Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improve fusion #2773

Merged
merged 38 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1c4fbb0
WIP
nathanielsimard Jan 22, 2025
06cb156
WIP
nathanielsimard Jan 22, 2025
e8493af
WIP testing
nathanielsimard Jan 23, 2025
4a2d03e
Very wip
nathanielsimard Jan 24, 2025
c83782b
WIP works better
nathanielsimard Jan 28, 2025
96d2ff0
Fix vectorization
nathanielsimard Jan 28, 2025
c431a43
Still debug
nathanielsimard Jan 28, 2025
4aeb900
Fix some problems
nathanielsimard Jan 29, 2025
e43bef0
Fix other broadcast issues
nathanielsimard Jan 29, 2025
19b01f4
Fix another bug, but still very wip
nathanielsimard Jan 30, 2025
f3b3459
WIP Works
nathanielsimard Feb 1, 2025
f80b722
Cleanup
nathanielsimard Feb 1, 2025
f86fbcb
Support broadcasted vectorization
nathanielsimard Feb 1, 2025
11550b4
Cleanup
nathanielsimard Feb 1, 2025
5e14374
Still some bugs
nathanielsimard Feb 1, 2025
b104e78
Fix multi vectorization broadcasting fused
nathanielsimard Feb 1, 2025
3eabb6c
Add fuse settings
nathanielsimard Feb 2, 2025
10fc217
Fix broadcast issue
nathanielsimard Feb 2, 2025
dcf563d
Fix performance
nathanielsimard Feb 2, 2025
4471ea3
Some cleanup
nathanielsimard Feb 2, 2025
b9bf504
Big refactoring
nathanielsimard Feb 2, 2025
6ce7c3e
Add reshape optimization
nathanielsimard Feb 3, 2025
f0c82e2
Merge branch 'main' into feat/fuse-reshape
nathanielsimard Feb 3, 2025
e127325
Cleanup
nathanielsimard Feb 3, 2025
1a84818
Add some docs
nathanielsimard Feb 3, 2025
5b25f18
Update cubecl ref
nathanielsimard Feb 3, 2025
651bb4c
Clippy + Fmt
nathanielsimard Feb 3, 2025
4ce7993
Add vulkan in example
nathanielsimard Feb 3, 2025
12de765
WIP
nathanielsimard Feb 4, 2025
8ed91eb
Fix test
nathanielsimard Feb 4, 2025
963e241
Cleanup
nathanielsimard Feb 4, 2025
756f7f0
Fix no std tests
nathanielsimard Feb 4, 2025
a778d34
Better autotune
nathanielsimard Feb 4, 2025
07b317c
Merge branch 'main' into feat/fuse-reshape
nathanielsimard Feb 5, 2025
53d2f57
Remove print
nathanielsimard Feb 5, 2025
76f4067
Update crates/burn-jit/src/fusion/on_write/trace/output.rs
laggui Feb 6, 2025
a2f01b2
Update crates/burn-jit/src/fusion/on_write/trace/plan.rs
laggui Feb 6, 2025
7f596f0
Merge branch 'main' into feat/fuse-reshape
laggui Feb 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions backend-comparison/benches/matmul_fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
gelu(relu(lhs.matmul(rhs)) + bias);
let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze());
}

