Skip to content

Commit bfebf22

Browse files
authored
Fix weight params in conv1d and conv2d (#1245)
1 parent 57ee2ce commit bfebf22

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

burn-core/src/nn/conv/conv1d.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ pub struct Conv1dConfig {
5050
/// - bias: Tensor of shape `[channels_out]`
5151
#[derive(Module, Debug)]
5252
pub struct Conv1d<B: Backend> {
53-
weight: Param<Tensor<B, 3>>,
54-
bias: Option<Param<Tensor<B, 1>>>,
53+
/// Tensor of shape [channels_out, channels_in / groups, kernel_size]
54+
pub weight: Param<Tensor<B, 3>>,
55+
/// Tensor of shape `[channels_out]`
56+
pub bias: Option<Param<Tensor<B, 1>>>,
5557
stride: usize,
5658
kernel_size: usize,
5759
dilation: usize,

burn-core/src/nn/conv/conv2d.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ pub struct Conv2dConfig {
4949
/// - bias: Tensor of shape `[channels_out]`
5050
#[derive(Module, Debug)]
5151
pub struct Conv2d<B: Backend> {
52-
weight: Param<Tensor<B, 4>>,
53-
bias: Option<Param<Tensor<B, 1>>>,
52+
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
53+
pub weight: Param<Tensor<B, 4>>,
54+
/// Tensor of shape `[channels_out]`
55+
pub bias: Option<Param<Tensor<B, 1>>>,
5456
stride: [usize; 2],
5557
kernel_size: [usize; 2],
5658
dilation: [usize; 2],

burn-core/src/nn/conv/conv_transpose1d.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ pub struct ConvTranspose1dConfig {
5151
/// - bias: Tensor of shape `[channels_out]`
5252
#[derive(Module, Debug)]
5353
pub struct ConvTranspose1d<B: Backend> {
54-
weight: Param<Tensor<B, 3>>,
55-
bias: Option<Param<Tensor<B, 1>>>,
54+
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
55+
pub weight: Param<Tensor<B, 3>>,
56+
/// Tensor of shape `[channels_out]`
57+
pub bias: Option<Param<Tensor<B, 1>>>,
5658
stride: usize,
5759
kernel_size: usize,
5860
dilation: usize,

burn-core/src/nn/conv/conv_transpose2d.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ pub struct ConvTranspose2dConfig {
5151
/// - bias: Tensor of shape `[channels_out]`
5252
#[derive(Module, Debug)]
5353
pub struct ConvTranspose2d<B: Backend> {
54-
weight: Param<Tensor<B, 4>>,
55-
bias: Option<Param<Tensor<B, 1>>>,
54+
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
55+
pub weight: Param<Tensor<B, 4>>,
56+
/// Tensor of shape `[channels_out]`
57+
pub bias: Option<Param<Tensor<B, 1>>>,
5658
stride: [usize; 2],
5759
kernel_size: [usize; 2],
5860
dilation: [usize; 2],

0 commit comments

Comments
 (0)