Skip to content

Commit 7661deb

Browse files
authored
Fix image-classsification-web + autotune flag usage (#2011)
1 parent 3afff43 commit 7661deb

File tree

5 files changed

+18
-5
lines changed

5 files changed

+18
-5
lines changed

crates/burn-jit/src/kernel/matmul/base.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ use burn_cube::prelude::*;
33
use burn_tensor::Shape;
44

55
use super::{
6-
config::Tiling2dConfig, init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d,
6+
config::Tiling2dConfig, init_matmul_output, matmul_simple, matmul_tiling_2d,
77
matmul_tiling_2d_cube, matmul_tiling_2d_padded,
88
};
99

10+
#[cfg(feature = "autotune")]
11+
use super::matmul_autotune;
12+
1013
/// The strategy to be used when launching a matmul kernel.
1114
pub enum MatmulStrategy {
1215
/// A simple kernel will be used with memory coalescing optimization.
@@ -27,11 +30,14 @@ pub enum MatmulStrategy {
2730
Tiling2dCube(Tiling2dConfig),
2831
}
2932

30-
#[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work.
31-
#[cfg(feature = "autotune")]
3233
impl Default for MatmulStrategy {
3334
fn default() -> Self {
34-
MatmulStrategy::Autotune
35+
// if autotune is enabled, default to autotune
36+
#[cfg(feature = "autotune")]
37+
return MatmulStrategy::Autotune;
38+
39+
#[cfg(not(feature = "autotune"))]
40+
MatmulStrategy::Tiling2d(Tiling2dConfig::default())
3541
}
3642
}
3743

Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#[cfg(feature = "autotune")]
12
mod base;
23
mod key;
34

5+
#[cfg(feature = "autotune")]
46
pub use base::*;
57
pub use key::*;

crates/burn-jit/src/kernel/reduce/base.rs

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::{
77
shared::{base::ReduceDimShared, shader::reduce_dim_shared},
88
};
99

10+
#[allow(dead_code)]
1011
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
1112
ReduceDimNaive<E> + ReduceDimShared<E>
1213
{
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
#[cfg(feature = "autotune")]
12
mod base;
23
mod key;
34

5+
#[cfg(feature = "autotune")]
46
pub(crate) use base::*;
57
pub use key::*;

examples/image-classification-web/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ half_precision = []
1717
burn = { path = "../../crates/burn", version = "0.14.0", default-features = false, features = [
1818
"ndarray",
1919
] }
20-
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false }
20+
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false, features = [
21+
"autotune",
22+
] }
2123
burn-candle = { path = "../../crates/burn-candle", version = "0.14.0", default-features = false }
2224

2325
log = { workspace = true }

0 commit comments

Comments
 (0)