Skip to content

Commit 4221820

Browse files
committed
commit missing autotune key
1 parent b43235b commit 4221820

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

backend-comparison/benches/reduce.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl<B: Backend> Benchmark for ReduceBenchmark<B> {
4343
self.tensor.clone().sum_dim(axis);
4444
}
4545
Instruction::Sum => {
46-
self.tensor.clone().sum();
46+
self.tensor.clone().sum_dim(2).sum_dim(1).sum_dim(0);
4747
}
4848
}
4949
}

crates/burn-jit/src/tune_key.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::kernel::{
22
conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey},
33
matmul::MatmulAutotuneKey,
4-
reduce::ReduceAutotuneKey,
4+
reduce::{ReduceAutotuneKey, SumAutotuneKey},
55
};
66
use cubecl::tune::AutotuneKey;
77
use serde::{Deserialize, Serialize};
@@ -14,6 +14,8 @@ pub enum JitAutotuneKey {
1414
Matmul(MatmulAutotuneKey),
1515
/// Key for reduce dim operations
1616
Reduce(ReduceAutotuneKey),
17+
/// Key for sum operations
18+
Sum(SumAutotuneKey),
1719
/// Key for convolution operations
1820
Conv2d(Conv2dAutotuneKey),
1921
/// Key for transpose convolution operations
@@ -25,6 +27,7 @@ impl Display for JitAutotuneKey {
2527
match self {
2628
JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
2729
JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
30+
JitAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
2831
JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
2932
JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
3033
}

0 commit comments

Comments
 (0)