Skip to content

Commit e931cfc

Browse files
committed
bump candle to 0.3.1 and conv_transpose_1d
1 parent 8c235d6 commit e931cfc

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

burn-candle/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ derive-new = { workspace = true }
1818
burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = false }
1919
half = { workspace = true }
2020

21-
# TODO remove pinned version ("=") once candle-core is updated to 0.3.1
22-
candle-core = { version = "=0.3.0" }
21+
candle-core = { version = "0.3.1" }
22+
2323

2424
[dev-dependencies]
2525
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", default-features = false, features = [

burn-candle/src/backend.rs

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ impl From<candle_core::Device> for CandleDevice {
5050
match device.location() {
5151
DeviceLocation::Cpu => CandleDevice::Cpu,
5252
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
53+
DeviceLocation::Metal => panic!("Metal unsupported"),
5354
}
5455
}
5556
}

burn-candle/src/ops/module.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,26 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
8383
bias: Option<FloatTensor<Self, 1>>,
8484
options: ConvTransposeOptions<1>,
8585
) -> FloatTensor<Self, 3> {
86-
panic!("Candle does not support conv_transpose1d")
86+
assert!(
87+
options.groups == 1,
88+
"Candle does not support groups in transposed convolutions"
89+
);
90+
let conv_transpose = x
91+
.tensor
92+
.conv_transpose1d(
93+
&weight.tensor,
94+
options.padding[0],
95+
options.padding_out[0],
96+
options.stride[0],
97+
options.dilation[0],
98+
)
99+
.unwrap();
100+
CandleTensor::new(match bias {
101+
Some(bias) => conv_transpose
102+
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
103+
.unwrap(),
104+
None => conv_transpose,
105+
})
87106
}
88107

89108
fn conv_transpose2d(

0 commit comments

Comments
 (0)