Skip to content

Commit

Permalink
Fix CI (#2268)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 10, 2024
1 parent 17050db commit d3fbdea
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ env:
# Note: It is not possible to define env vars in composite actions.
# To work around this issue we use inputs and define all the env vars here.

RUST_PREVIOUS_VERSION: 1.79.0
RUST_PREVIOUS_VERSION: 1.80.0

# Cargo
CARGO_TERM_COLOR: "always"
Expand Down
7 changes: 3 additions & 4 deletions crates/burn-jit/src/kernel/conv/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ fn conv2d_kernel<F: Float>(

let in_channels = weight.shape(1);

let kernel_size_0 = kernel_size_0_unroll.unwrap_or_else(|| weight.shape(2));
let kernel_size_0 = weight.shape(2);
let kernel_size_1 = kernel_size_1_unroll.unwrap_or_else(|| weight.shape(3));
let unroll_1 = kernel_size_1_unroll.is_some();


let b = ABSOLUTE_POS / output.stride(0) % output.shape(0);
let oc = ABSOLUTE_POS / output.stride(1) % output.shape(1);
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
Expand Down Expand Up @@ -130,7 +129,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let kernel_1_unroll = if kernel_1 > 8 {
None
} else {
Some(kernel_1.into())
Some(kernel_1 as u32)
};

let out_0 = calculate_conv_output_size(
Expand Down Expand Up @@ -188,7 +187,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.groups as u32),
),
Some(kernel_1 as u32),
kernel_1_unroll,
);

output
Expand Down
2 changes: 1 addition & 1 deletion crates/burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "burn"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn"
version.workspace = true
rust-version = "1.79"
rust-version = "1.80"

[features]
default = ["burn-core/default", "burn-train?/default", "std"]
Expand Down

0 comments on commit d3fbdea

Please sign in to comment.