Skip to content

Commit

Permalink
Fix the apply vs apply_t bug, closes #8
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 1, 2024
1 parent 8b516d4 commit d3189ed
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions candle-lora-transformers/src/resnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ fn downsample(
if stride != 1 || c_in != c_out {
let conv = conv2d(c_in, c_out, 1, 0, stride, vb.pp(0), merge, lora_config)?;
let bn = batch_norm(c_out, 1e-5, vb.pp(1))?;
Ok(UnsyncFunc::new(move |xs| xs.apply(&conv)?.apply(&bn)))
Ok(UnsyncFunc::new(move |xs| {
xs.apply(&conv)?.apply_t(&bn, false)
}))
} else {
Ok(UnsyncFunc::new(|xs| Ok(xs.clone())))
}
Expand Down Expand Up @@ -124,10 +126,10 @@ fn basic_block(
Ok(UnsyncFunc::new(move |xs| {
let ys = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
.apply(&bn2)?;
.apply_t(&bn2, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
Expand Down Expand Up @@ -190,7 +192,7 @@ fn resnet(
Ok(UnsyncFunc::new(move |xs| {
let xs = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?
Expand Down Expand Up @@ -292,13 +294,13 @@ fn bottleneck_block(
Ok(UnsyncFunc::new(move |xs| {
let ys = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.apply(&conv2)?
.apply(&bn2)?
.apply_t(&bn2, false)?
.relu()?
.apply(&conv3)?
.apply(&bn3)?;
.apply_t(&bn3, false)?;
(xs.apply(&downsample)? + ys)?.relu()
}))
}
Expand Down Expand Up @@ -378,7 +380,7 @@ fn bottleneck_resnet(
Ok(UnsyncFunc::new(move |xs| {
let xs = xs
.apply(&conv1)?
.apply(&bn1)?
.apply_t(&bn1, false)?
.relu()?
.pad_with_same(D::Minus1, 1, 1)?
.pad_with_same(D::Minus2, 1, 1)?
Expand Down

0 comments on commit d3189ed

Please sign in to comment.