Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improve fusion #2773

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1c4fbb0
WIP
nathanielsimard Jan 22, 2025
06cb156
WIP
nathanielsimard Jan 22, 2025
e8493af
WIP testing
nathanielsimard Jan 23, 2025
4a2d03e
Very wip
nathanielsimard Jan 24, 2025
c83782b
WIP works better
nathanielsimard Jan 28, 2025
96d2ff0
Fix vectorization
nathanielsimard Jan 28, 2025
c431a43
Still debug
nathanielsimard Jan 28, 2025
4aeb900
Fix some problems
nathanielsimard Jan 29, 2025
e43bef0
Fix other broadcast issues
nathanielsimard Jan 29, 2025
19b01f4
Fix another bug, but still very wip
nathanielsimard Jan 30, 2025
f3b3459
WIP Works
nathanielsimard Feb 1, 2025
f80b722
Cleanup
nathanielsimard Feb 1, 2025
f86fbcb
Support broadcasted vectorization
nathanielsimard Feb 1, 2025
11550b4
Cleanup
nathanielsimard Feb 1, 2025
5e14374
Still some bugs
nathanielsimard Feb 1, 2025
b104e78
Fix multi vectorization broadcasting fused
nathanielsimard Feb 1, 2025
3eabb6c
Add fuse settings
nathanielsimard Feb 2, 2025
10fc217
Fix broadcast issue
nathanielsimard Feb 2, 2025
dcf563d
Fix performance
nathanielsimard Feb 2, 2025
4471ea3
Some cleanup
nathanielsimard Feb 2, 2025
b9bf504
Big refactoring
nathanielsimard Feb 2, 2025
6ce7c3e
Add reshape optimization
nathanielsimard Feb 3, 2025
f0c82e2
Merge branch 'main' into feat/fuse-reshape
nathanielsimard Feb 3, 2025
e127325
Cleanup
nathanielsimard Feb 3, 2025
1a84818
Add some docs
nathanielsimard Feb 3, 2025
5b25f18
Update cubecl ref
nathanielsimard Feb 3, 2025
651bb4c
Clippy + Fmt
nathanielsimard Feb 3, 2025
4ce7993
Add vulkan in example
nathanielsimard Feb 3, 2025
12de765
WIP
nathanielsimard Feb 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
3 changes: 1 addition & 2 deletions backend-comparison/benches/matmul_fused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}

fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
gelu(relu(lhs.matmul(rhs)) + bias);
let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze());
}

