1
1
use crate :: kernel:: {
2
2
conv:: { Conv2dAutotuneKey , ConvTranspose2dAutotuneKey } ,
3
3
matmul:: MatmulAutotuneKey ,
4
- reduce:: ReduceAutotuneKey ,
4
+ reduce:: { ReduceAutotuneKey , SumAutotuneKey } ,
5
5
} ;
6
6
use cubecl:: tune:: AutotuneKey ;
7
7
use serde:: { Deserialize , Serialize } ;
@@ -14,6 +14,8 @@ pub enum JitAutotuneKey {
14
14
Matmul ( MatmulAutotuneKey ) ,
15
15
/// Key for reduce dim operations
16
16
Reduce ( ReduceAutotuneKey ) ,
17
+ /// Key for sum operations
18
+ Sum ( SumAutotuneKey ) ,
17
19
/// Key for convolution operations
18
20
Conv2d ( Conv2dAutotuneKey ) ,
19
21
/// Key for transpose convolution operations
@@ -25,6 +27,7 @@ impl Display for JitAutotuneKey {
25
27
match self {
26
28
JitAutotuneKey :: Matmul ( matmul_key) => std:: fmt:: Display :: fmt ( & matmul_key, f) ,
27
29
JitAutotuneKey :: Reduce ( reduce_key) => std:: fmt:: Display :: fmt ( & reduce_key, f) ,
30
+ JitAutotuneKey :: Sum ( reduce_key) => std:: fmt:: Display :: fmt ( & reduce_key, f) ,
28
31
JitAutotuneKey :: Conv2d ( conv2d_key) => std:: fmt:: Display :: fmt ( & conv2d_key, f) ,
29
32
JitAutotuneKey :: ConvTranspose2d ( conv2d_key) => std:: fmt:: Display :: fmt ( & conv2d_key, f) ,
30
33
}
0 commit comments