diff --git a/Cargo.lock b/Cargo.lock index ab9eddb0bd..62abf8ac0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1475,7 +1475,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1490,7 +1489,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1511,7 +1509,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1532,7 +1529,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1546,7 +1542,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1562,7 +1557,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1588,7 +1582,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1606,7 +1599,6 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-core", @@ -1618,7 +1610,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "darling", @@ -1633,7 +1624,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "darling", "proc-macro2", @@ -1644,7 +1634,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1660,7 +1649,6 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1670,7 +1658,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "async-channel", "async-lock", @@ -1692,7 +1679,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1707,7 +1693,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "ash", "async-channel", @@ -3544,7 +3529,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 169d668aa8..331ac11c18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs index fbec64c648..65c2a92074 100644 --- a/backend-comparison/benches/matmul_fused.rs +++ b/backend-comparison/benches/matmul_fused.rs @@ -26,8 +26,7 @@ impl Benchmark for MatmulBenchmark { } fn execute(&self, (lhs, rhs, bias): Self::Args) { - let bias = bias.unsqueeze(); - gelu(relu(lhs.matmul(rhs)) + bias); + let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze()); } fn prepare(&self) -> Self::Args { diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index ade8d64db7..34887ec6b9 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -59,12 +59,15 @@ extern crate alloc; pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] +/// Backend for test cases pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] +/// Backend for test cases pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] +/// Backend for test cases pub type TestBackend = burn_cuda::Cuda; /// Backend for autodiff test cases diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 738dd87c80..d3cf5951b0 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -77,7 +77,6 @@ impl Linear { let weight = self.weight.val().unsqueeze(); let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); - let output = input.matmul(weight); match bias { diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index c79c21a78c..a3d7e49158 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -455,8 +455,9 @@ impl TransformerDecoder { #[cfg(test)] mod tests { + use burn_tensor::Device; + use super::*; - use crate::tensor::Distribution; use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; #[test] @@ -481,20 +482,16 @@ mod tests { } fn test_autoregressive(config: TransformerDecoderConfig) { - let device = Default::default(); + let device: Device = Default::default(); let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(&device); - - let memory = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - &device, - ); - let target = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - &device, - ); + let transformer = config.init::(&device); + + let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) + .float() + .reshape([batch_size, seq_length, d_model]); + let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) + .float() + .reshape([batch_size, seq_length, d_model]); let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); let input = TransformerDecoderInput::new(target.clone(), memory.clone()) .target_mask_attn(mask_attn); diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 658907bf3e..a37fb98495 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -17,8 +17,8 @@ use burn_tensor::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, HandleContainer, OperationDescription, PermuteOperationDescription, - RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, - SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }, Device, Shape, }; @@ -182,7 +182,7 @@ impl BoolTensorOps for Fusion { fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -197,7 +197,7 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 1ba2717bfb..b798ac618f 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -650,7 +650,7 @@ impl FloatTensorOps for Fusion { fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -666,7 +666,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let out = tensor.client.tensor_uninitialized(shape.dims, dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bf88bbd25b..4ee4d6e804 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -103,7 +103,7 @@ impl IntTensorOps for Fusion { fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -120,7 +120,7 @@ impl IntTensorOps for Fusion { .client .tensor_uninitialized(shape.dims, B::IntElem::dtype()); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index ed1a1902f8..671ecfb473 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -39,12 +39,9 @@ pub struct Context<'a, H> { pub scalar_u8: &'a Vec, } -#[derive(Default)] pub(crate) struct OperationConverter { tensors_relative2global: HashMap, tensors_global2relative: HashMap, - /// Only useful to create new shape ID. - /// You should use tensor descriptions to retrieve the proper shape. shapes_global2relative: HashMap, scalar_f32: Vec, scalar_f16: Vec, @@ -59,6 +56,32 @@ pub(crate) struct OperationConverter { scalar_u8: Vec, } +impl Default for OperationConverter { + fn default() -> Self { + let mut val = Self { + tensors_relative2global: Default::default(), + tensors_global2relative: Default::default(), + shapes_global2relative: Default::default(), + scalar_f32: Default::default(), + scalar_f16: Default::default(), + scalar_bf16: Default::default(), + scalar_i64: Default::default(), + scalar_i32: Default::default(), + scalar_i16: Default::default(), + scalar_i8: Default::default(), + scalar_u64: Default::default(), + scalar_u32: Default::default(), + scalar_u16: Default::default(), + scalar_u8: Default::default(), + }; + + // global 1 is always shape id 0. + val.shapes_global2relative.insert(1, 0); + + val + } +} + /// Fork of a [context](Context) which owns its data. pub struct ContextOwned { tensors: HashMap, @@ -180,7 +203,11 @@ impl OperationConverter { pub(crate) fn clear(&mut self) { self.tensors_relative2global.clear(); self.tensors_global2relative.clear(); + self.shapes_global2relative.clear(); + // global 1 is always shape id 0. + self.shapes_global2relative.insert(1, 0); + self.scalar_f32.clear(); self.scalar_f16.clear(); self.scalar_bf16.clear(); @@ -1126,7 +1153,7 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::ToDevice(desc.to_relative(converter)) } BaseOperationDescription::Reshape(desc) => { - BaseOperationDescription::Reshape(ReshapeDescription { + BaseOperationDescription::Reshape(UnaryOperationDescription { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }) @@ -1241,6 +1268,7 @@ impl RelativeOps for TensorDescription { // We never saw this dim value before, therefore we create a new ID. let dim_id = converter.shapes_global2relative.len(); relative_shape.push(dim_id); + converter.shapes_global2relative.insert(*dim, dim_id); } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 461767e9fc..a137f88692 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -2,7 +2,7 @@ use burn_fusion::OptimizationBuilder; use crate::{ fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, JitOptimization, }, JitRuntime, @@ -23,7 +23,16 @@ impl ElementWiseBuilder { let max_bindings = props.hardware_properties().max_bindings; Self { - builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), + builder: FuseOnWriteBuilder::new( + max_bindings, + bool_precision, + FuseSettings { + broadcast: true, + output_shape_updates: true, + mix_vectorization: true, + inplace: true, + }, + ), device, } } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 2e33eefc20..71b44b8e44 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -110,7 +110,6 @@ impl TraceRunner for ElemwiseRunner { }, None => panic!("Invalid argument"), }; - let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); @@ -141,7 +140,7 @@ fn elemwise_fuse( let args = comptime![Sequence::::new()]; let pos = ABSOLUTE_POS; - let length = match comptime![config.ref_layout] { + let length = match comptime![config.ref_layout.clone()] { Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => inputs.t_f32.index(index).len(), ElemwisePrecision::F16 => inputs.t_f16.index(index).len(), diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index bba18e88f9..00532853d5 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -51,6 +51,7 @@ impl MatmulArgs for FusedMatmulArgs { LayoutInfo::IsRef, precision, &state.config, + None, ) } @@ -70,6 +71,7 @@ impl MatmulArgs for FusedMatmulArgs { LayoutInfo::IsRef, precision, &state.config, + None, ) } @@ -77,8 +79,8 @@ impl MatmulArgs for FusedMatmulArgs { let mut values = Registry::>::new(); let mut args = comptime![Sequence::::new()]; - values.insert(state.out, value); - comptime![args.push(state.out)]; + values.insert(comptime![state.out.clone()], value); + comptime![args.push(state.out.clone())]; fuse_on_write( unsafe { &(*state.inputs) }, @@ -225,9 +227,9 @@ impl FusedMatmulState { inputs: &inputs.global, outputs, config: comptime![config.clone()], - lhs: comptime![inputs.lhs], - rhs: comptime![inputs.rhs], - out: comptime![inputs.out], + lhs: comptime![inputs.lhs.clone()], + rhs: comptime![inputs.rhs.clone()], + out: comptime![inputs.out.clone()], } } } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index f197237819..11e38626e3 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -3,7 +3,7 @@ use burn_tensor::repr::{FloatOperationDescription, OperationDescription}; use crate::{ fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, JitOptimization, }, JitRuntime, @@ -24,10 +24,16 @@ impl MatmulBuilder { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware_properties().max_bindings; + let settings = FuseSettings { + broadcast: true, + output_shape_updates: false, + mix_vectorization: true, + inplace: true, + }; Self { - builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), - builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision), + builder: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings), + builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings), device, matmul: None, } @@ -56,6 +62,7 @@ impl OptimizationBuilder> for MatmulBuilder )); } else { self.builder.close(); + self.builder_fallback.close(); } } else { self.builder.register(operation); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index bf31ef78ea..d1584af3b7 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -1,7 +1,7 @@ use super::{ ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs}, - trace::FuseOnWriteTrace, - trace_builder::FuseOnWriteTraceBuilder, + settings::FuseSettings, + trace::{FuseOnWriteTrace, FuseOnWriteTraceBuilder}, }; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ @@ -17,9 +17,11 @@ use cubecl::ir::Elem; /// Fused element wise operations that are normally memory bound. pub(crate) struct FuseOnWriteBuilder { builder: TryFuseBuilder, + settings: FuseSettings, current_output_shape: Vec, status: OptimizationStatus, - num_ops: usize, + pub(crate) num_ops: usize, + pub(crate) num_reshapes: usize, max_bindings: u32, } @@ -30,33 +32,40 @@ struct TryFuseBuilder { } impl TryFuseBuilder { - fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { + fn new(max_bindings: u32, bool_precision: ElemwisePrecision, settings: FuseSettings) -> Self { Self { - builder: FuseOnWriteTraceBuilder::new(bool_precision), + builder: FuseOnWriteTraceBuilder::new(bool_precision, settings), max_bindings, added_ops: false, } } - fn register(&mut self, add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder)) -> bool { + fn register(&mut self, add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder) -> bool) -> bool { // Always allow the first operation to be added. if !self.added_ops { self.added_ops = true; - add_ops(&mut self.builder); + + if !add_ops(&mut self.builder) { + return false; + } return true; } let mut cloned = self.builder.clone(); - add_ops(&mut cloned); + if !add_ops(&mut cloned) { + return false; + } + if cloned.estimate_bindings() > self.max_bindings { return false; } + self.builder = cloned; true } - fn build(&self) -> FuseOnWriteTrace { - self.builder.build() + fn build(&self, shape: Vec) -> FuseOnWriteTrace { + self.builder.build(shape) } } @@ -97,6 +106,12 @@ impl OptimizationBuilder for FuseOnWriteBuilder { return; } } + OperationDescription::BaseBool(ops) => { + if !self.register_base(ops) { + self.status = OptimizationStatus::Closed; + return; + } + } _ => { self.status = OptimizationStatus::Closed; return; @@ -108,7 +123,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { - self.builder.build() + self.builder.build(self.current_output_shape.clone()) } fn len(&self) -> usize { @@ -118,7 +133,11 @@ impl OptimizationBuilder for FuseOnWriteBuilder { fn reset(&mut self) { self.num_ops = 0; self.status = OptimizationStatus::Open; - self.builder = TryFuseBuilder::new(self.max_bindings, self.builder.builder.bool_precision); + self.builder = TryFuseBuilder::new( + self.max_bindings, + self.builder.builder.bool_precision, + self.settings, + ); self.current_output_shape.clear(); } @@ -137,10 +156,16 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } impl FuseOnWriteBuilder { - pub fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { + pub fn new( + max_bindings: u32, + bool_precision: ElemwisePrecision, + settings: FuseSettings, + ) -> Self { Self { - builder: TryFuseBuilder::new(max_bindings, bool_precision), + builder: TryFuseBuilder::new(max_bindings, bool_precision, settings), + settings, num_ops: 0, + num_reshapes: 0, max_bindings, current_output_shape: Vec::new(), status: OptimizationStatus::Open, @@ -158,6 +183,9 @@ impl FuseOnWriteBuilder { pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { if self.current_output_shape.is_empty() { self.current_output_shape = tensor.shape.clone(); + } else if self.current_output_shape.iter().sum::() < tensor.shape.iter().sum() { + // The larguest shape win. + self.current_output_shape = tensor.shape.clone(); } self.builder.builder.output_unhandled(tensor) @@ -172,6 +200,39 @@ impl FuseOnWriteBuilder { BaseOperationDescription::Cast(desc) => self.register_unary_ops(desc, |input, out| { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), + BaseOperationDescription::Reshape(desc) => { + if desc.input.shape == desc.out.shape { + return self.register_unary_ops(desc, |input, out| { + ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) + }); + } + + if desc.input.shape.len() > desc.out.shape.len() { + // Not yet supported. + return false; + } + + if !self.output_is_compatible(&desc.out) { + return false; + } + + if self.builder.register(|build| { + let input = match build.input_reshaped(&desc.input, &desc.out) { + Some(val) => val, + None => return false, + }; + let out = build.output(&desc.out); + + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true + }) { + self.num_reshapes += 1; + true + } else { + false + } + } _ => false, } } @@ -302,7 +363,9 @@ impl FuseOnWriteBuilder { lhs, rhs, out, - }) + }); + + true }) } NumericOperationDescription::MaskFill(desc) => { @@ -321,7 +384,9 @@ impl FuseOnWriteBuilder { lhs, rhs, out, - }) + }); + + true }) } NumericOperationDescription::Ones(desc) => { @@ -336,7 +401,9 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } NumericOperationDescription::Zeros(desc) => { @@ -351,7 +418,9 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } NumericOperationDescription::Full((desc, elem)) => { @@ -363,7 +432,9 @@ impl FuseOnWriteBuilder { let input = build.scalar(elem, desc.dtype); let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } _ => false, @@ -383,7 +454,9 @@ impl FuseOnWriteBuilder { let rhs = build.input(&desc.rhs); let out = build.output(&desc.out); - build.register_operation(func(lhs, rhs, out)) + build.register_operation(func(lhs, rhs, out)); + + true }) } @@ -398,7 +471,8 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let input = build.input(&desc.input); let out = build.output(&desc.out); - build.register_operation(func(input, out)) + build.register_operation(func(input, out)); + true }) } @@ -420,16 +494,59 @@ impl FuseOnWriteBuilder { let rhs = build.scalar(&desc.rhs, elem); let out = build.output(&desc.out); - build.register_operation(func(lhs, rhs, out)) + build.register_operation(func(lhs, rhs, out)); + + true }) } fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { if self.current_output_shape.is_empty() { self.current_output_shape.clone_from(&out.shape); - } else if self.current_output_shape != out.shape { + return true; + } + + let rank = self.current_output_shape.len(); + + // Rank should be equal. + if rank != out.shape.len() { + return false; + } + + let mut updated = self.current_output_shape.clone(); + + #[allow(clippy::needless_range_loop)] + for i in 0..rank { + let curr = self.current_output_shape[i]; + let new = out.shape[i]; + + if curr == new { + continue; + } + + // Broadcast not enabled. + if !self.settings.broadcast { + return false; + } + + // Broadcasted on new dim. + if new == 0 { + continue; + } + + // Broadcasted on curr dim - update reference output shape. + if curr == 0 && self.settings.output_shape_updates { + updated[i] = new; + continue; + } + + return false; + } + + if updated != out.shape { return false; } + self.current_output_shape.clone_from_slice(&out.shape); true } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 497bc510df..cacb2decc5 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -1,5 +1,8 @@ use super::ir::*; -use cubecl::{linalg::tensor::index_offset_with_layout, prelude::*}; +use cubecl::{ + ir::{ExpandElement, Variable}, + prelude::*, +}; #[cube] /// Read the value from the [arg](Arg) and cast it to the generic cube primitive. @@ -12,9 +15,9 @@ pub fn read( #[comptime] config: &ElemwiseConfig, ) -> Line { match arg { - Arg::Input(pos, precision, layout) => { - read_input(inputs, outputs, pos, ref_pos, layout, precision, config) - } + Arg::Input(pos, precision, layout) => read_input( + inputs, outputs, pos, ref_pos, layout, precision, config, None, + ), Arg::Output(pos, precision, layout) => { read_output(inputs, outputs, pos, ref_pos, layout, precision, config) } @@ -32,21 +35,62 @@ pub fn read( ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)), ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)), }, + Arg::Scalar(..) => { + let scalar = read_scalar::(inputs, arg); + Line::new(scalar) + } + Arg::ScalarShape(_) => { + let scalar = read_scalar_shape(inputs, arg); + Line::cast_from(scalar) + } + Arg::Literal(val, _precision) => Line::cast_from(val.runtime()), + Arg::InputReshaped { + original, shape, .. + } => match comptime![original.as_ref().clone()] { + Arg::Input(pos, precision, layout) => read_input( + inputs, + outputs, + pos, + ref_pos, + layout, + precision, + config, + comptime![Some(shape)], + ), + _ => comptime![panic!("Only input can be reshaped")], + }, + } +} + +#[cube] +pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) -> C { + match arg { Arg::Scalar(pos, precision) => match comptime![precision] { - ElemwisePrecision::F32 => Line::cast_from(*inputs.s_f32.index(pos)), - ElemwisePrecision::F16 => Line::cast_from(*inputs.s_f16.index(pos)), - ElemwisePrecision::BF16 => Line::cast_from(*inputs.s_bf16.index(pos)), - ElemwisePrecision::U64 => Line::cast_from(*inputs.s_u64.index(pos)), - ElemwisePrecision::U32 => Line::cast_from(*inputs.s_u32.index(pos)), - ElemwisePrecision::U16 => Line::cast_from(*inputs.s_u16.index(pos)), - ElemwisePrecision::U8 => Line::cast_from(*inputs.s_u8.index(pos)), - ElemwisePrecision::I64 => Line::cast_from(*inputs.s_i64.index(pos)), - ElemwisePrecision::I32 => Line::cast_from(*inputs.s_i32.index(pos)), - ElemwisePrecision::I16 => Line::cast_from(*inputs.s_i16.index(pos)), - ElemwisePrecision::I8 => Line::cast_from(*inputs.s_i8.index(pos)), + ElemwisePrecision::F32 => C::cast_from(*inputs.s_f32.index(pos)), + ElemwisePrecision::F16 => C::cast_from(*inputs.s_f16.index(pos)), + ElemwisePrecision::BF16 => C::cast_from(*inputs.s_bf16.index(pos)), + ElemwisePrecision::U64 => C::cast_from(*inputs.s_u64.index(pos)), + ElemwisePrecision::U32 => C::cast_from(*inputs.s_u32.index(pos)), + ElemwisePrecision::U16 => C::cast_from(*inputs.s_u16.index(pos)), + ElemwisePrecision::U8 => C::cast_from(*inputs.s_u8.index(pos)), + ElemwisePrecision::I64 => C::cast_from(*inputs.s_i64.index(pos)), + ElemwisePrecision::I32 => C::cast_from(*inputs.s_i32.index(pos)), + ElemwisePrecision::I16 => C::cast_from(*inputs.s_i16.index(pos)), + ElemwisePrecision::I8 => C::cast_from(*inputs.s_i8.index(pos)), _ => comptime![panic!("Unsupported precision {precision:?}")], }, - Arg::Literal(val, _precision) => Line::cast_from(val.runtime()), + _ => comptime![panic!("Not a scalar")], + } +} + +#[cube] +pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { + match arg { + Arg::ScalarShape(pos) => { + let offset = comptime![inputs.s_u32.len() - pos - 1]; + *inputs.s_u32.index(offset) + } + _ => comptime![panic!("Not a scalar shape")], } } @@ -59,6 +103,7 @@ pub fn read_input( #[comptime] layout: LayoutInfo, #[comptime] precision: ElemwisePrecision, #[comptime] config: &ElemwiseConfig, + #[comptime] shape: Option>, ) -> Line { match comptime![precision] { ElemwisePrecision::F32 => { @@ -66,7 +111,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -75,7 +120,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -84,7 +129,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -93,7 +138,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -102,7 +147,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -111,7 +156,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -120,7 +165,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -129,7 +174,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -138,7 +183,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -147,7 +192,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -156,7 +201,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -180,7 +225,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -189,7 +234,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -198,7 +243,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -207,7 +252,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -216,7 +261,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -225,7 +270,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -234,7 +279,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -243,7 +288,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -252,7 +297,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -261,7 +306,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -270,7 +315,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -296,7 +341,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_f32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -306,7 +353,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_f16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -316,7 +365,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_bf16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -326,7 +377,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u64.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -336,7 +389,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -346,7 +401,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -356,7 +413,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u8.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -366,7 +425,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i64.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -376,7 +437,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -386,7 +449,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -396,7 +461,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i8.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -428,99 +495,100 @@ fn get_offset( tensor: &Tensor>, pos: u32, #[comptime] config: &ElemwiseConfig, + #[comptime] shape: Option>, ) -> u32 { - match comptime![config.ref_layout] { + match comptime![config.ref_layout.clone()] { Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = inputs.t_f32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::F16 => { let layout = inputs.t_f16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = inputs.t_bf16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U64 => { let layout = inputs.t_u64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U32 => { let layout = inputs.t_u32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U16 => { let layout = inputs.t_u16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U8 => { let layout = inputs.t_u8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I64 => { let layout = inputs.t_i64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I32 => { let layout = inputs.t_i32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I16 => { let layout = inputs.t_i16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I8 => { let layout = inputs.t_i8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = outputs.t_f32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::F16 => { let layout = outputs.t_f16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = outputs.t_bf16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U64 => { let layout = outputs.t_u64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U32 => { let layout = outputs.t_u32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U16 => { let layout = outputs.t_u16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U8 => { let layout = outputs.t_u8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I64 => { let layout = outputs.t_i64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I32 => { let layout = outputs.t_i32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I16 => { let layout = outputs.t_i16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I8 => { let layout = outputs.t_i8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, @@ -692,3 +760,97 @@ pub fn global_stride( _ => comptime![panic!("Unsupported precision {precision:?}")], } } + +/// Returns the offset of the tensor corresponding to the layout tensor. +#[cube] +fn index_offset_with_layout( + inputs: &GlobalArgs, + tensor: &Tensor>, + layout: &Tensor>, + index: u32, + #[comptime] rank: u32, + #[comptime] shape: Option>, +) -> u32 { + match comptime![shape.clone()] { + Some(shape) => { + let index = reshaped_index(inputs, layout, index, rank, shape); + reshaped_index_to_original_index(tensor, index, rank) + } + None => { + let offset_ref = index * layout.line_size(); + let mut offset = 0u32; + + for i in 0u32..rank { + let ogwl = offset_ref / layout.stride(i); + offset += ogwl % tensor.shape(i) * tensor.stride(i); + } + + offset / tensor.line_size() + } + } +} + +#[cube] +fn reshaped_index( + inputs: &GlobalArgs, + layout: &Tensor>, + index: u32, + #[comptime] rank: u32, + #[comptime] shape: Sequence, +) -> u32 { + let index = index * layout.line_size(); + + let mut offset = 0u32; + let mut stride_curr = 1u32; + + #[unroll] + for r in 0..rank { + let i = comptime![reverse_index(rank, r)]; + let arg = comptime![shape.index(i.clone())]; + let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); + + let ogwl = index / layout.stride(i); + offset += ogwl % shape_i * stride_curr; + + stride_curr *= shape_i; + } + + offset +} + +#[cube] +fn reshaped_index_to_original_index( + original: &Tensor>, + index_reshaped: u32, + #[comptime] rank: u32, +) -> u32 { + let mut remaining = index_reshaped; + let mut offset = 0; + + #[unroll] + for r in 0..rank { + let i = comptime![reverse_index(rank, r)]; + let shape = original.shape(comptime![i.clone()]); + let stride = original.stride(i); + + let coordinate = remaining % shape; + + remaining /= shape; + offset += coordinate * stride; + } + + offset / original.line_size() +} + +fn reverse_index>>( + rank: u32, + iter: Elem, +) -> ExpandElementTyped { + let elem = iter.into(); + let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); + let result = rank - elem - 1; + let scalar: Variable = result.into(); + let expand: ExpandElement = ExpandElement::Plain(scalar); + + expand.into() +} diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 36c8e402a0..d189badcdf 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -4,20 +4,23 @@ use cubecl::prelude::*; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to an [elemwise operation](ElemwiseOp). pub enum Arg { Input(u32, ElemwisePrecision, LayoutInfo), Local(u32, ElemwisePrecision), Output(u32, ElemwisePrecision, LayoutInfo), Scalar(u32, ElemwisePrecision), + ScalarShape(u32), /// Only constant that can be encoded into an u32 can be used as literal. Literal(u32, ElemwisePrecision), + InputReshaped { + original: Box, + shape: Sequence, + }, } -#[derive( - CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, -)] +#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Layout information. pub enum LayoutInfo { /// The layout if the same as the reference. @@ -36,6 +39,8 @@ impl Arg { Arg::Output(_, p, _) => p, Arg::Scalar(_, p) => p, Arg::Literal(_, p) => p, + Arg::ScalarShape(_) => return ElemwisePrecision::U32, + Arg::InputReshaped { original, .. } => return original.precision(), } } } @@ -88,6 +93,14 @@ pub enum ElemwiseOp { } #[derive(CubeLaunch)] +pub struct ReshapedTensor { + #[cube(comptime)] + original: Arg, + #[cube(comptime)] + shape: Sequence, +} + +#[derive(CubeLaunch, Default)] /// Global arguments that are used for fusing [element wise operations](ElemwiseOp). pub struct GlobalArgs { pub t_f32: Sequence>>, @@ -114,6 +127,34 @@ pub struct GlobalArgs { pub s_u8: Sequence, } +impl Default for GlobalArgsLaunch<'_, R> { + fn default() -> Self { + Self::new( + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ) + } +} impl GlobalArgsLaunch<'_, R> { /// Get the shape of the given [argument](Arg). /// diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index 269ba1f3b8..cf79a0fb6b 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -33,8 +33,8 @@ pub fn fuse_on_write( // Write the values given as arguments. #[unroll] for i in 0..write_args.len() { - let arg = comptime![*write_args.index(i)]; - let val = write_values.find(arg); + let arg = comptime![write_args.index(i).clone()]; + let val = write_values.find(comptime![arg.clone()]); write::(inputs, outputs, &mut locals, write_pos, val, arg, config); } @@ -404,7 +404,9 @@ pub fn fuse_on_write( ElemwisePrecision::U8 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } - _ => comptime![panic!("Unsupported precision {op:?}")], + ElemwisePrecision::Bool => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } }, ElemwiseOp::Greater(op) => match op.lhs.precision() { ElemwisePrecision::F32 => { @@ -677,7 +679,7 @@ pub fn fuse_on_write( out, config, ), - _ => comptime![panic!("Unsupported precision {op:?}")], + _ => comptime![panic!("Unsupported precision")], }, } } diff --git a/crates/burn-jit/src/fusion/on_write/mod.rs b/crates/burn-jit/src/fusion/on_write/mod.rs index d40d682dcd..69bbc724d1 100644 --- a/crates/burn-jit/src/fusion/on_write/mod.rs +++ b/crates/burn-jit/src/fusion/on_write/mod.rs @@ -2,6 +2,6 @@ pub(crate) mod builder; pub(crate) mod io; pub(crate) mod ir; pub(crate) mod kernel; +pub(crate) mod settings; pub mod trace; -pub(crate) mod trace_builder; diff --git a/crates/burn-jit/src/fusion/on_write/settings.rs b/crates/burn-jit/src/fusion/on_write/settings.rs new file mode 100644 index 0000000000..761b98e8b2 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/settings.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +/// Controls which operations can be fused. +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FuseSettings { + /// Enables broadcasting of shapes. + pub broadcast: bool, + /// Enables output shape updates. + /// + /// When broadcast is enabled, the output shape can become bigger after a fusion, + /// therefore an update is needed. + pub output_shape_updates: bool, + /// Enables mix vectorization factor. + /// + /// Useful when the last dimension is broadcasted for one of the tensors, which would limit the + /// vectorization factor to be 1 without this setting enabled. + pub mix_vectorization: bool, + /// Enables the reuse of input buffers. + pub inplace: bool, +} diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs deleted file mode 100644 index 2c29d05ce8..0000000000 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ /dev/null @@ -1,647 +0,0 @@ -use crate::{ - fusion::{on_write::ir::LayoutInfo, strides_dyn_rank, JitFusionHandle}, - BoolElement, JitRuntime, -}; - -use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; -use burn_fusion::stream::Context; -use burn_tensor::{ - repr::{TensorDescription, TensorId, TensorStatus}, - DType, -}; -use cubecl::{ir::Elem, prelude::*}; -use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; - -#[derive(new, Clone, Serialize, Deserialize, Debug)] -/// Trace containing all element wise operations as well as reads and writes. -pub struct FuseOnWriteTrace { - outputs: RegisteredTensors, - inputs: RegisteredTensors, - scalars: BTreeMap, - ops: Vec, - reads: BTreeMap, - writes: BTreeMap, - inputs_unhandled: Vec, -} - -/// A trace runner is responsible for determining the vectorization factor as well as launching -/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) -/// with a provided [element wise config](ElemwiseConfig). -pub trait TraceRunner { - /// The error that might happen while running the trace. - type Error; - - /// Run the trace. - fn run<'a>( - &'a self, - client: &'a ComputeClient, - inputs: GlobalArgsLaunch<'a, R>, - outputs: GlobalArgsLaunch<'a, R>, - config: &'a ElemwiseConfig, - ) -> Result<(), Self::Error>; - - /// The vectorization factor for all inputs and outputs. - fn vectorization<'a>( - handles_inputs: impl Iterator>, - inputs: impl Iterator, - outputs: impl Iterator, - ) -> u8 { - // The default version uses the last dimension as vectorization axis and assumes a - // perpendicular contiguous line. - - let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { - let rank = handle.strides.len(); - - // Last dimension strides should be 1, otherwise vecX won't be contiguous. - if handle.strides[rank - 1] != 1 { - return 1; - } - - for s in R::line_size_elem(&desc.dtype.into()) { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % s as usize == 0 { - return s; - } - } - - 1 - }; - - let vectorization_output = |desc: &TensorDescription| { - let rank = desc.shape.len(); - - for s in R::line_size_elem(&desc.dtype.into()) { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % s as usize == 0 { - return s; - } - } - - 1 - }; - - let mut output = u8::MAX; - - for (handle, tensor) in handles_inputs.zip(inputs) { - output = Ord::min(vectorization_input(handle, tensor), output); - } - - for tensor in outputs { - output = Ord::min(vectorization_output(tensor), output); - } - - output - } -} - -#[derive(Debug)] -struct LaunchAnalysis<'a, R: JitRuntime> { - potential_inplaces: Vec>, - global_inputs: Vec, - global_outputs: Vec, - handle_inputs: Vec>, - handle_outputs: Vec>, - reference: Option, - reads: BTreeMap, - writes: BTreeMap, - rank: usize, - vectorization: u8, -} - -#[derive(Debug)] -enum HandleOutput { - Alias { - input_pos: usize, - precision: ElemwisePrecision, - }, - Owned { - global_id: TensorId, - precision: ElemwisePrecision, - handle: JitFusionHandle, - global_shape: Vec, - }, -} - -#[derive(Debug)] -struct HandleInput { - relative_id: TensorId, - global_id: TensorId, - precision: ElemwisePrecision, - handle: JitFusionHandle, - global_shape: Vec, -} - -#[derive(Debug)] -struct Reference { - layout: Arg, - shape: Vec, - strides: Vec, -} - -#[derive(Debug)] -struct PotentialInplace<'a> { - input_pos: usize, - tensor_relative: &'a TensorDescription, - strides: Vec, -} - -impl FuseOnWriteTrace { - /// Run a trace with the given [runner](TraceRunner). - pub fn run>( - &self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - runner: &Runner, - ) -> Result<(), Runner::Error> { - let analysis = self.analyse::(client, device, context); - - let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); - let outputs = - self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); - - let mut ops = Sequence::new(); - for op in analysis.reads.into_values() { - ops.push(op); - } - - for op in self.ops.iter() { - ops.push(op.clone()); - } - - for op in analysis.writes.into_values() { - ops.push(op); - } - - let config = ElemwiseConfig { - rank: analysis.rank as u32, - ref_layout: analysis - .reference - .expect("An output should exist for the fused kernel") - .layout, - ops, - }; - - match Runner::run(runner, client, inputs, outputs, &config) { - Err(err) => { - self.rollback(context, analysis.handle_inputs, analysis.handle_outputs); - Err(err) - } - Ok(val) => Ok(val), - } - } - - fn rollback( - &self, - context: &mut Context<'_, JitFusionHandle>, - handle_inputs: Vec>, - handle_outputs: Vec>, - ) { - for input in handle_inputs { - context - .handles - .register_handle(input.global_id, input.handle); - } - for output in handle_outputs { - if let HandleOutput::Owned { - global_id, handle, .. - } = output - { - context.handles.register_handle(global_id, handle); - } - } - } - - fn analyse<'a, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( - &'a self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - ) -> LaunchAnalysis<'a, R> { - let mut analysis = LaunchAnalysis { - potential_inplaces: Vec::new(), - global_inputs: Vec::new(), - global_outputs: Vec::new(), - handle_inputs: Vec::new(), - handle_outputs: Vec::new(), - reference: None, - reads: self.reads.clone(), - writes: self.writes.clone(), - rank: 1, - vectorization: 1, - }; - - self.analyse_inputs(context, &mut analysis); - self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); - - analysis.vectorization = Runner::vectorization( - analysis.handle_inputs.iter().map(|item| &item.handle), - analysis.global_inputs.iter(), - analysis.global_outputs.iter(), - ); - - analysis - } - - fn analyse_inputs<'a, R: JitRuntime>( - &'a self, - context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, R>, - ) { - for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); - // Important to take the status of the relative graph and not - // the global graph, since the status of the global graph - // might be of a later operation on the same tensor id. - let status = &tensor_relative.status; - let handle = context.handles.get_handle(&tensor_global.id, status); - - if status == &TensorStatus::ReadWrite - && handle.handle.can_mut() - && !self.inputs_unhandled.contains(&tensor_relative.id) - { - analysis.potential_inplaces.push(PotentialInplace { - input_pos: i, - tensor_relative, - strides: handle.strides.clone(), - }); - } - - analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); - analysis.handle_inputs.push(HandleInput { - precision, - handle, - relative_id: tensor_relative.id, - global_id: tensor_global.id, - global_shape: tensor_global.shape.clone(), - }); - analysis.global_inputs.push(tensor_global); - } - } - - fn analyse_outputs<'a, R: JitRuntime, BT: BoolElement>( - &'a self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, R>, - ) { - for (precision, tensor_relative) in self.outputs.iter() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); - let strides = strides_dyn_rank(&tensor_global.shape); - - if let Some(index) = analysis - .potential_inplaces - .iter() - .enumerate() - .find(|(_pos, pi)| { - pi.tensor_relative.dtype == tensor_global.dtype - && pi.tensor_relative.shape == tensor_relative.shape - && pi.strides == strides - }) - .map(|(pos, _)| pos) - { - let potential_inplace = analysis.potential_inplaces.remove(index); - let handle_input = analysis - .handle_inputs - .get(potential_inplace.input_pos) - .unwrap(); - - if analysis.reference.is_none() { - let index_input = self - .inputs - .get_index(precision, potential_inplace.tensor_relative.id) - .unwrap(); - - analysis.reference = Some(Reference { - layout: Arg::Input(index_input as u32, precision, LayoutInfo::IsRef), - shape: tensor_global.shape.clone(), - strides: handle_input.handle.strides.clone(), - }); - - if let Some(ElemwiseOp::Assign(op)) = - analysis.reads.get_mut(&handle_input.relative_id) - { - op.input.add_layout_info(LayoutInfo::IsRef); - }; - - if let Some(ElemwiseOp::Assign(op)) = - analysis.writes.get_mut(&tensor_relative.id) - { - op.out.add_layout_info(LayoutInfo::IsRef); - }; - } - - context - .handles - .register_handle(tensor_global.id, handle_input.handle.clone()); - analysis.handle_outputs.push(HandleOutput::Alias { - input_pos: potential_inplace.input_pos, - precision, - }); - analysis.global_outputs.push(tensor_global); - } else { - if analysis.reference.is_none() { - analysis.reference = Some(Reference { - layout: Arg::Output(0, precision, LayoutInfo::IsRef), - shape: tensor_global.shape.clone(), - strides: strides.clone(), - }); - - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&tensor_relative.id).unwrap() - { - op.out.add_layout_info(LayoutInfo::IsRef); - }; - } else if let Some(reference) = analysis.reference.as_ref() { - if reference.strides == strides && reference.shape == tensor_global.shape { - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&tensor_relative.id).unwrap() - { - op.out.add_layout_info(LayoutInfo::SameAsRef); - }; - } - } - - // We encode bool tensors as `B`. - let dtype = match tensor_global.dtype { - DType::Bool => BT::dtype(), - _ => tensor_global.dtype, - }; - let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); - - let handle = JitFusionHandle { - client: client.clone(), - handle: client.empty(size), - device: device.clone(), - strides, - dtype, - }; - - analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); - context - .handles - .register_handle(tensor_global.id, handle.clone()); - - analysis.handle_outputs.push(HandleOutput::Owned { - precision, - handle, - global_shape: tensor_global.shape.clone(), - global_id: tensor_global.id, - }); - analysis.global_outputs.push(tensor_global); - } - } - - Self::add_layout_info_inputs(analysis); - } - - fn add_layout_info_inputs(analysis: &mut LaunchAnalysis<'_, R>) { - for hi in analysis.handle_inputs.iter() { - if let Some(reference) = analysis.reference.as_ref() { - if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { - if let Some(ElemwiseOp::Assign(op)) = analysis.reads.get_mut(&hi.relative_id) { - op.input.add_layout_info(LayoutInfo::SameAsRef); - } - } - } - } - } - - fn register_inputs<'h, R: JitRuntime>( - &self, - context: &mut Context<'_, JitFusionHandle>, - handle_inputs: &'h [HandleInput], - vectorization: u8, - ) -> GlobalArgsLaunch<'h, R> { - let mut inputs = GlobalArgsLaunch::new( - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - ); - - for hi in handle_inputs.iter() { - let arg = hi.handle.as_tensor_arg(&hi.global_shape, vectorization); - match hi.precision { - ElemwisePrecision::F32 => inputs.t_f32.push(arg), - ElemwisePrecision::F16 => inputs.t_f16.push(arg), - ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), - ElemwisePrecision::I64 => inputs.t_i64.push(arg), - ElemwisePrecision::I32 => inputs.t_i32.push(arg), - ElemwisePrecision::I16 => inputs.t_i16.push(arg), - ElemwisePrecision::I8 => inputs.t_i8.push(arg), - ElemwisePrecision::U64 => inputs.t_u64.push(arg), - ElemwisePrecision::U32 => inputs.t_u32.push(arg), - ElemwisePrecision::U16 => inputs.t_u16.push(arg), - ElemwisePrecision::U8 => inputs.t_u8.push(arg), - _ => panic!("Unsupported input precision {:?}", hi.precision), - }; - } - - for (precision, count) in self.scalars.iter() { - for i in 0..(*count as usize) { - match precision { - ElemwisePrecision::F32 => { - inputs.s_f32.push(ScalarArg::new(context.scalar_f32[i])) - } - ElemwisePrecision::F16 => { - inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) - } - ElemwisePrecision::BF16 => { - inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) - } - ElemwisePrecision::I64 => { - inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) - } - ElemwisePrecision::I32 => { - inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) - } - ElemwisePrecision::I16 => { - inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) - } - ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), - ElemwisePrecision::U64 => { - inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) - } - ElemwisePrecision::U32 => { - inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) - } - ElemwisePrecision::U16 => { - inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) - } - ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), - ElemwisePrecision::Bool => todo!(), - } - } - } - - inputs - } - - fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( - &self, - handle_outputs: &'s [HandleOutput], - vectorization: u8, - ) -> GlobalArgsLaunch<'s, R> { - let mut outputs = GlobalArgsLaunch::new( - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - ); - for item in handle_outputs.iter() { - match item { - HandleOutput::Alias { - input_pos, - precision, - } => match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), - _ => todo!(), - }, - HandleOutput::Owned { - precision, - handle, - global_shape, - .. - } => { - let arg = handle.as_tensor_arg(global_shape, vectorization); - - match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(arg), - ElemwisePrecision::F16 => outputs.t_f16.push(arg), - ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), - ElemwisePrecision::I64 => outputs.t_i64.push(arg), - ElemwisePrecision::I32 => outputs.t_i32.push(arg), - ElemwisePrecision::I16 => outputs.t_i16.push(arg), - ElemwisePrecision::I8 => outputs.t_i8.push(arg), - ElemwisePrecision::U64 => outputs.t_u64.push(arg), - ElemwisePrecision::U32 => outputs.t_u32.push(arg), - ElemwisePrecision::U16 => outputs.t_u16.push(arg), - ElemwisePrecision::U8 => outputs.t_u8.push(arg), - ElemwisePrecision::Bool => match BT::dtype() { - DType::U32 => outputs.t_u32.push(arg), - DType::U8 => outputs.t_u8.push(arg), - _ => todo!(), - }, - }; - } - } - } - - outputs - } -} - -#[derive(Default, Clone, Serialize, Deserialize, Debug)] -pub struct RegisteredTensors { - tensors: BTreeMap>, -} - -impl RegisteredTensors { - pub fn iter(&self) -> impl Iterator { - self.tensors.iter().flat_map(|(precision, descriptions)| { - descriptions.iter().map(|desc| (*precision, desc)) - }) - } - - pub fn len(&self) -> usize { - self.tensors.values().map(|v| v.len()).sum() - } - - pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { - self.tensors.get(&precision).and_then(|items| { - items - .iter() - .enumerate() - .find(|(_pos, tensor)| tensor.id == tensor_id) - .map(|(pos, _)| pos) - }) - } - - pub fn get_all(&self, precision: ElemwisePrecision) -> &[TensorDescription] { - self.tensors - .get(&precision) - .map(|v| v.as_slice()) - .unwrap_or(&[]) - } - - pub fn get( - &self, - precision: ElemwisePrecision, - tensor_id: TensorId, - ) -> Option<&TensorDescription> { - self.get_all(precision) - .iter() - .find(|desc| desc.id == tensor_id) - } - - pub fn insert(&mut self, precision: ElemwisePrecision, tensor: TensorDescription) -> u32 { - if let Some(tensors) = self.tensors.get_mut(&precision) { - let position = tensors.len() as u32; - tensors.push(tensor); - position - } else { - self.tensors.insert(precision, vec![tensor]); - 0 - } - } - - pub fn update(&mut self, precision: ElemwisePrecision, tensor: &TensorDescription) { - if let Some(tensors) = self.tensors.get_mut(&precision) { - if let Some(tensor_old) = tensors - .iter_mut() - .find(|tensor_old| tensor_old.id == tensor.id) - { - tensor_old.status = tensor.status.clone(); - } - } - } -} diff --git a/crates/burn-jit/src/fusion/on_write/trace/base.rs b/crates/burn-jit/src/fusion/on_write/trace/base.rs new file mode 100644 index 0000000000..f017d6dc1b --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/base.rs @@ -0,0 +1,164 @@ +use crate::{fusion::JitFusionHandle, BoolElement, JitRuntime}; + +use super::{ + super::{ + ir::{ElemwiseOp, ElemwisePrecision}, + settings::FuseSettings, + }, + executor::LaunchPlanExecutor, + input::InputPlanner, + output::OutputPlanner, + vectorization::VectorizationPlanner, + HandleInput, HandleOutput, LaunchPlan, TraceRunner, +}; +use burn_fusion::stream::Context; +use burn_tensor::repr::{TensorDescription, TensorId}; +use cubecl::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +#[derive(Clone, Serialize, Deserialize, Debug)] +/// Trace containing all element wise operations as well as reads and writes. +pub struct FuseOnWriteTrace { + pub outputs: RegisteredTensors, + pub inputs: RegisteredTensors, + pub settings: FuseSettings, + pub scalars: BTreeMap, + pub reshapes: Vec, + pub shape_ref: Vec, + pub ops: Vec, + pub reads: BTreeMap>, + pub writes: BTreeMap, + pub inputs_unhandled: Vec, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Reshape { + pub reshaped: TensorId, + pub original: TensorId, +} + +impl FuseOnWriteTrace { + /// Run a trace with the given [runner](TraceRunner). + pub fn run>( + &self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + runner: &Runner, + ) -> Result<(), Runner::Error> { + let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len()); + + InputPlanner::::new( + &self.inputs, + &self.inputs_unhandled, + &self.reshapes, + &self.shape_ref, + &self.settings, + ) + .run(context, &mut plan); + + OutputPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) + .run::(client, device, context, &mut plan); + + VectorizationPlanner::::new(&self.reshapes, &self.reads, &self.settings) + .run::(context, &mut plan); + + match LaunchPlanExecutor::::new(&self.scalars, &self.reshapes, &self.ops) + .execute::<_, BT>(client, runner, context, plan) + { + Err(err) => { + self.rollback(context, err.handles_input, err.handles_output); + Err(err.runner_error) + } + Ok(val) => Ok(val), + } + } + + fn rollback( + &self, + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: Vec>, + handle_outputs: Vec>, + ) { + for input in handle_inputs { + context + .handles + .register_handle(input.global_id, input.handle); + } + for output in handle_outputs { + if let HandleOutput::Owned { + global_id, handle, .. + } = output + { + context.handles.register_handle(global_id, handle); + } + } + } +} + +#[derive(Default, Clone, Serialize, Deserialize, Debug)] +pub struct RegisteredTensors { + tensors: BTreeMap>, +} + +impl RegisteredTensors { + pub fn iter(&self) -> impl Iterator { + self.tensors.iter().flat_map(|(precision, descriptions)| { + descriptions.iter().map(|desc| (*precision, desc)) + }) + } + + pub fn len(&self) -> usize { + self.tensors.values().map(|v| v.len()).sum() + } + + pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { + self.tensors.get(&precision).and_then(|items| { + items + .iter() + .enumerate() + .find(|(_pos, tensor)| tensor.id == tensor_id) + .map(|(pos, _)| pos) + }) + } + + pub fn get_all(&self, precision: ElemwisePrecision) -> &[TensorDescription] { + self.tensors + .get(&precision) + .map(|v| v.as_slice()) + .unwrap_or(&[]) + } + + pub fn get( + &self, + precision: ElemwisePrecision, + tensor_id: TensorId, + ) -> Option<&TensorDescription> { + self.get_all(precision) + .iter() + .find(|desc| desc.id == tensor_id) + } + + pub fn insert(&mut self, precision: ElemwisePrecision, tensor: TensorDescription) -> u32 { + if let Some(tensors) = self.tensors.get_mut(&precision) { + let position = tensors.len() as u32; + tensors.push(tensor); + position + } else { + self.tensors.insert(precision, vec![tensor]); + 0 + } + } + + pub fn update(&mut self, precision: ElemwisePrecision, tensor: &TensorDescription) { + if let Some(tensors) = self.tensors.get_mut(&precision) { + if let Some(tensor_old) = tensors + .iter_mut() + .find(|tensor_old| tensor_old.id == tensor.id) + { + tensor_old.status = tensor.status.clone(); + } + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace/builder.rs similarity index 79% rename from crates/burn-jit/src/fusion/on_write/trace_builder.rs rename to crates/burn-jit/src/fusion/on_write/trace/builder.rs index c37237ae4e..896d7f3b1c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/builder.rs @@ -1,33 +1,39 @@ -use super::{ +use super::super::{ ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, LayoutInfo, UnaryElemwiseArgs}, - trace::{FuseOnWriteTrace, RegisteredTensors}, + settings::FuseSettings, }; +use super::{FuseOnWriteTrace, RegisteredTensors, Reshape}; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, DType, Element, }; +use cubecl::prelude::Sequence; use std::collections::BTreeMap; #[derive(Clone)] pub struct FuseOnWriteTraceBuilder { locals: Locals, outputs: RegisteredTensors, + settings: FuseSettings, inputs: RegisteredTensors, scalars: BTreeMap, + reshapes: Vec, ops: Vec, - reads: BTreeMap, + reads: BTreeMap>, pub bool_precision: ElemwisePrecision, outputs_unhandled: Vec, inputs_unhandled: Vec, } impl FuseOnWriteTraceBuilder { - pub fn new(bool_precision: ElemwisePrecision) -> Self { + pub fn new(bool_precision: ElemwisePrecision, settings: FuseSettings) -> Self { Self { locals: Locals::default(), outputs: RegisteredTensors::default(), + settings, inputs: RegisteredTensors::default(), scalars: BTreeMap::default(), + reshapes: Vec::new(), ops: Vec::new(), reads: BTreeMap::new(), bool_precision, @@ -54,7 +60,7 @@ impl FuseOnWriteTraceBuilder { pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { let arg = self.output(tensor); - self.outputs_unhandled.push(arg); + self.outputs_unhandled.push(arg.clone()); arg } @@ -96,10 +102,19 @@ impl FuseOnWriteTraceBuilder { let out = self.locals.create(precision, tensor.id); let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); - self.reads.insert( - tensor.id, - ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }), - ); + let reads = if let std::collections::btree_map::Entry::Vacant(e) = + self.reads.entry(tensor.id) + { + e.insert(Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; + + reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + })); out } @@ -127,6 +142,77 @@ impl FuseOnWriteTraceBuilder { } } + pub fn input_reshaped( + &mut self, + tensor: &TensorDescription, + output: &TensorDescription, + ) -> Option { + let precision = tensor.dtype.into(); + + // Bool tensors are encoded as bool_precision. + let precision_input = match precision { + ElemwisePrecision::Bool => self.bool_precision, + _ => precision, + }; + + let input_index = match self.locals.get(precision, tensor.id) { + Some(_) => { + // Can't fused an already fused input. + if self.outputs.get(precision_input, tensor.id).is_some() { + return None; + } + + match self.inputs.get_index(precision_input, tensor.id) { + Some(index) => { + self.inputs.update(precision_input, tensor); + index as u32 + } + None => { + return None; + } + } + } + None => self.inputs.insert(precision_input, tensor.clone()), + }; + + let out = self.locals.create(precision, tensor.id); + let original = Arg::Input(input_index, precision_input, LayoutInfo::Unknown); + + let mut shape = Sequence::new(); + + let index = self.reshapes.len(); + self.reshapes.push(Reshape { + reshaped: output.id, + original: tensor.id, + }); + let rank = output.shape.len(); + + for i in 0..output.shape.len() { + let id = index * rank + i; + shape.push(Arg::ScalarShape(id as u32)); + } + + let input = Arg::InputReshaped { + original: Box::new(original), + shape, + }; + + let reads = + if let std::collections::btree_map::Entry::Vacant(e) = self.reads.entry(tensor.id) { + e.insert(Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; + + reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + })); + + Some(out) + } + pub fn scalar(&mut self, _: &E, dtype: DType) -> Arg { let precision = dtype.into(); @@ -143,7 +229,7 @@ impl FuseOnWriteTraceBuilder { Arg::Scalar(new_index, precision) } - pub fn build(&self) -> FuseOnWriteTrace { + pub fn build(&self, shape_ref: Vec) -> FuseOnWriteTrace { let inputs = self.inputs.clone(); let outputs = self.output_tensors(); let ops = self.ops.clone(); @@ -165,16 +251,22 @@ impl FuseOnWriteTraceBuilder { ); } - // Current problem is that I need btreemap instead of sequences. - FuseOnWriteTrace::new( + let reshapes = self.reshapes.clone(); + let settings = self.settings; + let inputs_unhandled = self.inputs_unhandled.clone(); + + FuseOnWriteTrace { outputs, inputs, + settings, scalars, + reshapes, + shape_ref, ops, reads, writes, - self.inputs_unhandled.clone(), - ) + inputs_unhandled, + } } fn output_tensors(&self) -> RegisteredTensors { @@ -334,8 +426,10 @@ impl FuseOnWriteTraceBuilder { }; // For all operators, mark their local tensor id in the proper set. - for (_, op) in self.reads.iter() { - mark_op(op); + for (_, ops) in self.reads.iter() { + for op in ops { + mark_op(op); + } } for op in self.ops.iter() { diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs new file mode 100644 index 0000000000..749e74340e --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -0,0 +1,228 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use burn_fusion::stream::Context; +use burn_tensor::DType; +use cubecl::{ + client::ComputeClient, + prelude::{ScalarArg, Sequence, TensorArg}, +}; + +use super::{HandleInput, HandleOutput, LaunchPlan, Reshape, TraceRunner}; +use crate::{ + fusion::{ + on_write::ir::{ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}, + JitFusionHandle, + }, + BoolElement, JitRuntime, +}; + +/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). +pub struct LaunchPlanExecutor<'a, R: JitRuntime> { + scalars: &'a BTreeMap, + reshapes: &'a Vec, + ops: &'a Vec, + _r: PhantomData, +} + +#[derive(new)] +pub struct ExecutionError> { + pub runner_error: Runner::Error, + pub handles_input: Vec>, + pub handles_output: Vec>, +} + +impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { + pub fn new( + scalars: &'a BTreeMap, + reshapes: &'a Vec, + ops: &'a Vec, + ) -> Self { + Self { + scalars, + reshapes, + ops, + _r: PhantomData, + } + } + + pub fn execute, BT: BoolElement>( + self, + client: &ComputeClient, + runner: &Runner, + context: &mut Context<'_, JitFusionHandle>, + plan: LaunchPlan<'a, R>, + ) -> Result<(), ExecutionError> { + let reference = match plan.reference { + Some(reference) => reference, + None => { + if plan.writes.is_empty() { + // Nothing to write, can skip execution. + return Ok(()); + } else { + panic!("An output should exist for the fused kernel") + } + } + }; + + let inputs = self.register_inputs(context, &plan.handle_inputs); + let outputs = self.register_outputs::(&plan.handle_outputs); + + let mut ops = Sequence::::new(); + + for read_ops in plan.reads.into_values() { + for op in read_ops { + ops.push(op); + } + } + + for op in self.ops.iter() { + ops.push(op.clone()); + } + + for op in plan.writes.into_values() { + ops.push(op); + } + + let config = ElemwiseConfig { + rank: plan.rank as u32, + ref_layout: reference.layout, + ops, + }; + + Runner::run(runner, client, inputs, outputs, &config) + .map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs)) + } + + fn register_inputs<'h>( + &self, + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: &'h [HandleInput], + ) -> GlobalArgsLaunch<'h, R> { + let mut inputs = GlobalArgsLaunch::default(); + + for hi in handle_inputs.iter() { + let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); + match hi.precision { + ElemwisePrecision::F32 => inputs.t_f32.push(arg), + ElemwisePrecision::F16 => inputs.t_f16.push(arg), + ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), + ElemwisePrecision::I64 => inputs.t_i64.push(arg), + ElemwisePrecision::I32 => inputs.t_i32.push(arg), + ElemwisePrecision::I16 => inputs.t_i16.push(arg), + ElemwisePrecision::I8 => inputs.t_i8.push(arg), + ElemwisePrecision::U64 => inputs.t_u64.push(arg), + ElemwisePrecision::U32 => inputs.t_u32.push(arg), + ElemwisePrecision::U16 => inputs.t_u16.push(arg), + ElemwisePrecision::U8 => inputs.t_u8.push(arg), + _ => panic!("Unsupported input precision {:?}", hi.precision), + }; + } + + for (precision, count) in self.scalars.iter() { + for i in 0..(*count as usize) { + match precision { + ElemwisePrecision::F32 => { + inputs.s_f32.push(ScalarArg::new(context.scalar_f32[i])) + } + ElemwisePrecision::F16 => { + inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) + } + ElemwisePrecision::BF16 => { + inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) + } + ElemwisePrecision::I64 => { + inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) + } + ElemwisePrecision::I32 => { + inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) + } + ElemwisePrecision::I16 => { + inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) + } + ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), + ElemwisePrecision::U64 => { + inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) + } + ElemwisePrecision::U32 => { + inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) + } + ElemwisePrecision::U16 => { + inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) + } + ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), + ElemwisePrecision::Bool => todo!(), + } + } + } + + // Reshape values are pushed in reverse in the same scalar buffer for all `u32` + for relative in self.reshapes.iter().rev() { + let global = context.tensors.get(&relative.reshaped).unwrap(); + + for shape in global.shape.iter().rev() { + inputs.s_u32.push(ScalarArg::new(*shape as u32)) + } + } + + inputs + } + + fn register_outputs<'s, BT: BoolElement>( + &self, + handle_outputs: &'s [HandleOutput], + ) -> GlobalArgsLaunch<'s, R> { + let mut outputs = GlobalArgsLaunch::default(); + + for item in handle_outputs.iter() { + match item { + HandleOutput::Alias { + input_pos, + precision, + } => match precision { + ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), + _ => todo!(), + }, + HandleOutput::Owned { + precision, + handle, + global_shape, + vectorization, + .. + } => { + let arg = handle.as_tensor_arg(global_shape, *vectorization); + + match precision { + ElemwisePrecision::F32 => outputs.t_f32.push(arg), + ElemwisePrecision::F16 => outputs.t_f16.push(arg), + ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), + ElemwisePrecision::I64 => outputs.t_i64.push(arg), + ElemwisePrecision::I32 => outputs.t_i32.push(arg), + ElemwisePrecision::I16 => outputs.t_i16.push(arg), + ElemwisePrecision::I8 => outputs.t_i8.push(arg), + ElemwisePrecision::U64 => outputs.t_u64.push(arg), + ElemwisePrecision::U32 => outputs.t_u32.push(arg), + ElemwisePrecision::U16 => outputs.t_u16.push(arg), + ElemwisePrecision::U8 => outputs.t_u8.push(arg), + ElemwisePrecision::Bool => match BT::dtype() { + DType::U32 => outputs.t_u32.push(arg), + DType::U8 => outputs.t_u8.push(arg), + _ => todo!(), + }, + }; + } + } + } + + outputs + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/input.rs b/crates/burn-jit/src/fusion/on_write/trace/input.rs new file mode 100644 index 0000000000..a243fddac1 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/input.rs @@ -0,0 +1,86 @@ +use super::Reshape; +use crate::{ + fusion::{on_write::settings::FuseSettings, JitFusionHandle}, + JitRuntime, +}; +use burn_fusion::stream::Context; +use burn_tensor::repr::{TensorId, TensorStatus}; +use std::marker::PhantomData; + +use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; + +/// Fetch and register [input handles](HandleInput) and itendify potential inputs that +/// can be used inplace. +pub struct InputPlanner<'a, R: JitRuntime> { + inputs: &'a RegisteredTensors, + inputs_unhandled: &'a Vec, + reshapes: &'a Vec, + shape_ref: &'a Vec, + settings: &'a FuseSettings, + _r: PhantomData, +} + +impl<'a, R: JitRuntime> InputPlanner<'a, R> { + pub fn new( + inputs: &'a RegisteredTensors, + inputs_unhandled: &'a Vec, + reshapes: &'a Vec, + shape_ref: &'a Vec, + settings: &'a FuseSettings, + ) -> Self { + Self { + inputs, + settings, + inputs_unhandled, + reshapes, + shape_ref, + _r: PhantomData, + } + } + + pub fn run(self, context: &mut Context<'_, JitFusionHandle>, plan: &mut LaunchPlan<'a, R>) { + for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { + let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); + // Important to take the status of the relative graph and not + // the global graph, since the status of the global graph + // might be of a later operation on the same tensor id. + let status = &tensor_relative.status; + let mut handle = context.handles.get_handle(&tensor_global.id, status); + + if self.settings.inplace + && status == &TensorStatus::ReadWrite + && handle.handle.can_mut() + && !self.inputs_unhandled.contains(&tensor_relative.id) + && !self + .reshapes + .iter() + .any(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) + && self.shape_ref == &tensor_relative.shape + { + plan.potential_inplaces.push(PotentialInplace { + input_pos: i, + tensor_relative, + strides: handle.strides.clone(), + }); + } + + if tensor_global.shape.len() < plan.rank { + let num_elem: usize = tensor_global.shape.iter().product(); + for _ in 0..(plan.rank - tensor_global.shape.len()) { + tensor_global.shape.insert(0, 1); + handle.strides.insert(0, num_elem); + } + } + + plan.handle_inputs.push(HandleInput { + precision, + handle, + relative_id: tensor_relative.id, + global_id: tensor_global.id, + global_shape: tensor_global.shape.clone(), + vectorization: 1, + }); + plan.global_inputs.push(tensor_global); + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/mod.rs b/crates/burn-jit/src/fusion/on_write/trace/mod.rs new file mode 100644 index 0000000000..64de887986 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/mod.rs @@ -0,0 +1,14 @@ +pub(crate) mod executor; +pub(crate) mod input; +pub(crate) mod output; +pub(crate) mod vectorization; + +mod base; +mod builder; +mod plan; +mod runner; + +pub use base::*; +pub use builder::*; +pub use plan::*; +pub use runner::*; diff --git a/crates/burn-jit/src/fusion/on_write/trace/output.rs b/crates/burn-jit/src/fusion/on_write/trace/output.rs new file mode 100644 index 0000000000..0964974c7a --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/output.rs @@ -0,0 +1,390 @@ +use burn_fusion::stream::Context; +use burn_tensor::{repr::TensorDescription, DType}; +use cubecl::{client::ComputeClient, ir::Elem}; + +use crate::{ + fusion::{ + on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, + strides_dyn_rank, JitFusionHandle, + }, + tensor::is_contiguous, + BoolElement, JitRuntime, +}; + +use super::{ + super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors, Reshape, +}; +use std::collections::BTreeMap; + +/// Create or reuse handles for the outputs. +/// +/// It is also responsable to select the reference tensor. +pub struct OutputPlanner<'a, R: JitRuntime> { + inputs: &'a RegisteredTensors, + reshapes: &'a Vec, + outputs_sorted: Vec>, + handles: Vec>>, + globals: Vec>, + mapper: OutputPositionMapper, +} + +struct OutputSorted<'a> { + pos_original: usize, + precision: ElemwisePrecision, + tensor_relative: &'a TensorDescription, +} + +enum OutputKind { + Normal, + Inplace { input_pos: usize }, + Reshaped { reshape: Reshape }, +} + +impl<'a, R: JitRuntime> OutputPlanner<'a, R> { + pub fn new( + inputs: &'a RegisteredTensors, + outputs: &'a RegisteredTensors, + reshapes: &'a Vec, + ) -> Self { + let mut mapper = OutputPositionMapper::default(); + let mut outputs_sorted: Vec<_> = outputs + .iter() + .enumerate() + .map(|(pos, (precision, tensor))| { + mapper.register(precision, pos); + OutputSorted { + pos_original: pos, + precision, + tensor_relative: tensor, + } + }) + .collect(); + + outputs_sorted.sort_by(|a, b| { + let a_val: usize = a.tensor_relative.shape.iter().sum(); + let b_val: usize = b.tensor_relative.shape.iter().sum(); + + b_val.cmp(&a_val) + }); + + let mut handles = Vec::with_capacity(outputs.len()); + let mut globals = Vec::with_capacity(outputs.len()); + + for _ in 0..outputs.len() { + handles.push(None); + globals.push(None); + } + + Self { + inputs, + outputs_sorted, + reshapes, + handles, + globals, + mapper, + } + } + + pub fn run( + mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + ) { + // So that we can borrow self during the iteration. + let mut outputs = Vec::new(); + core::mem::swap(&mut outputs, &mut self.outputs_sorted); + + for output in outputs.into_iter() { + let tensor_global = context + .tensors + .get(&output.tensor_relative.id) + .unwrap() + .clone(); + let strides = strides_dyn_rank(&tensor_global.shape); + + match self.output_kind(plan, &tensor_global, &output, &strides) { + OutputKind::Inplace { input_pos } => { + self.inplace_output(context, plan, output, tensor_global, input_pos); + } + OutputKind::Normal => { + self.normal_output::( + client, + device, + context, + plan, + output, + tensor_global, + strides, + ); + } + OutputKind::Reshaped { reshape } => { + self.reshaped_output::( + client, + device, + context, + plan, + output, + tensor_global, + strides, + reshape, + ); + } + } + } + + for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) { + plan.handle_outputs.push(handle.unwrap()); + plan.global_outputs.push(global.unwrap()); + } + + Self::add_layout_info_inputs(plan); + } + + fn add_layout_info_inputs(analysis: &mut LaunchPlan<'_, R>) { + for hi in analysis.handle_inputs.iter() { + if let Some(reference) = analysis.reference.as_ref() { + if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { + if let Some(ops) = analysis.reads.get_mut(&hi.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::SameAsRef); + } + } + } + } + } + } + } + + fn output_kind( + &self, + plan: &mut LaunchPlan<'a, R>, + tensor_global: &TensorDescription, + output: &OutputSorted, + strides: &[usize], + ) -> OutputKind { + if let Some(reshape) = self + .reshapes + .iter() + .find(|r| r.reshaped == output.tensor_relative.id) + { + return OutputKind::Reshaped { + reshape: reshape.clone(), + }; + } + + plan.potential_inplaces + .iter() + .enumerate() + .find(|(_pos, pi)| { + pi.tensor_relative.dtype == tensor_global.dtype + && pi.tensor_relative.shape == output.tensor_relative.shape + && pi.strides == strides + }) + .map(|(pos, _)| pos) + .map(|input_pos| OutputKind::Inplace { input_pos }) + .unwrap_or(OutputKind::Normal) + } + + fn inplace_output( + &mut self, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + input_index: usize, + ) { + let potential_inplace = plan.potential_inplaces.remove(input_index); + let handle_input = plan.handle_inputs.get(potential_inplace.input_pos).unwrap(); + + if plan.reference.is_none() { + let index_input = self + .inputs + .get_index(output.precision, potential_inplace.tensor_relative.id) + .unwrap(); + + plan.reference = Some(Reference { + layout: Arg::Input(index_input as u32, output.precision, LayoutInfo::IsRef), + shape: tensor_global.shape.clone(), + strides: handle_input.handle.strides.clone(), + }); + + if let Some(ops) = plan.reads.get_mut(&handle_input.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::IsRef); + }; + } + } + + if let Some(ElemwiseOp::Assign(op)) = plan.writes.get_mut(&output.tensor_relative.id) { + op.out.add_layout_info(LayoutInfo::IsRef); + }; + } + + context + .handles + .register_handle(tensor_global.id, handle_input.handle.clone()); + + self.handles[output.pos_original] = Some(HandleOutput::Alias { + input_pos: potential_inplace.input_pos, + precision: output.precision, + }); + self.globals[output.pos_original] = Some(tensor_global); + } + + #[allow(clippy::too_many_arguments)] + fn normal_output( + &mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + strides: Vec, + ) { + if plan.reference.is_none() { + let position = self + .mapper + .resolve_index(&output.precision, output.pos_original); + plan.reference = Some(Reference { + layout: Arg::Output(position, output.precision, LayoutInfo::IsRef), + shape: tensor_global.shape.clone(), + strides: strides.clone(), + }); + + if let ElemwiseOp::Assign(op) = plan.writes.get_mut(&output.tensor_relative.id).unwrap() + { + op.out.add_layout_info(LayoutInfo::IsRef); + }; + } else if let Some(reference) = plan.reference.as_ref() { + if reference.strides == strides && reference.shape == tensor_global.shape { + if let ElemwiseOp::Assign(op) = + plan.writes.get_mut(&output.tensor_relative.id).unwrap() + { + op.out.add_layout_info(LayoutInfo::SameAsRef); + }; + } + } + + // We encode bool tensors as `B`. + let dtype = match tensor_global.dtype { + DType::Bool => BT::dtype(), + _ => tensor_global.dtype, + }; + let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); + + let handle = JitFusionHandle { + client: client.clone(), + handle: client.empty(size), + device: device.clone(), + strides, + dtype, + }; + + plan.rank = usize::max(tensor_global.shape.len(), plan.rank); + context + .handles + .register_handle(tensor_global.id, handle.clone()); + + self.handles[output.pos_original] = Some(HandleOutput::Owned { + precision: output.precision, + handle, + global_shape: tensor_global.shape.clone(), + global_id: tensor_global.id, + vectorization: 1, + }); + self.globals[output.pos_original] = Some(tensor_global); + } + + #[allow(clippy::too_many_arguments)] + fn reshaped_output( + &mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + strides: Vec, + reshape: Reshape, + ) { + let original_handle = plan + .handle_inputs + .iter() + .find(|handle| handle.relative_id == reshape.original) + .unwrap(); + + // We encode bool tensors as `B`. + let dtype = match tensor_global.dtype { + DType::Bool => BT::dtype(), + _ => tensor_global.dtype, + }; + + if is_contiguous( + &original_handle.global_shape, + &original_handle.handle.strides, + ) { + plan.writes.remove(&output.tensor_relative.id); + + let handle = JitFusionHandle { + client: client.clone(), + handle: original_handle.handle.handle.clone(), + device: device.clone(), + strides, + dtype, + }; + context + .handles + .register_handle(tensor_global.id, handle.clone()); + // IT will never be access, just a way to keep the original position working. + self.handles[output.pos_original] = Some(HandleOutput::Alias { + input_pos: 0, + precision: output.precision, + }); + self.globals[output.pos_original] = Some(tensor_global); + } else { + self.normal_output::( + client, + device, + context, + plan, + output, + tensor_global, + strides, + ); + } + } +} + +/// Group output position by [element precision](ElemwisePrecision). +#[derive(Default, Debug)] +pub struct OutputPositionMapper { + map: BTreeMap>, +} + +impl OutputPositionMapper { + /// Register a new output with the given precision and position. + pub fn register(&mut self, precision: ElemwisePrecision, pos_handle: usize) { + if let Some(positions) = self.map.get_mut(&precision) { + positions.push(pos_handle); + } else { + self.map.insert(precision, vec![pos_handle]); + } + } + + /// Returns the right position from the precision and the global position in all outputs. + pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { + self.map + .get(precision) + .unwrap() + .iter() + .enumerate() + .find(|(_pos_elem, pos_all)| **pos_all == pos_handle) + .map(|(pos_elem, _pos_all)| pos_elem) + .unwrap() as u32 + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs new file mode 100644 index 0000000000..89a11a188c --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -0,0 +1,86 @@ +use std::collections::BTreeMap; + +use crate::{ + fusion::{ + on_write::ir::{Arg, ElemwiseOp, ElemwisePrecision}, + JitFusionHandle, + }, + JitRuntime, +}; +use burn_tensor::repr::{TensorDescription, TensorId}; + +/// The plan is responsable to keep runtime information related to the launch of a fused kernel +/// at one place. +#[derive(Debug)] +pub(crate) struct LaunchPlan<'a, R: JitRuntime> { + pub potential_inplaces: Vec>, + pub global_inputs: Vec, + pub global_outputs: Vec, + pub handle_inputs: Vec>, + pub handle_outputs: Vec>, + pub reference: Option, + pub reads: BTreeMap>, + pub writes: BTreeMap, + pub vectorization: BTreeMap, + pub rank: usize, +} + +impl LaunchPlan<'_, R> { + pub fn new( + reads: &BTreeMap>, + writes: &BTreeMap, + rank: usize, + ) -> Self { + LaunchPlan { + potential_inplaces: Vec::new(), + global_inputs: Vec::new(), + global_outputs: Vec::new(), + handle_inputs: Vec::new(), + handle_outputs: Vec::new(), + reference: None, + vectorization: BTreeMap::default(), + reads: reads.clone(), + writes: writes.clone(), + rank, + } + } +} + +#[derive(Debug)] +pub enum HandleOutput { + Alias { + input_pos: usize, + precision: ElemwisePrecision, + }, + Owned { + global_id: TensorId, + precision: ElemwisePrecision, + handle: JitFusionHandle, + global_shape: Vec, + vectorization: u8, + }, +} + +#[derive(Debug)] +pub struct HandleInput { + pub relative_id: TensorId, + pub global_id: TensorId, + pub precision: ElemwisePrecision, + pub handle: JitFusionHandle, + pub global_shape: Vec, + pub vectorization: u8, +} + +#[derive(Debug)] +pub struct Reference { + pub layout: Arg, + pub shape: Vec, + pub strides: Vec, +} + +#[derive(Debug)] +pub struct PotentialInplace<'a> { + pub input_pos: usize, + pub tensor_relative: &'a TensorDescription, + pub strides: Vec, +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/runner.rs b/crates/burn-jit/src/fusion/on_write/trace/runner.rs new file mode 100644 index 0000000000..fc3109327d --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/runner.rs @@ -0,0 +1,154 @@ +use super::super::ir::{ElemwiseConfig, GlobalArgsLaunch}; +use crate::{fusion::JitFusionHandle, JitRuntime}; +use burn_tensor::repr::{TensorDescription, TensorId}; +use cubecl::prelude::*; +use std::collections::BTreeMap; + +/// A trace runner is responsible for determining the vectorization factor as well as launching +/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) +/// with a provided [element wise config](ElemwiseConfig). +pub trait TraceRunner { + /// The error that might happen while running the trace. + type Error; + + /// Run the trace. + fn run<'a>( + &'a self, + client: &'a ComputeClient, + inputs: GlobalArgsLaunch<'a, R>, + outputs: GlobalArgsLaunch<'a, R>, + config: &'a ElemwiseConfig, + ) -> Result<(), Self::Error>; + + /// The vectorization factor for all inputs and outputs. + fn vectorization<'a>( + vectorizations: &mut BTreeMap, + handles_inputs: impl Iterator>, + inputs: impl Iterator, + outputs: impl Iterator, + reshaped: impl Iterator, + ) { + vectorization_default(vectorizations, handles_inputs, inputs, outputs, reshaped) + } +} + +fn vectorization_default<'a, R: JitRuntime>( + vectorizations: &mut BTreeMap, + handles_inputs: impl Iterator>, + inputs: impl Iterator, + outputs: impl Iterator, + reshaped: impl Iterator, +) { + enum Vect { + Broadcated, + Max(u8), + } + + // The default version uses the last dimension as vectorization axis and assumes a + // perpendicular contiguous line. + let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { + let rank = handle.strides.len(); + + // Last dimension strides should be 1, otherwise vecX won't be contiguous. + if handle.strides[rank - 1] != 1 { + return Vect::Max(1); + } + let shape_axis = desc.shape[rank - 1]; + + if shape_axis == 1 { + return Vect::Broadcated; + } + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size or broadcated. + if shape_axis % s as usize == 0 { + return Vect::Max(s); + } + } + + Vect::Max(1) + }; + + let vectorization_output = |desc: &TensorDescription| { + let rank = desc.shape.len(); + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size. + if desc.shape[rank - 1] % s as usize == 0 { + return Vect::Max(s); + } + } + + Vect::Max(1) + }; + + let vectorization_reshape = + |reshaped: &TensorDescription, original: &TensorDescription, multi_reads: bool| { + let reshape_axis = reshaped.shape[reshaped.shape.len() - 1]; + let shape_axis = original.shape[original.shape.len() - 1]; + + if !multi_reads && reshape_axis == 1 { + return Vect::Broadcated; + } + + for s in R::line_size_elem(&reshaped.dtype.into()) { + if !multi_reads { + // The last dimension should be a multiple of the vector size or broadcated. + if reshape_axis % s as usize == 0 { + return Vect::Max(s); + } + } else { + // Since the original tensor must share the same vectorization factor as the + // reshaped tensor, they must have compatible shapes when both are access + // independently. + if reshape_axis % s as usize == 0 && shape_axis % s as usize == 0 { + return Vect::Max(s); + } + } + } + + Vect::Max(1) + }; + + let mut max_current = u8::MAX; + + for (handle, tensor) in handles_inputs.zip(inputs) { + match vectorization_input(handle, tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; + } + + for tensor in outputs { + match vectorization_output(tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; + } + + for (reshaped, original, multi_reads) in reshaped { + match vectorization_reshape(reshaped, original, multi_reads) { + Vect::Broadcated => { + vectorizations.insert(original.id, 1); + vectorizations.insert(reshaped.id, 1); + } + Vect::Max(val) => { + vectorizations.insert(original.id, 0); + vectorizations.insert(reshaped.id, 0); + max_current = Ord::min(val, max_current); + } + } + } + + for (_id, val) in vectorizations.iter_mut() { + if *val == 0 { + *val = max_current; + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs new file mode 100644 index 0000000000..ff775e3327 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs @@ -0,0 +1,83 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use burn_fusion::stream::Context; +use burn_tensor::repr::TensorId; + +use crate::{ + fusion::{ + on_write::{ir::ElemwiseOp, settings::FuseSettings}, + JitFusionHandle, + }, + JitRuntime, +}; + +use super::{HandleOutput, LaunchPlan, Reshape, TraceRunner}; + +/// Select the best vectorization factor for each tensor handle. +pub struct VectorizationPlanner<'a, R: JitRuntime> { + settings: &'a FuseSettings, + reshapes: &'a Vec, + reads: &'a BTreeMap>, + _r: PhantomData, +} + +impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { + pub fn new( + reshapes: &'a Vec, + reads: &'a BTreeMap>, + settings: &'a FuseSettings, + ) -> Self { + Self { + settings, + reshapes, + reads, + _r: PhantomData, + } + } + pub fn run>( + self, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + ) { + let tensors_reshaped = self.reshapes.iter().map(|reshape| { + ( + context.tensors.get(&reshape.reshaped).unwrap(), + context.tensors.get(&reshape.original).unwrap(), + self.reads.get(&reshape.original).unwrap().len() > 1, + ) + }); + + Runner::vectorization( + &mut plan.vectorization, + plan.handle_inputs.iter().map(|item| &item.handle), + plan.global_inputs.iter(), + plan.global_outputs.iter(), + tensors_reshaped, + ); + + // If mix vectorization is disable, we set the vectorization factor of each tensor to the + // minimum value found. + if !self.settings.mix_vectorization { + let factor = plan.vectorization.values().min().cloned(); + if let Some(factor) = factor { + plan.vectorization + .iter_mut() + .for_each(|(_, vf)| *vf = factor); + } + } + + for handle in plan.handle_inputs.iter_mut() { + handle.vectorization = *plan.vectorization.get(&handle.global_id).unwrap(); + } + for handle in plan.handle_outputs.iter_mut() { + if let HandleOutput::Owned { + vectorization, + global_id, + .. + } = handle + { + *vectorization = *plan.vectorization.get(global_id).unwrap() + } + } + } +} diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index ccfcc3ef9e..75c7207885 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -25,11 +25,18 @@ pub fn sum( match cube_count { SumStrategy::OneShot(cube_count) => { - let output = shared_sum::(&client, tensor.as_handle_ref(), cube_count)?; - Ok(from_data::( - TensorData::new(vec![output], vec![1]), - &device, - )) + let handle = client.empty(E::size().unwrap()); + let output = + JitTensor::new_contiguous(client.clone(), device, [1].into(), handle, E::dtype()); + + shared_sum::( + &client, + tensor.as_handle_ref(), + output.as_handle_ref(), + cube_count, + )?; + + Ok(output) } SumStrategy::Chained(strategy) => reduce::(tensor, strategy), #[cfg(feature = "autotune")] diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index cd5cd61157..b26d852656 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -275,10 +275,20 @@ mod sum_ops { pub(crate) fn sum_one_shot( input: JitTensor, ) -> Result, String> { + let client = input.client.clone(); let device = input.device.clone(); - cubecl::reduce::shared_sum::(&input.client, input.as_handle_ref(), C) - .map(|output| from_data::(TensorData::new(vec![output], vec![1]), &device)) - .map_err(|e| e.to_string()) + let handle = client.empty(E::size().unwrap()); + let output = JitTensor::new_contiguous(client, device, [1].into(), handle, E::dtype()); + + cubecl::reduce::shared_sum::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + C, + ) + .map_err(|e| e.to_string())?; + + Ok(output) } #[cfg(feature = "autotune")] diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 5d01ddd7d3..fb3361ec17 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -5,8 +5,8 @@ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, FromDataOperationDescription, OperationDescription, PermuteOperationDescription, - RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, - SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -92,7 +92,7 @@ impl BoolTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 1cf211701c..1b17d5a2ad 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -11,10 +11,10 @@ use burn_tensor::repr::{ FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, - ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, - SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -419,7 +419,7 @@ impl FloatTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 997bf5b9e6..eefecd7ef8 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -11,10 +11,10 @@ use burn_tensor::repr::{ FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, - ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, - ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, - SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -74,7 +74,7 @@ impl IntTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index e4b0f3ccaf..d3203ea14d 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -215,7 +215,7 @@ pub enum BaseOperationDescription { /// Float => [reshape](crate::ops::FloatTensorOps::float_reshape). /// Int => [reshape](crate::ops::IntTensorOps::int_reshape). /// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape). - Reshape(ReshapeDescription), + Reshape(UnaryOperationDescription), /// Operation corresponding to: /// @@ -644,13 +644,6 @@ pub struct FromDataOperationDescription { pub data: TensorData, } -#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct ReshapeDescription { - pub input: TensorDescription, - pub out: TensorDescription, -} - #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ExpandDescription { diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index b82175c3fe..c9315823d7 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -1449,8 +1449,8 @@ where // last two dimensions let shape = &self.shape().dims[D - 2..].to_owned(); - let mask = Tensor::::tril_mask(shape, diagonal, &self.device()).unsqueeze(); + self.mask_fill(mask, 0) } @@ -2208,6 +2208,7 @@ where let indices = Tensor::::arange(0..size as i64, device).unsqueeze::<2>(); let ones = K::ones([1, size].into(), device); let zeros = K::zeros([size, size].into(), device); + Self::new(K::scatter(0, zeros, indices.primitive, ones)) } } diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index bd9ffbf860..de581b2d4d 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -14,6 +14,102 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_reshape_maybe_fused_1() { + let tensor = TestTensorInt::arange(0..32, &Default::default()); + let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); + let tensor1 = tensor.clone().reshape([1, 4, 8]); + let output = tensor0 + tensor1; + + let expected = TensorData::from([ + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + ]); + output.into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_reshape_maybe_fused_2() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor1 = tensor.reshape([2, 2, 1]); + let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); + let output = tensor2 + tensor1; + + let expected_tensor1 = + TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); + output.into_data().assert_eq(&expected_tensor1, false); + } + + #[test] + fn should_support_reshape_maybe_fused_3() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor1 = tensor.reshape([2, 2, 1]); + let tensor2 = TestTensorInt::<3>::full([2, 2, 3], 5, &Default::default()); + + let expected_tensor1 = TensorData::from([[[0], [2]], [[1], [2]]]); + tensor1.into_data().assert_eq(&expected_tensor1, false); + } + + #[test] + fn should_support_reshape_maybe_fused_4() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); + let tensor2 = tensor2.swap_dims(0, 1); + let tensor1 = tensor.reshape([2, 2, 1]); + let output = tensor2 + tensor1; + println!("{output}"); + + let expected_tensor1 = + TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); + output.into_data().assert_eq(&expected_tensor1, false); + } + #[test] fn should_support_reshape_int() { let data = TensorData::from([0, 1, 2]); diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 043c61672d..d380dcc9cb 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,7 +16,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -vulkan = ["wgpu", "burn/vulkan"] +vulkan = ["burn/vulkan", "burn/default"] remote = ["burn/remote"] cuda = ["burn/cuda"] hip = ["burn/hip"] diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 927c190b2c..610cd821c4 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -93,6 +93,16 @@ mod wgpu { } } +#[cfg(feature = "vulkan")] +mod vulkan { + use crate::{launch, ElemType}; + use burn::backend::{Autodiff, Vulkan}; + + pub fn run() { + launch::>>(vec![Default::default()]); + } +} + #[cfg(feature = "remote")] mod remote { use crate::{launch, ElemType}; @@ -143,4 +153,6 @@ fn main() { hip::run(); #[cfg(feature = "remote")] remote::run(); + #[cfg(feature = "vulkan")] + vulkan::run(); }