fn prepare(&self) -> Self::Args {
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ extern crate alloc;
pub type TestBackend = burn_ndarray::NdArray<f32>;

#[cfg(all(test, feature = "test-tch"))]
/// Backend for test cases
pub type TestBackend = burn_tch::LibTorch<f32>;

#[cfg(all(test, feature = "test-wgpu"))]
/// Backend for test cases
pub type TestBackend = burn_wgpu::Wgpu;

#[cfg(all(test, feature = "test-cuda"))]
/// Backend for test cases
pub type TestBackend = burn_cuda::Cuda;

/// Backend for autodiff test cases
Expand Down
1 change: 0 additions & 1 deletion crates/burn-core/src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ impl<B: Backend> Linear<B> {

let weight = self.weight.val().unsqueeze();
let bias = self.bias.as_ref().map(|b| b.val().unsqueeze());

let output = input.matmul(weight);

match bias {
Expand Down
25 changes: 11 additions & 14 deletions crates/burn-core/src/nn/transformer/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ impl<B: Backend> TransformerDecoder<B> {

#[cfg(test)]
mod tests {
use burn_tensor::Device;

use super::*;
use crate::tensor::Distribution;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};

#[test]
Expand All @@ -481,20 +482,16 @@ mod tests {
}

fn test_autoregressive(config: TransformerDecoderConfig) {
let device = Default::default();
let device: Device<TestBackend> = Default::default();
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let transformer = config.init(&device);

let memory = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let target = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let transformer = config.init::<TestBackend>(&device);

let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
.float()
.reshape([batch_size, seq_length, d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
};
Expand Down Expand Up @@ -182,7 +182,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_reshape(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -197,7 +197,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_reshape(tensor: FloatTensor<Self>, shape: Shape) -> FloatTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -666,7 +666,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
let dtype = tensor.dtype;
let out = tensor.client.tensor_uninitialized(shape.dims, dtype);

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
#[derive(new)]
struct ReshapeDimsOps<B: FusionBackend> {
desc: ReshapeDescription,
desc: UnaryOperationDescription,
_b: PhantomData<B>,
}

Expand All @@ -120,7 +120,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape.dims, B::IntElem::dtype());

let desc = ReshapeDescription {
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
Expand Down
36 changes: 32 additions & 4 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,9 @@ pub struct Context<'a, H> {
pub scalar_u8: &'a Vec<u8>,
}

#[derive(Default)]
pub(crate) struct OperationConverter {
tensors_relative2global: HashMap<TensorId, TensorDescription>,
tensors_global2relative: HashMap<TensorId, TensorDescription>,
/// Only useful to create new shape ID.
/// You should use tensor descriptions to retrieve the proper shape.
shapes_global2relative: HashMap<usize, usize>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
Expand All @@ -59,6 +56,32 @@ pub(crate) struct OperationConverter {
scalar_u8: Vec<u8>,
}

impl Default for OperationConverter {
fn default() -> Self {
let mut val = Self {
tensors_relative2global: Default::default(),
tensors_global2relative: Default::default(),
shapes_global2relative: Default::default(),
scalar_f32: Default::default(),
scalar_f16: Default::default(),
scalar_bf16: Default::default(),
scalar_i64: Default::default(),
scalar_i32: Default::default(),
scalar_i16: Default::default(),
scalar_i8: Default::default(),
scalar_u64: Default::default(),
scalar_u32: Default::default(),
scalar_u16: Default::default(),
scalar_u8: Default::default(),
};

// global 1 is always shape id 0.
val.shapes_global2relative.insert(1, 0);

val
}
}

/// Fork of a [context](Context) which owns its data.
pub struct ContextOwned<H> {
tensors: HashMap<TensorId, TensorDescription>,
Expand Down Expand Up @@ -180,7 +203,11 @@ impl OperationConverter {
pub(crate) fn clear(&mut self) {
self.tensors_relative2global.clear();
self.tensors_global2relative.clear();

self.shapes_global2relative.clear();
// global 1 is always shape id 0.
self.shapes_global2relative.insert(1, 0);

self.scalar_f32.clear();
self.scalar_f16.clear();
self.scalar_bf16.clear();
Expand Down Expand Up @@ -1126,7 +1153,7 @@ impl RelativeOps for BaseOperationDescription {
BaseOperationDescription::ToDevice(desc.to_relative(converter))
}
BaseOperationDescription::Reshape(desc) => {
BaseOperationDescription::Reshape(ReshapeDescription {
BaseOperationDescription::Reshape(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
Expand Down Expand Up @@ -1241,6 +1268,7 @@ impl RelativeOps for TensorDescription {
// We never saw this dim value before, therefore we create a new ID.
let dim_id = converter.shapes_global2relative.len();
relative_shape.push(dim_id);

converter.shapes_global2relative.insert(*dim, dim_id);
}
}
Expand Down
13 changes: 11 additions & 2 deletions crates/burn-jit/src/fusion/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use burn_fusion::OptimizationBuilder;

use crate::{
fusion::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision},
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
JitOptimization,
},
JitRuntime,
Expand All @@ -23,7 +23,16 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(max_bindings, bool_precision),
builder: FuseOnWriteBuilder::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
mix_vectorization: true,
inplace: true,
},
),
device,
}
}
Expand Down
Loading
Loading