Skip to content

Commit 6015823

Browse files
Feat/improve fusion (#2773)
* WIP * WIP * WIP testing * Very wip * WIP works better * Fix vectorization * Still debug * Fix some problems * Fix other broadcast issues * Fix another bug, but still very wip * WIP Works * Cleanup * Support broadcasted vectorization * Cleanup * Still some bugs * Fix multi vectorization broadcasting fused * Add fuse settings * Fix broadcast issue * Fix performance * Some cleanup * Big refactoring * Add reshape optimization * Cleanup * Add some docs * Update cubecl ref * Clippy + Fmt * Add vulkan in example * WIP * Fix test * Cleanup * Fix no std tests * Better autotune * Remove print * Update crates/burn-jit/src/fusion/on_write/trace/output.rs * Update crates/burn-jit/src/fusion/on_write/trace/plan.rs --------- Co-authored-by: Guillaume Lagrange <[email protected]>
1 parent 00422c1 commit 6015823

File tree

40 files changed

+1996
-835
lines changed

40 files changed

+1996
-835
lines changed

Cargo.lock

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend-comparison/benches/matmul_fused.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
2626
}
2727

2828
fn execute(&self, (lhs, rhs, bias): Self::Args) {
29-
let bias = bias.unsqueeze();
30-
gelu(relu(lhs.matmul(rhs)) + bias);
29+
let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze());
3130
}
3231

