Skip to content

Commit b9bf504

Browse files
Big refactoring
1 parent 4471ea3 commit b9bf504

16 files changed

+1114
-886
lines changed

crates/burn-jit/src/fusion/elemwise/builder.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@ use burn_fusion::OptimizationBuilder;
22

33
use crate::{
44
fusion::{
5-
on_write::{
6-
builder::{FuseOnWriteBuilder, FuseSettings},
7-
ir::ElemwisePrecision,
8-
},
5+
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
96
JitOptimization,
107
},
118
JitRuntime,

crates/burn-jit/src/fusion/matmul/builder.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ use burn_tensor::repr::{FloatOperationDescription, OperationDescription};
33

44
use crate::{
55
fusion::{
6-
on_write::{
7-
builder::{FuseOnWriteBuilder, FuseSettings},
8-
ir::ElemwisePrecision,
9-
},
6+
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
107
JitOptimization,
118
},
129
JitRuntime,

crates/burn-jit/src/fusion/on_write/builder.rs

+2-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::{
22
ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs},
3-
trace::FuseOnWriteTrace,
4-
trace_builder::FuseOnWriteTraceBuilder,
3+
settings::FuseSettings,
4+
trace::{FuseOnWriteTrace, FuseOnWriteTraceBuilder},
55
};
66
use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus};
77
use burn_tensor::{
@@ -13,7 +13,6 @@ use burn_tensor::{
1313
Element,
1414
};
1515
use cubecl::ir::Elem;
16-
use serde::{Deserialize, Serialize};
1716

1817
/// Fused element wise operations that are normally memory bound.
1918
pub(crate) struct FuseOnWriteBuilder {
@@ -26,25 +25,6 @@ pub(crate) struct FuseOnWriteBuilder {
2625
max_bindings: u32,
2726
}
2827

29-
/// Controls which operations can be fused.
30-
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
31-
pub struct FuseSettings {
32-
/// Enables broadcasting of shapes.
33-
pub broadcast: bool,
34-
/// Enables output shape updates.
35-
///
36-
/// When broadcast is enabled, the output shape can become bigger after a fusion,
37-
/// therefore an update is needed.
38-
pub output_shape_updates: bool,
39-
/// Enables mix vectorization factor.
40-
///
41-
/// Useful when the last dimension is broadcasted for one of the tensors, which would limit the
42-
/// vectorization factor to be 1 without this setting enabled.
43-
pub mix_vectorization: bool,
44-
/// Enables the reuse of input buffers.
45-
pub inplace: bool,
46-
}
47-
4828
struct TryFuseBuilder {
4929
builder: FuseOnWriteTraceBuilder,
5030
max_bindings: u32,

crates/burn-jit/src/fusion/on_write/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ pub(crate) mod builder;
22
pub(crate) mod io;
33
pub(crate) mod ir;
44
pub(crate) mod kernel;
5-
pub(super) mod position;
5+
pub(crate) mod settings;
66

77
pub mod trace;
8-
pub(crate) mod trace_builder;

crates/burn-jit/src/fusion/on_write/position.rs

-31
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
/// Controls which operations can be fused.
4+
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
5+
pub struct FuseSettings {
6+
/// Enables broadcasting of shapes.
7+
pub broadcast: bool,
8+
/// Enables output shape updates.
9+
///
10+
/// When broadcast is enabled, the output shape can become bigger after a fusion,
11+
/// therefore an update is needed.
12+
pub output_shape_updates: bool,
13+
/// Enables mix vectorization factor.
14+
///
15+
/// Useful when the last dimension is broadcasted for one of the tensors, which would limit the
16+
/// vectorization factor to be 1 without this setting enabled.
17+
pub mix_vectorization: bool,
18+
/// Enables the reuse of input buffers.
19+
pub inplace: bool,
20+
}

0 commit comments

Comments
 (0)