Skip to content

Commit

Permalink
Use the dilation parameter in the cpu backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Aug 29, 2023
1 parent d289286 commit ea9a6f1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * l_out;
for dst_l in 0..l_out {
let dst_idx = dst_idx + dst_l;
let src_l = p.stride * dst_l + offset;
let src_l = (p.stride * dst_l + offset) * p.dilation;
if src_l < p.padding || src_l >= p.padding + p.l_in {
continue;
}
Expand Down Expand Up @@ -1141,14 +1141,14 @@ impl<'a> Map2 for Conv2D<'a> {
let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
for dst_h in 0..out_h {
let dst_idx = dst_idx + dst_h * out_w;
let src_h = p.stride * dst_h + offset_h;
let src_h = (p.stride * dst_h + offset_h) * p.dilation;
if src_h < p.padding || src_h >= p.i_h + p.padding {
continue;
}
let src_h = src_h - p.padding;
for dst_w in 0..out_w {
let dst_idx = dst_idx + dst_w;
let src_w = p.stride * dst_w + offset_w;
let src_w = (p.stride * dst_w + offset_w) * p.dilation;
if src_w < p.padding || src_w >= p.i_w + p.padding {
continue;
}
Expand Down

0 comments on commit ea9a6f1

Please sign in to comment.