fn prepare(&self) -> Self::Args {
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ extern crate alloc;
pub type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;

#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;

/// Backend for autodiff test cases
Expand Down
1 change: 0 additions & 1 deletion crates/burn-core/src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ impl<B: Backend> Linear<B> {

let weight = self.weight.val().unsqueeze();
let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());

let output = input.matmul(weight);

match bias {
Expand Down
25 changes: 11 additions & 14 deletions crates/burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ impl<B: Backend> TransformerDecoder<B> {

#[cfg(test)]
mod tests {
use burn_tensor::Device;

use super::*;
use crate::tensor::Distribution;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};

#[test]
Expand All @@ -481,20 +482,16 @@ mod tests {
}

fn test_autoregressive(config: TransformerDecoderConfig) {
let device = Default::default();
let device: Device<TestBackend> = Default::default();
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let transformer = config.init(&device);

let memory = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let target = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let transformer = config.init::<TestBackend>(&device);

let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
};
Expand Down Expand Up @@ -171,7 +171,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -186,7 +186,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -656,7 +656,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let out = tensor.client.tensor_uninitialized(shape.dims, dtype);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -110,7 +110,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape.dims, B::IntElem::dtype());

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
41 changes: 34 additions & 7 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ pub struct Context<'a, H> {
pub scalar_u8: &'a Vec<u8>,
}

#[derive(Default)]
pub(crate) struct OperationConverter {
tensors_relative2global: HashMap<TensorId, TensorDescription>,
tensors_global2relative: HashMap<TensorId, TensorDescription>,
/// Only useful to create new shape ID.
/// You should use tensor descriptions to retrieve the proper shape.
shapes_global2relative: HashMap<usize, usize>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
Expand All @@ -59,6 +56,32 @@ pub(crate) struct OperationConverter {
scalar_u8: Vec<u8>,
}

impl Default for OperationConverter {
fn default() -> Self {
let mut val = Self {
tensors_relative2global: Default::default(),
tensors_global2relative: Default::default(),
shapes_global2relative: Default::default(),
scalar_f32: Default::default(),
scalar_f16: Default::default(),
scalar_bf16: Default::default(),
scalar_i64: Default::default(),
scalar_i32: Default::default(),
scalar_i16: Default::default(),
scalar_i8: Default::default(),
scalar_u64: Default::default(),
scalar_u32: Default::default(),
scalar_u16: Default::default(),
scalar_u8: Default::default(),
};

// global 1 is always shape id 0.
val.shapes_global2relative.insert(1, 0);

val
}
}

/// Fork of a [context](Context) which owns its data.
pub struct ContextOwned<H> {
tensors: HashMap<TensorId, TensorDescription>,
Expand Down Expand Up @@ -180,7 +203,11 @@ impl OperationConverter {
pub(crate) fn clear(&mut self) {
self.tensors_relative2global.clear();
self.tensors_global2relative.clear();

self.shapes_global2relative.clear();
// global 1 is always shape id 0.
self.shapes_global2relative.insert(1, 0);

self.scalar_f32.clear();
self.scalar_f16.clear();
self.scalar_bf16.clear();
Expand Down Expand Up @@ -227,7 +254,6 @@ impl OperationConverter {

impl RelativeOps for OperationDescription {
fn to_relative(&self, converter: &mut OperationConverter) -> Self {
println!("To relative {self:?}");
match self {
OperationDescription::BaseFloat(ops) => {
OperationDescription::BaseFloat(ops.to_relative(converter))
Expand Down Expand Up @@ -1130,7 +1156,7 @@ impl RelativeOps for BaseOperationDescription {
BaseOperationDescription::ToDevice(desc.to_relative(converter))
}
BaseOperationDescription::Reshape(desc) => {
BaseOperationDescription::Reshape(ReshapeDescription {
BaseOperationDescription::Reshape(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
Expand Down Expand Up @@ -1247,6 +1273,7 @@ impl RelativeOps for TensorDescription {
// We never saw this dim value before, therefore we create a new ID.
let dim_id = converter.shapes_global2relative.len();
relative_shape.push(dim_id);

converter.shapes_global2relative.insert(*dim, dim_id);
}
}
Expand Down Expand Up @@ -1301,7 +1328,7 @@ mod tests {
tensor1_local,
TensorDescription {
id: TensorId::new(0),
shape: vec![0, 1, 2],
shape: vec![1, 2, 3],
status: TensorStatus::ReadOnly,
dtype: DType::F32
}
Expand All @@ -1310,7 +1337,7 @@ mod tests {
tensor2_local,
TensorDescription {
id: TensorId::new(1),
shape: vec![0, 3, 2],
shape: vec![1, 4, 3],
status: TensorStatus::ReadOnly,
dtype: DType::F32
}
Expand Down
13 changes: 11 additions & 2 deletions crates/burn-jit/src/fusion/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use burn_fusion::OptimizationBuilder;

use crate::{
fusion::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
JitOptimization,
},
JitRuntime,
Expand All @@ -23,7 +23,16 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
builder: FuseOnWriteBuilder::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
mix_vectorization: true,
inplace: true,
},
),
device,
}
}
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/fusion/elemwise/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ impl<R: JitRuntime> TraceRunner<R> for ElemwiseRunner {
},
None => panic!("Invalid argument"),
};

let total_elem = shape.iter().product::<usize>() / *vectorization as usize;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim);
Expand Down Expand Up @@ -141,7 +140,7 @@ fn elemwise_fuse(
let args = comptime![Sequence::<Arg>::new()];
let pos = ABSOLUTE_POS;

let length = match comptime![config.ref_layout] {
let length = match comptime![config.ref_layout.clone()] {
Arg::Input(index, precision, _) => match comptime![precision] {
ElemwisePrecision::F32 => inputs.t_f32.index(index).len(),
ElemwisePrecision::F16 => inputs.t_f16.index(index).len(),
Expand Down
12 changes: 7 additions & 5 deletions crates/burn-jit/src/fusion/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl MatmulArgs for FusedMatmulArgs {
LayoutInfo::IsRef,
precision,
&state.config,
None,
)
}

Expand All @@ -70,15 +71,16 @@ impl MatmulArgs for FusedMatmulArgs {
LayoutInfo::IsRef,
precision,
&state.config,
None,
)
}

fn write_out<EG: Numeric>(state: &mut Self::State<EG>, coordinate: u32, value: Line<EG>) {
let mut values = Registry::<Arg, Line<EG>>::new();
let mut args = comptime![Sequence::<Arg>::new()];

values.insert(state.out, value);
comptime![args.push(state.out)];
values.insert(comptime![state.out.clone()], value);
comptime![args.push(state.out.clone())];

fuse_on_write(
unsafe { &(*state.inputs) },
Expand Down Expand Up @@ -225,9 +227,9 @@ impl FusedMatmulState {
inputs: &inputs.global,
outputs,
config: comptime![config.clone()],
lhs: comptime![inputs.lhs],
rhs: comptime![inputs.rhs],
out: comptime![inputs.out],
lhs: comptime![inputs.lhs.clone()],
rhs: comptime![inputs.rhs.clone()],
out: comptime![inputs.out.clone()],
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions crates/burn-jit/src/fusion/matmul/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use burn_tensor::repr::{FloatOperationDescription, OperationDescription};

use crate::{
fusion::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
JitOptimization,
},
JitRuntime,
Expand All @@ -24,10 +24,16 @@ impl<R: JitRuntime> MatmulBuilder<R> {
let client = R::client(&device);
let props = client.properties();
let max_bindings = props.hardware_properties().max_bindings;
let settings = FuseSettings {
broadcast: true,
output_shape_updates: false,
mix_vectorization: true,
inplace: true,
};

Self {
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision),
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings),
builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings),
device,
matmul: None,
}
Expand Down Expand Up @@ -56,6 +62,7 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
));
} else {
self.builder.close();
self.builder_fallback.close();
}
} else {
self.builder.register(operation);
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-jit/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ impl<R: JitRuntime> MatmulOptimization<R> {
}
}

/// Returns the number of output buffers added by fusion.
pub fn num_output_buffers(&self) -> usize {
self.trace_fallback.outputs.len()
}

pub fn execute_standard_fused<BT: BoolElement>(
&self,
context: &mut Context<'_, JitFusionHandle<R>>,
Expand Down
Loading