Skip to content

Commit

Permalink
Feat/fused matmul tune (#2726)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jan 22, 2025
1 parent b33bd24 commit 245fbcd
Show file tree
Hide file tree
Showing 14 changed files with 525 additions and 60 deletions.
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
10 changes: 7 additions & 3 deletions backend-comparison/benches/matmul_fused.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use backend_comparison::persistence::save;
use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor};
use burn::tensor::{
activation::{gelu, relu},
backend::Backend,
Distribution, Shape, Tensor,
};
use burn_common::benchmark::{run_benchmark, Benchmark};
use derive_new::new;

Expand All @@ -14,7 +18,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
type Args = (Tensor<B, D>, Tensor<B, D>, Tensor<B, 1>);

fn name(&self) -> String {
"matmul_bias_relu".into()
"matmul_relu_bias_gelu".into()
}

fn shapes(&self) -> Vec<Vec<usize>> {
Expand All @@ -23,7 +27,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {

fn execute(&self, (lhs, rhs, bias): Self::Args) {
let bias = bias.unsqueeze();
relu(lhs.matmul(rhs) + bias);
gelu(relu(lhs.matmul(rhs)) + bias);
}

fn prepare(&self) -> Self::Args {
Expand Down
78 changes: 78 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,84 @@ pub(crate) struct OperationConverter {
scalar_u8: Vec<u8>,
}

/// Fork of a [context](Context) which owns its data.
pub struct ContextOwned<H> {
tensors: HashMap<TensorId, TensorDescription>,
handles: HandleContainer<H>,
scalar_f32: Vec<f32>,
scalar_f16: Vec<f16>,
scalar_bf16: Vec<bf16>,
scalar_i64: Vec<i64>,
scalar_i32: Vec<i32>,
scalar_i16: Vec<i16>,
scalar_i8: Vec<i8>,
scalar_u64: Vec<u64>,
scalar_u32: Vec<u32>,
scalar_u16: Vec<u16>,
scalar_u8: Vec<u8>,
}

impl<H: Clone> ContextOwned<H> {
/// Convert into [context](Context).
pub fn as_context(&mut self) -> Context<'_, H> {
Context {
tensors: &mut self.tensors,
handles: &mut self.handles,
scalar_f32: &self.scalar_f32,
scalar_f16: &self.scalar_f16,
scalar_bf16: &self.scalar_bf16,
scalar_i64: &self.scalar_i64,
scalar_i32: &self.scalar_i32,
scalar_i16: &self.scalar_i16,
scalar_i8: &self.scalar_i8,
scalar_u64: &self.scalar_u64,
scalar_u32: &self.scalar_u32,
scalar_u16: &self.scalar_u16,
scalar_u8: &self.scalar_u8,
}
}

/// Fork the context again.
pub fn fork(&self) -> ContextOwned<H> {
ContextOwned {
tensors: self.tensors.clone(),
handles: self.handles.fork(),
scalar_f32: self.scalar_f32.clone(),
scalar_f16: self.scalar_f16.clone(),
scalar_bf16: self.scalar_bf16.clone(),
scalar_i64: self.scalar_i64.clone(),
scalar_i32: self.scalar_i32.clone(),
scalar_i16: self.scalar_i16.clone(),
scalar_i8: self.scalar_i8.clone(),
scalar_u64: self.scalar_u64.clone(),
scalar_u32: self.scalar_u32.clone(),
scalar_u16: self.scalar_u16.clone(),
scalar_u8: self.scalar_u8.clone(),
}
}
}

impl<H: Clone> Context<'_, H> {
/// Fork the context into an [owned context](ContextOwned).
pub fn fork(&self) -> ContextOwned<H> {
ContextOwned {
tensors: self.tensors.clone(),
handles: self.handles.fork(),
scalar_f32: self.scalar_f32.clone(),
scalar_f16: self.scalar_f16.clone(),
scalar_bf16: self.scalar_bf16.clone(),
scalar_i64: self.scalar_i64.clone(),
scalar_i32: self.scalar_i32.clone(),
scalar_i16: self.scalar_i16.clone(),
scalar_i8: self.scalar_i8.clone(),
scalar_u64: self.scalar_u64.clone(),
scalar_u32: self.scalar_u32.clone(),
scalar_u16: self.scalar_u16.clone(),
scalar_u8: self.scalar_u8.clone(),
}
}
}

pub(crate) trait RelativeOps {
/// Convert (usually an [`OperationDescription`]) to a relative form.
///
Expand Down
16 changes: 6 additions & 10 deletions crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,16 @@ impl<R: JitRuntime, BT: BoolElement> FusionRuntime for FusionJitRuntime<R, BT> {
fn optimizations(
device: R::Device,
) -> Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> {
let mut optimizations: Vec<Box<dyn burn_fusion::OptimizationBuilder<Self::Optimization>>> =
vec![Box::new(ElementWiseBuilder::<R>::new(
vec![
Box::new(ElementWiseBuilder::<R>::new(
device.clone(),
BT::as_elem_native_unchecked().into(),
))];

if cfg!(feature = "fusion-experimental") {
optimizations.push(Box::new(MatmulBuilder::<R>::new(
)),
Box::new(MatmulBuilder::<R>::new(
device.clone(),
BT::as_elem_native_unchecked().into(),
)));
}

optimizations
)),
]
}
}

Expand Down
8 changes: 7 additions & 1 deletion crates/burn-jit/src/fusion/matmul/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ impl<R: JitRuntime> OptimizationBuilder<JitOptimization<R>> for MatmulBuilder<R>
let rhs = self.builder.input_unhandled(&op.rhs);
let out = self.builder.output_unhandled(&op.out);

self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone()));
self.matmul = Some(FusedMatmul::new(
lhs,
rhs,
out,
op.clone(),
Default::default(),
));
} else {
self.builder.close();
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub(crate) mod args;
pub(crate) mod builder;
pub(crate) mod optimization;
pub(crate) mod spec;
pub(crate) mod tune;
Loading

0 comments on commit 245fbcd

Please sign in to comment.