Skip to content

Commit e39b4d2

Browse files
authored
refactor reduce into separate traits (#1798)
1 parent 0331719 commit e39b4d2

19 files changed

+332
-259
lines changed
+13-73
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,15 @@
1-
use burn_cube::dialect::{Item, Scope, Variable};
2-
31
#[cfg(feature = "autotune")]
42
use crate::kernel::reduce::reduce_dim_autotune;
53
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
64

7-
use super::{reduce_dim_naive, reduce_dim_shared, ArgMax, ArgMin, MeanDim, ProdDim, SumDim};
8-
9-
/// Specifies the reduce dim algorithm in use
10-
pub trait ReduceDimAlgorithm<E: JitElement>: Send + Sync + 'static {
11-
/// The reduction accumulator
12-
type Accumulator: Copy;
13-
14-
/// Initialization for naive algorithm
15-
fn initialize_naive(
16-
scope: &mut Scope,
17-
input_item: Item,
18-
output_item: Item,
19-
) -> Self::Accumulator;
20-
21-
/// Inner loop for naive algorithm
22-
fn inner_loop_naive(
23-
scope: &mut Scope,
24-
accumulator: Self::Accumulator,
25-
current_value: Variable,
26-
i: Variable,
27-
);
28-
29-
/// Assignation for naive algorithm
30-
fn assign_naive(
31-
scope: &mut Scope,
32-
output: Variable,
33-
accumulator: Self::Accumulator,
34-
shape_reduce_dim: Variable,
35-
);
5+
use super::{
6+
naive::{base::ReduceDimNaive, shader::reduce_dim_naive},
7+
shared::{base::ReduceDimShared, shader::reduce_dim_shared},
8+
};
369

37-
/// Initialization for shared algorithm
38-
fn initialize_shared(
39-
scope: &mut Scope,
40-
shared_memory_size: u32,
41-
write_position: Variable,
42-
input_item: Item,
43-
) -> Self::Accumulator;
44-
45-
/// How to write to shared memory
46-
fn write_to_shared(
47-
scope: &mut Scope,
48-
shared_memory: Self::Accumulator,
49-
write_position: Variable,
50-
value: Self::Accumulator,
51-
);
52-
53-
/// How to read from input in shared algorithm
54-
fn read_from_input(
55-
scope: &mut Scope,
56-
input: Variable,
57-
read_position: Variable,
58-
i: Variable,
59-
) -> Self::Accumulator;
60-
61-
/// How to read from shared memory
62-
fn read_from_shared(
63-
scope: &mut Scope,
64-
shared_memory: Self::Accumulator,
65-
read_position: Variable,
66-
) -> Self::Accumulator;
67-
68-
/// How to assign from shared memory
69-
fn assign_shared(
70-
scope: &mut Scope,
71-
shared_memory: Self::Accumulator,
72-
output: Variable,
73-
write_position: Variable,
74-
shape_reduce_dim: Variable,
75-
);
10+
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
11+
ReduceDimNaive<E> + ReduceDimShared<E>
12+
{
7613
}
7714

7815
/// Creates an empty output tensor with reduce output shape
@@ -116,7 +53,10 @@ impl Default for ReduceStrategy {
11653
}
11754

11855
macro_rules! reduce_operation {
119-
($name:ident, $ops:ty) => {
56+
($name:ident, $ops:ident) => {
57+
pub(crate) struct $ops;
58+
impl<E: JitElement> ReduceDimAlgorithm<E> for $ops {}
59+
12060
/// Executes the reduce operation with the given strategy.
12161
pub fn $name<R: JitRuntime, EI: JitElement, EO: JitElement, const D: usize>(
12262
tensor: JitTensor<R, EI, D>,
@@ -143,5 +83,5 @@ macro_rules! reduce_operation {
14383
reduce_operation!(sum_dim, SumDim);
14484
reduce_operation!(mean_dim, MeanDim);
14585
reduce_operation!(prod_dim, ProdDim);
146-
reduce_operation!(argmin, ArgMin);
147-
reduce_operation!(argmax, ArgMax);
86+
reduce_operation!(argmin, Argmin);
87+
reduce_operation!(argmax, Argmax);
+2-14
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,11 @@
1-
mod argmax_dim;
2-
mod argmin_dim;
31
mod base;
4-
mod mean_dim;
5-
mod naive_reduce_shader;
2+
mod naive;
63
mod prod;
7-
mod prod_dim;
8-
mod shared_reduce_shader;
4+
mod shared;
95
mod sum;
10-
mod sum_dim;
116
mod tune;
127

13-
pub(crate) use argmax_dim::*;
14-
pub(crate) use argmin_dim::*;
158
pub use base::*;
16-
pub(crate) use mean_dim::*;
17-
pub use naive_reduce_shader::*;
189
pub use prod::*;
19-
pub(crate) use prod_dim::*;
20-
pub use shared_reduce_shader::*;
2110
pub use sum::*;
22-
pub(crate) use sum_dim::*;
2311
pub use tune::*;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use crate::{kernel::reduce::Argmax, JitElement};
2+
use burn_cube::{
3+
cpa,
4+
dialect::{Elem, Item, Scope, Variable},
5+
};
6+
7+
use super::base::ReduceDimNaive;
8+
9+
impl<E: JitElement> ReduceDimNaive<E> for Argmax {
10+
type Accumulator = (Variable, Variable);
11+
12+
fn initialize_naive(
13+
scope: &mut Scope,
14+
input_item: Item,
15+
_output_item: Item,
16+
) -> Self::Accumulator {
17+
let index = scope.create_local(Elem::UInt);
18+
let max = scope.create_local(input_item);
19+
let max_initial =
20+
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
21+
cpa!(scope, max = max_initial);
22+
23+
(max, index)
24+
}
25+
26+
fn inner_loop_naive(
27+
scope: &mut Scope,
28+
(max, index): Self::Accumulator,
29+
value: Variable,
30+
i: Variable,
31+
) {
32+
let condition = scope.create_local(Elem::Bool);
33+
cpa!(scope, condition = value > max);
34+
cpa!(scope, if(condition).then(|scope| {
35+
cpa!(scope, max = value);
36+
cpa!(scope, index = i);
37+
}));
38+
}
39+
40+
fn assign_naive(
41+
scope: &mut Scope,
42+
output: Variable,
43+
(_max, index): Self::Accumulator,
44+
_shape_reduce_dim: Variable,
45+
) {
46+
let id = Variable::Id;
47+
cpa!(scope, output[id] = index);
48+
}
49+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
use burn_cube::{
2+
cpa,
3+
dialect::{Elem, Item, Scope, Variable},
4+
};
5+
6+
use crate::{kernel::reduce::Argmin, JitElement};
7+
8+
use super::base::ReduceDimNaive;
9+
10+
impl<E: JitElement> ReduceDimNaive<E> for Argmin {
11+
type Accumulator = (Variable, Variable);
12+
13+
fn initialize_naive(
14+
scope: &mut Scope,
15+
input_item: Item,
16+
_output_item: Item,
17+
) -> Self::Accumulator {
18+
let index = scope.create_local(Elem::UInt);
19+
let min = scope.create_local(input_item);
20+
let min_initial =
21+
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
22+
cpa!(scope, min = min_initial);
23+
24+
(min, index)
25+
}
26+
27+
fn inner_loop_naive(
28+
scope: &mut Scope,
29+
(min, index): Self::Accumulator,
30+
value: Variable,
31+
i: Variable,
32+
) {
33+
let condition = scope.create_local(Elem::Bool);
34+
cpa!(scope, condition = value < min);
35+
cpa!(scope, if(condition).then(|scope| {
36+
cpa!(scope, min = value);
37+
cpa!(scope, index = i);
38+
}));
39+
}
40+
41+
fn assign_naive(
42+
scope: &mut Scope,
43+
output: Variable,
44+
(_min, index): Self::Accumulator,
45+
_shape_reduce_dim: Variable,
46+
) {
47+
let id = Variable::Id;
48+
cpa!(scope, output[id] = index);
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use burn_cube::dialect::{Item, Scope, Variable};
2+
3+
use crate::JitElement;
4+
5+
/// Specifies the reduce dim algorithm in use
6+
pub trait ReduceDimNaive<E: JitElement>: Send + Sync + 'static {
7+
/// The reduction accumulator
8+
type Accumulator: Copy;
9+
10+
/// Initialization for naive algorithm
11+
fn initialize_naive(
12+
scope: &mut Scope,
13+
input_item: Item,
14+
output_item: Item,
15+
) -> Self::Accumulator;
16+
17+
/// Inner loop for naive algorithm
18+
fn inner_loop_naive(
19+
scope: &mut Scope,
20+
accumulator: Self::Accumulator,
21+
current_value: Variable,
22+
i: Variable,
23+
);
24+
25+
/// Assignation for naive algorithm
26+
fn assign_naive(
27+
scope: &mut Scope,
28+
output: Variable,
29+
accumulator: Self::Accumulator,
30+
shape_reduce_dim: Variable,
31+
);
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use crate::{kernel::reduce::MeanDim, JitElement};
2+
use burn_cube::{
3+
cpa,
4+
dialect::{Item, Scope, Variable},
5+
};
6+
7+
use super::base::ReduceDimNaive;
8+
9+
impl<E: JitElement> ReduceDimNaive<E> for MeanDim {
10+
type Accumulator = Variable;
11+
12+
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
13+
scope.zero(output_item)
14+
}
15+
16+
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
17+
cpa!(scope, accumulator += value);
18+
}
19+
20+
fn assign_naive(
21+
scope: &mut Scope,
22+
output: Variable,
23+
accumulator: Variable,
24+
shape_reduce_dim: Variable,
25+
) {
26+
let id = Variable::Id;
27+
let denominator = scope.create_local(accumulator.item());
28+
cpa!(scope, denominator = cast(shape_reduce_dim));
29+
cpa!(scope, accumulator = accumulator / denominator);
30+
cpa!(scope, output[id] = accumulator);
31+
}
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pub(crate) mod argmax;
2+
pub(crate) mod argmin;
3+
pub(crate) mod base;
4+
pub(crate) mod mean_dim;
5+
pub(crate) mod prod_dim;
6+
pub(crate) mod shader;
7+
pub(crate) mod sum_dim;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use crate::{kernel::reduce::ProdDim, JitElement};
2+
use burn_cube::{
3+
cpa,
4+
dialect::{Item, Scope, Variable},
5+
};
6+
7+
use super::base::ReduceDimNaive;
8+
9+
impl<E: JitElement> ReduceDimNaive<E> for ProdDim {
10+
type Accumulator = Variable;
11+
12+
fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
13+
scope.create_with_value(1, output_item)
14+
}
15+
16+
fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
17+
cpa!(scope, accumulator *= value);
18+
}
19+
20+
fn assign_naive(
21+
scope: &mut Scope,
22+
output: Variable,
23+
accumulator: Variable,
24+
_shape_reduce_dim: Variable,
25+
) {
26+
let id = Variable::Id;
27+
cpa!(scope, output[id] = accumulator);
28+
}
29+
}

crates/burn-jit/src/kernel/reduce/naive_reduce_shader.rs crates/burn-jit/src/kernel/reduce/naive/shader.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ use std::marker::PhantomData;
88

99
use crate::{element::JitElement, kernel::GpuComputeShaderPhase, tensor::JitTensor, JitRuntime};
1010

11-
use super::ReduceDimAlgorithm;
11+
use super::base::ReduceDimNaive;
1212

13-
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgorithm<E>> {
13+
pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimNaive<E>> {
1414
tensor: Variable,
1515
dim: usize,
1616
output: Variable,
@@ -20,7 +20,7 @@ pub(crate) struct NaiveReduceDimComputeShader<E: JitElement, RD: ReduceDimAlgori
2020

2121
#[derive(new)]
2222
pub(crate) struct NaiveReduceDimEagerKernel<
23-
RD: ReduceDimAlgorithm<EI>,
23+
RD: ReduceDimNaive<EI>,
2424
R: JitRuntime,
2525
EI: JitElement,
2626
EO: JitElement,
@@ -32,8 +32,8 @@ pub(crate) struct NaiveReduceDimEagerKernel<
3232
_elem_out: PhantomData<EO>,
3333
}
3434

35-
impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
36-
GpuComputeShaderPhase for NaiveReduceDimEagerKernel<RD, R, EI, EO>
35+
impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> GpuComputeShaderPhase
36+
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
3737
{
3838
fn compile(&self) -> ComputeShader {
3939
let mut scope = Scope::root();
@@ -76,7 +76,7 @@ impl<RD: ReduceDimAlgorithm<EI>, R: JitRuntime, EI: JitElement, EO: JitElement>
7676
}
7777
}
7878

79-
impl<E: JitElement, RD: ReduceDimAlgorithm<E>> NaiveReduceDimComputeShader<E, RD> {
79+
impl<E: JitElement, RD: ReduceDimNaive<E>> NaiveReduceDimComputeShader<E, RD> {
8080
pub(crate) fn expand(self, scope: &mut Scope) {
8181
let tensor = self.tensor;
8282
let dim: Variable = self.dim.into();
@@ -136,7 +136,7 @@ impl<E: JitElement, RD: ReduceDimAlgorithm<E>> NaiveReduceDimComputeShader<E, RD
136136

137137
/// Executes the naive kernel for reduce dim
138138
pub fn reduce_dim_naive<
139-
RD: ReduceDimAlgorithm<EI>,
139+
RD: ReduceDimNaive<EI>,
140140
R: JitRuntime,
141141
EI: JitElement,
142142
EO: JitElement,

0 commit comments

Comments
 (0)