3332
fn prepare(&self) -> Self::Args {

crates/burn-core/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,15 @@ extern crate alloc;
5959
pub type TestBackend = burn_ndarray::NdArray<f32>;
6060

6161
#[cfg(all(test, feature = "test-tch"))]
62+
/// Backend for test cases
6263
pub type TestBackend = burn_tch::LibTorch<f32>;
6364

6465
#[cfg(all(test, feature = "test-wgpu"))]
66+
/// Backend for test cases
6567
pub type TestBackend = burn_wgpu::Wgpu;
6668

6769
#[cfg(all(test, feature = "test-cuda"))]
70+
/// Backend for test cases
6871
pub type TestBackend = burn_cuda::Cuda;
6972

7073
/// Backend for autodiff test cases

crates/burn-core/src/nn/linear.rs

-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ impl<B: Backend> Linear<B> {
7777

7878
let weight = self.weight.val().unsqueeze();
7979
let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());
80-
8180
let output = input.matmul(weight);
8281

8382
match bias {

crates/burn-core/src/nn/transformer/decoder.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,9 @@ impl<B: Backend> TransformerDecoder<B> {
455455

456456
#[cfg(test)]
457457
mod tests {
458+
use burn_tensor::Device;
459+
458460
use super::*;
459-
use crate::tensor::Distribution;
460461
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
461462

462463
#[test]
@@ -481,20 +482,16 @@ mod tests {
481482
}
482483

483484
fn test_autoregressive(config: TransformerDecoderConfig) {
484-
let device = Default::default();
485+
let device: Device<TestBackend> = Default::default();
485486
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
486-
let transformer = config.init(&device);
487-
488-
let memory = Tensor::<TestBackend, 3>::random(
489-
[batch_size, seq_length, d_model],
490-
Distribution::Default,
491-
&device,
492-
);
493-
let target = Tensor::<TestBackend, 3>::random(
494-
[batch_size, seq_length, d_model],
495-
Distribution::Default,
496-
&device,
497-
);
487+
let transformer = config.init::<TestBackend>(&device);
488+
489+
let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
490+
.float()
491+
.reshape([batch_size, seq_length, d_model]);
492+
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
493+
.float()
494+
.reshape([batch_size, seq_length, d_model]);
498495
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
499496
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
500497
.target_mask_attn(mask_attn);

crates/burn-fusion/src/ops/boolean.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ use burn_tensor::{
1717
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
1818
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
1919
HandleContainer, OperationDescription, PermuteOperationDescription,
20-
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
21-
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
20+
RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
21+
SwapDimsDescription, UnaryOperationDescription,
2222
},
2323
Device, Shape,
2424
};
@@ -171,7 +171,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
171171
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
172172
#[derive(new)]
173173
struct ReshapeDimsOps<B: FusionBackend> {
174-
desc: ReshapeDescription,
174+
desc: UnaryOperationDescription,
175175
_b: PhantomData<B>,
176176
}
177177

@@ -186,7 +186,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
186186
let stream = tensor.stream;
187187
let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool);
188188

189-
let desc = ReshapeDescription {
189+
let desc = UnaryOperationDescription {
190190
input: tensor.into_description(),
191191
out: out.to_description_out(),
192192
};

crates/burn-fusion/src/ops/float.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
640640
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
641641
#[derive(new)]
642642
struct ReshapeDimsOps<B: FusionBackend> {
643-
desc: ReshapeDescription,
643+
desc: UnaryOperationDescription,
644644
_b: PhantomData<B>,
645645
}
646646

@@ -656,7 +656,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
656656
let dtype = tensor.dtype;
657657
let out = tensor.client.tensor_uninitialized(shape.dims, dtype);
658658

659-
let desc = ReshapeDescription {
659+
let desc = UnaryOperationDescription {
660660
input: tensor.into_description(),
661661
out: out.to_description_out(),
662662
};

crates/burn-fusion/src/ops/int.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
9393
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
9494
#[derive(new)]
9595
struct ReshapeDimsOps<B: FusionBackend> {
96-
desc: ReshapeDescription,
96+
desc: UnaryOperationDescription,
9797
_b: PhantomData<B>,
9898
}
9999

@@ -110,7 +110,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
110110
.client
111111
.tensor_uninitialized(shape.dims, B::IntElem::dtype());
112112

113-
let desc = ReshapeDescription {
113+
let desc = UnaryOperationDescription {
114114
input: tensor.into_description(),
115115
out: out.to_description_out(),
116116
};

crates/burn-fusion/src/stream/context.rs

+34-6
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,9 @@ pub struct Context<'a, H> {
3939
pub scalar_u8: &'a Vec<u8>,
4040
}
4141

42-
#[derive(Default)]
4342
pub(crate) struct OperationConverter {
4443
tensors_relative2global: HashMap<TensorId, TensorDescription>,
4544
tensors_global2relative: HashMap<TensorId, TensorDescription>,
46-
/// Only useful to create new shape ID.
47-
/// You should use tensor descriptions to retrieve the proper shape.
4845
shapes_global2relative: HashMap<usize, usize>,
4946
scalar_f32: Vec<f32>,
5047
scalar_f16: Vec<f16>,
@@ -59,6 +56,32 @@ pub(crate) struct OperationConverter {
5956
scalar_u8: Vec<u8>,
6057
}
6158

59+
impl Default for OperationConverter {
60+
fn default() -> Self {
61+
let mut val = Self {
62+
tensors_relative2global: Default::default(),
63+
tensors_global2relative: Default::default(),
64+
shapes_global2relative: Default::default(),
65+
scalar_f32: Default::default(),
66+
scalar_f16: Default::default(),
67+
scalar_bf16: Default::default(),
68+
scalar_i64: Default::default(),
69+
scalar_i32: Default::default(),
70+
scalar_i16: Default::default(),
71+
scalar_i8: Default::default(),
72+
scalar_u64: Default::default(),
73+
scalar_u32: Default::default(),
74+
scalar_u16: Default::default(),
75+
scalar_u8: Default::default(),
76+
};
77+
78+
// global 1 is always shape id 0.
79+
val.shapes_global2relative.insert(1, 0);
80+
81+
val
82+
}
83+
}
84+
6285
/// Fork of a [context](Context) which owns its data.
6386
pub struct ContextOwned<H> {
6487
tensors: HashMap<TensorId, TensorDescription>,
@@ -180,7 +203,11 @@ impl OperationConverter {
180203
pub(crate) fn clear(&mut self) {
181204
self.tensors_relative2global.clear();
182205
self.tensors_global2relative.clear();
206+
183207
self.shapes_global2relative.clear();
208+
// global 1 is always shape id 0.
209+
self.shapes_global2relative.insert(1, 0);
210+
184211
self.scalar_f32.clear();
185212
self.scalar_f16.clear();
186213
self.scalar_bf16.clear();
@@ -1129,7 +1156,7 @@ impl RelativeOps for BaseOperationDescription {
11291156
BaseOperationDescription::ToDevice(desc.to_relative(converter))
11301157
}
11311158
BaseOperationDescription::Reshape(desc) => {
1132-
BaseOperationDescription::Reshape(ReshapeDescription {
1159+
BaseOperationDescription::Reshape(UnaryOperationDescription {
11331160
input: desc.input.to_relative(converter),
11341161
out: desc.out.to_relative(converter),
11351162
})
@@ -1246,6 +1273,7 @@ impl RelativeOps for TensorDescription {
12461273
// We never saw this dim value before, therefore we create a new ID.
12471274
let dim_id = converter.shapes_global2relative.len();
12481275
relative_shape.push(dim_id);
1276+
12491277
converter.shapes_global2relative.insert(*dim, dim_id);
12501278
}
12511279
}
@@ -1300,7 +1328,7 @@ mod tests {
13001328
tensor1_local,
13011329
TensorDescription {
13021330
id: TensorId::new(0),
1303-
shape: vec![0, 1, 2],
1331+
shape: vec![1, 2, 3],
13041332
status: TensorStatus::ReadOnly,
13051333
dtype: DType::F32
13061334
}
@@ -1309,7 +1337,7 @@ mod tests {
13091337
tensor2_local,
13101338
TensorDescription {
13111339
id: TensorId::new(1),
1312-
shape: vec![0, 3, 2],
1340+
shape: vec![1, 4, 3],
13131341
status: TensorStatus::ReadOnly,
13141342
dtype: DType::F32
13151343
}

crates/burn-jit/src/fusion/elemwise/builder.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use burn_fusion::OptimizationBuilder;
22

33
use crate::{
44
fusion::{
5-
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
5+
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
66
JitOptimization,
77
},
88
JitRuntime,
@@ -23,7 +23,16 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
2323
let max_bindings = props.hardware_properties().max_bindings;
2424

2525
Self {
26-
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
26+
builder: FuseOnWriteBuilder::new(
27+
max_bindings,
28+
bool_precision,
29+
FuseSettings {
30+
broadcast: true,
31+
output_shape_updates: true,
32+
mix_vectorization: true,
33+
inplace: true,
34+
},
35+
),
2736
device,
2837
}
2938
}

crates/burn-jit/src/fusion/elemwise/optimization.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseRunner {
110110
},
111111
None => panic!("Invalid argument"),
112112
};
113-
114113
let total_elem = shape.iter().product::<usize>() / *vectorization as usize;
115114
let cube_dim = CubeDim::default();
116115
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
@@ -141,7 +140,7 @@ fn elemwise_fuse(
141140
let args = comptime![Sequence::<Arg>::new()];
142141
let pos = ABSOLUTE_POS;
143142

144-
let length = match comptime![config.ref_layout] {
143+
let length = match comptime![config.ref_layout.clone()] {
145144
Arg::Input(index, precision, _) => match comptime![precision] {
146145
ElemwisePrecision::F32 => inputs.t_f32.index(index).len(),
147146
ElemwisePrecision::F16 => inputs.t_f16.index(index).len(),

0 commit comments

Comments
 (0)