1
- use burn_cube:: dialect:: { Item , Scope , Variable } ;
2
-
3
1
#[ cfg( feature = "autotune" ) ]
4
2
use crate :: kernel:: reduce:: reduce_dim_autotune;
5
3
use crate :: { element:: JitElement , tensor:: JitTensor , JitRuntime } ;
6
4
7
- use super :: { reduce_dim_naive, reduce_dim_shared, ArgMax , ArgMin , MeanDim , ProdDim , SumDim } ;
8
-
9
- /// Specifies the reduce dim algorithm in use
10
- pub trait ReduceDimAlgorithm < E : JitElement > : Send + Sync + ' static {
11
- /// The reduction accumulator
12
- type Accumulator : Copy ;
13
-
14
- /// Initialization for naive algorithm
15
- fn initialize_naive (
16
- scope : & mut Scope ,
17
- input_item : Item ,
18
- output_item : Item ,
19
- ) -> Self :: Accumulator ;
20
-
21
- /// Inner loop for naive algorithm
22
- fn inner_loop_naive (
23
- scope : & mut Scope ,
24
- accumulator : Self :: Accumulator ,
25
- current_value : Variable ,
26
- i : Variable ,
27
- ) ;
28
-
29
- /// Assignation for naive algorithm
30
- fn assign_naive (
31
- scope : & mut Scope ,
32
- output : Variable ,
33
- accumulator : Self :: Accumulator ,
34
- shape_reduce_dim : Variable ,
35
- ) ;
5
+ use super :: {
6
+ naive:: { base:: ReduceDimNaive , shader:: reduce_dim_naive} ,
7
+ shared:: { base:: ReduceDimShared , shader:: reduce_dim_shared} ,
8
+ } ;
36
9
37
- /// Initialization for shared algorithm
38
- fn initialize_shared (
39
- scope : & mut Scope ,
40
- shared_memory_size : u32 ,
41
- write_position : Variable ,
42
- input_item : Item ,
43
- ) -> Self :: Accumulator ;
44
-
45
- /// How to write to shared memory
46
- fn write_to_shared (
47
- scope : & mut Scope ,
48
- shared_memory : Self :: Accumulator ,
49
- write_position : Variable ,
50
- value : Self :: Accumulator ,
51
- ) ;
52
-
53
- /// How to read from input in shared algorithm
54
- fn read_from_input (
55
- scope : & mut Scope ,
56
- input : Variable ,
57
- read_position : Variable ,
58
- i : Variable ,
59
- ) -> Self :: Accumulator ;
60
-
61
- /// How to read from shared memory
62
- fn read_from_shared (
63
- scope : & mut Scope ,
64
- shared_memory : Self :: Accumulator ,
65
- read_position : Variable ,
66
- ) -> Self :: Accumulator ;
67
-
68
- /// How to assign from shared memory
69
- fn assign_shared (
70
- scope : & mut Scope ,
71
- shared_memory : Self :: Accumulator ,
72
- output : Variable ,
73
- write_position : Variable ,
74
- shape_reduce_dim : Variable ,
75
- ) ;
10
+ pub ( crate ) trait ReduceDimAlgorithm < E : JitElement > :
11
+ ReduceDimNaive < E > + ReduceDimShared < E >
12
+ {
76
13
}
77
14
78
15
/// Creates an empty output tensor with reduce output shape
@@ -116,7 +53,10 @@ impl Default for ReduceStrategy {
116
53
}
117
54
118
55
macro_rules! reduce_operation {
119
- ( $name: ident, $ops: ty) => {
56
+ ( $name: ident, $ops: ident) => {
57
+ pub ( crate ) struct $ops;
58
+ impl <E : JitElement > ReduceDimAlgorithm <E > for $ops { }
59
+
120
60
/// Executes the reduce operation with the given strategy.
121
61
pub fn $name<R : JitRuntime , EI : JitElement , EO : JitElement , const D : usize >(
122
62
tensor: JitTensor <R , EI , D >,
@@ -143,5 +83,5 @@ macro_rules! reduce_operation {
143
83
reduce_operation ! ( sum_dim, SumDim ) ;
144
84
reduce_operation ! ( mean_dim, MeanDim ) ;
145
85
reduce_operation ! ( prod_dim, ProdDim ) ;
146
- reduce_operation ! ( argmin, ArgMin ) ;
147
- reduce_operation ! ( argmax, ArgMax ) ;
86
+ reduce_operation ! ( argmin, Argmin ) ;
87
+ reduce_operation ! ( argmax, Argmax ) ;
0 commit comments