From 1c4fbb028e06ac5e14ded8c12d22535355808294 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 22 Jan 2025 13:30:54 -0500 Subject: [PATCH 01/28] WIP --- Cargo.lock | 14 -- Cargo.toml | 8 +- .../src/fusion/elemwise/optimization.rs | 2 +- crates/burn-jit/src/fusion/matmul/args.rs | 12 +- crates/burn-jit/src/fusion/on_write/io.rs | 210 +++++++++++++----- crates/burn-jit/src/fusion/on_write/ir.rs | 26 ++- crates/burn-jit/src/fusion/on_write/kernel.rs | 9 +- crates/burn-jit/src/fusion/on_write/trace.rs | 2 + .../src/fusion/on_write/trace_builder.rs | 11 +- 9 files changed, 201 insertions(+), 93 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de9444c5ff..6eb9713a50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1589,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1609,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1629,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1642,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1657,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,7 +1682,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "float-ord", @@ -1701,7 +1694,6 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1705,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "darling", @@ -1728,7 +1719,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1734,6 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1743,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1764,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1778,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index f731d063a9..c4723d0f7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 2e33eefc20..bc67e9e945 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -141,7 +141,7 @@ fn elemwise_fuse( let args = comptime![Sequence::::new()]; let pos = ABSOLUTE_POS; - let length = match comptime![config.ref_layout] { + let length = match comptime![config.ref_layout.clone()] { Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => inputs.t_f32.index(index).len(), ElemwisePrecision::F16 => inputs.t_f16.index(index).len(), diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index 1dbbf3baea..996443a801 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -51,6 +51,7 @@ impl MatmulArgs for FusedMatmulArgs { LayoutInfo::IsRef, precision, &state.config, + None, ) } @@ -70,6 +71,7 @@ impl MatmulArgs for FusedMatmulArgs { LayoutInfo::IsRef, precision, &state.config, + None, ) } @@ -77,8 +79,8 @@ impl MatmulArgs for FusedMatmulArgs { let mut values = Registry::>::new(); let mut args = comptime![Sequence::::new()]; - values.insert(state.out, value); - comptime![args.push(state.out)]; + values.insert(comptime![state.out.clone()], value); + comptime![args.push(state.out.clone())]; fuse_on_write( unsafe { &(*state.inputs) }, @@ -225,9 +227,9 @@ impl FusedMatmulState { inputs: &inputs.global, outputs, config: comptime![config.clone()], - lhs: comptime![inputs.lhs], - rhs: comptime![inputs.rhs], - out: comptime![inputs.out], + lhs: comptime![inputs.lhs.clone()], + rhs: comptime![inputs.rhs.clone()], + out: comptime![inputs.out.clone()], } } } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 497bc510df..25c4ce0c30 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -1,5 +1,5 @@ use super::ir::*; -use cubecl::{linalg::tensor::index_offset_with_layout, prelude::*}; +use cubecl::prelude::*; #[cube] /// Read the value from the [arg](Arg) and cast it to the generic cube primitive. @@ -12,9 +12,9 @@ pub fn read( #[comptime] config: &ElemwiseConfig, ) -> Line { match arg { - Arg::Input(pos, precision, layout) => { - read_input(inputs, outputs, pos, ref_pos, layout, precision, config) - } + Arg::Input(pos, precision, layout) => read_input( + inputs, outputs, pos, ref_pos, layout, precision, config, None, + ), Arg::Output(pos, precision, layout) => { read_output(inputs, outputs, pos, ref_pos, layout, precision, config) } @@ -47,6 +47,42 @@ pub fn read( _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Literal(val, _precision) => Line::cast_from(val.runtime()), + Arg::InputReshaped { + original, shape, .. + } => match comptime![original.as_ref().clone()] { + Arg::Input(pos, precision, layout) => read_input( + inputs, + outputs, + pos, + ref_pos, + layout, + precision, + config, + comptime![Some(shape)], + ), + _ => comptime![panic![]], + }, + } +} + +#[cube] +pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) -> C { + match arg { + Arg::Scalar(pos, precision) => match comptime![precision] { + ElemwisePrecision::F32 => C::cast_from(*inputs.s_f32.index(pos)), + ElemwisePrecision::F16 => C::cast_from(*inputs.s_f16.index(pos)), + ElemwisePrecision::BF16 => C::cast_from(*inputs.s_bf16.index(pos)), + ElemwisePrecision::U64 => C::cast_from(*inputs.s_u64.index(pos)), + ElemwisePrecision::U32 => C::cast_from(*inputs.s_u32.index(pos)), + ElemwisePrecision::U16 => C::cast_from(*inputs.s_u16.index(pos)), + ElemwisePrecision::U8 => C::cast_from(*inputs.s_u8.index(pos)), + ElemwisePrecision::I64 => C::cast_from(*inputs.s_i64.index(pos)), + ElemwisePrecision::I32 => C::cast_from(*inputs.s_i32.index(pos)), + ElemwisePrecision::I16 => C::cast_from(*inputs.s_i16.index(pos)), + ElemwisePrecision::I8 => C::cast_from(*inputs.s_i8.index(pos)), + _ => comptime![panic!("Unsupported precision {precision:?}")], + }, + _ => comptime![panic!("Not a scalar")], } } @@ -59,6 +95,7 @@ pub fn read_input( #[comptime] layout: LayoutInfo, #[comptime] precision: ElemwisePrecision, #[comptime] config: &ElemwiseConfig, + #[comptime] shape: Option>, ) -> Line { match comptime![precision] { ElemwisePrecision::F32 => { @@ -66,7 +103,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -75,7 +112,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -84,7 +121,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -93,7 +130,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -102,7 +139,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -111,7 +148,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -120,7 +157,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -129,7 +166,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -138,7 +175,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -147,7 +184,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -156,7 +193,7 @@ pub fn read_input( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, shape), }; Line::cast_from(tensor[offset]) } @@ -180,7 +217,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -189,7 +226,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -198,7 +235,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -207,7 +244,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -216,7 +253,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -225,7 +262,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -234,7 +271,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -243,7 +280,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -252,7 +289,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -261,7 +298,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -270,7 +307,7 @@ pub fn read_output( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config, None), }; Line::cast_from(tensor[offset]) } @@ -296,7 +333,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_f32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -306,7 +345,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_f16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -316,7 +357,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_bf16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -326,7 +369,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u64.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -336,7 +381,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -346,7 +393,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -356,7 +405,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_u8.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -366,7 +417,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i64.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -376,7 +429,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i32.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -386,7 +441,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i16.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -396,7 +453,9 @@ pub fn write( let offset = match layout { LayoutInfo::SameAsRef => ref_pos, LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + LayoutInfo::Unknown => { + get_offset(inputs, outputs, tensor, ref_pos, config, None) + } }; let tensor = outputs.t_i8.index_mut(pos); tensor[offset] = Line::cast_from(value); @@ -428,99 +487,100 @@ fn get_offset( tensor: &Tensor>, pos: u32, #[comptime] config: &ElemwiseConfig, + #[comptime] shape: Option>, ) -> u32 { - match comptime![config.ref_layout] { + match comptime![config.ref_layout.clone()] { Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = inputs.t_f32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::F16 => { let layout = inputs.t_f16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = inputs.t_bf16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U64 => { let layout = inputs.t_u64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U32 => { let layout = inputs.t_u32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U16 => { let layout = inputs.t_u16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U8 => { let layout = inputs.t_u8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I64 => { let layout = inputs.t_i64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I32 => { let layout = inputs.t_i32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I16 => { let layout = inputs.t_i16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I8 => { let layout = inputs.t_i8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = outputs.t_f32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::F16 => { let layout = outputs.t_f16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = outputs.t_bf16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U64 => { let layout = outputs.t_u64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U32 => { let layout = outputs.t_u32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U16 => { let layout = outputs.t_u16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::U8 => { let layout = outputs.t_u8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I64 => { let layout = outputs.t_i64.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I32 => { let layout = outputs.t_i32.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I16 => { let layout = outputs.t_i16.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } ElemwisePrecision::I8 => { let layout = outputs.t_i8.index(index); - index_offset_with_layout(tensor, layout, pos, 0, config.rank, false) + index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, @@ -692,3 +752,33 @@ pub fn global_stride( _ => comptime![panic!("Unsupported precision {precision:?}")], } } + +/// Returns the offset of the tensor corresponding to the layout tensor. +#[cube] +fn index_offset_with_layout( + inputs: &GlobalArgs, + tensor: &Tensor>, + layout: &Tensor>, + offset_layout: u32, + dim_start: u32, + dim_end: u32, + #[comptime] shape: Option>, +) -> u32 { + let offset_ref = offset_layout * tensor.line_size(); + let mut offset = 0u32; + + #[unroll] + for i in dim_start..dim_end { + let shape_i = match comptime![shape.clone()] { + Some(s) => { + let arg = comptime![s.index(i.clone())]; + read_scalar::(inputs, comptime![arg.clone()]) + } + None => tensor.shape(i), + }; + let ogwl = offset_ref / layout.stride(i); + offset += ogwl % shape_i * tensor.stride(i); + } + + offset / tensor.line_size() +} diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0cec2d29c7..188a18ba39 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -4,7 +4,7 @@ use cubecl::prelude::*; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to an [elemwise operation](ElemwiseOp). pub enum Arg { Input(u32, ElemwisePrecision, LayoutInfo), @@ -13,11 +13,14 @@ pub enum Arg { Scalar(u32, ElemwisePrecision), /// Only constant that can be encoded into an u32 can be used as literal. Literal(u32, ElemwisePrecision), + InputReshaped { + id: u32, + original: Box, + shape: Sequence, + }, } -#[derive( - CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, -)] +#[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Layout information. pub enum LayoutInfo { /// The layout if the same as the reference. @@ -36,6 +39,7 @@ impl Arg { Arg::Output(_, p, _) => p, Arg::Scalar(_, p) => p, Arg::Literal(_, p) => p, + Arg::InputReshaped { original, .. } => return original.precision(), } } } @@ -85,6 +89,19 @@ pub enum ElemwiseOp { rhs: Arg, out: Arg, }, + Reshape { + input: Arg, + out: Arg, + shape: Sequence, + }, +} + +#[derive(CubeLaunch)] +pub struct ReshapedTensor { + #[cube(comptime)] + original: Arg, + #[cube(comptime)] + shape: Sequence, } #[derive(CubeLaunch)] @@ -112,6 +129,7 @@ pub struct GlobalArgs { pub s_u32: Sequence, pub s_u16: Sequence, pub s_u8: Sequence, + pub t_reshaped: Sequence, } impl GlobalArgsLaunch<'_, R> { diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index 269ba1f3b8..3e9c6117f9 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -33,8 +33,8 @@ pub fn fuse_on_write( // Write the values given as arguments. #[unroll] for i in 0..write_args.len() { - let arg = comptime![*write_args.index(i)]; - let val = write_values.find(arg); + let arg = comptime![write_args.index(i).clone()]; + let val = write_values.find(comptime![arg.clone()]); write::(inputs, outputs, &mut locals, write_pos, val, arg, config); } @@ -550,6 +550,9 @@ pub fn fuse_on_write( } _ => comptime![panic!("Unsupported precision {op:?}")], }, + ElemwiseOp::Reshape { .. } => { + // Nothing to do. + } ElemwiseOp::ConditionalAssign { cond, lhs, @@ -677,7 +680,7 @@ pub fn fuse_on_write( out, config, ), - _ => comptime![panic!("Unsupported precision {op:?}")], + _ => comptime![panic!("Unsupported precision")], }, } } diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 2c29d05ce8..f89bcdac27 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -438,6 +438,7 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), + SequenceArg::new(), ); for hi in handle_inputs.iter() { @@ -526,6 +527,7 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), + SequenceArg::new(), ); for item in handle_outputs.iter() { match item { diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index c37237ae4e..eb6c3f0a02 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -54,7 +54,7 @@ impl FuseOnWriteTraceBuilder { pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { let arg = self.output(tensor); - self.outputs_unhandled.push(arg); + self.outputs_unhandled.push(arg.clone()); arg } @@ -98,7 +98,10 @@ impl FuseOnWriteTraceBuilder { self.reads.insert( tensor.id, - ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }), + ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + }), ); out @@ -331,6 +334,10 @@ impl FuseOnWriteTraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + ElemwiseOp::Reshape { input, out, .. } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } }; // For all operators, mark their local tensor id in the proper set. From 06cb156b56140ba162500500d811a34706fd0bf1 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 22 Jan 2025 17:29:42 -0500 Subject: [PATCH 02/28] WIP --- crates/burn-fusion/src/stream/context.rs | 29 +++++- .../burn-jit/src/fusion/elemwise/builder.rs | 4 +- .../burn-jit/src/fusion/on_write/builder.rs | 95 ++++++++++++++++--- crates/burn-jit/src/fusion/on_write/io.rs | 40 ++++---- crates/burn-jit/src/fusion/on_write/ir.rs | 8 +- crates/burn-jit/src/fusion/on_write/kernel.rs | 3 - crates/burn-jit/src/fusion/on_write/trace.rs | 9 ++ .../src/fusion/on_write/trace_builder.rs | 69 +++++++++++++- 8 files changed, 213 insertions(+), 44 deletions(-) diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 0f9c75fc94..eeeeacbb91 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -39,7 +39,6 @@ pub struct Context<'a, H> { pub scalar_u8: &'a Vec, } -#[derive(Default)] pub(crate) struct OperationConverter { tensors_relative2global: HashMap, tensors_global2relative: HashMap, @@ -59,6 +58,32 @@ pub(crate) struct OperationConverter { scalar_u8: Vec, } +impl Default for OperationConverter { + fn default() -> Self { + let mut val = Self { + tensors_relative2global: Default::default(), + tensors_global2relative: Default::default(), + shapes_global2relative: Default::default(), + scalar_f32: Default::default(), + scalar_f16: Default::default(), + scalar_bf16: Default::default(), + scalar_i64: Default::default(), + scalar_i32: Default::default(), + scalar_i16: Default::default(), + scalar_i8: Default::default(), + scalar_u64: Default::default(), + scalar_u32: Default::default(), + scalar_u16: Default::default(), + scalar_u8: Default::default(), + }; + + // global 1 is always shape id 0. + val.shapes_global2relative.insert(1, 0); + + val + } +} + /// Fork of a [context](Context) which owns its data. pub struct ContextOwned { tensors: HashMap, @@ -181,6 +206,8 @@ impl OperationConverter { self.tensors_relative2global.clear(); self.tensors_global2relative.clear(); self.shapes_global2relative.clear(); + // global 1 is always shape id 0. + self.shapes_global2relative.insert(1, 0); self.scalar_f32.clear(); self.scalar_f16.clear(); self.scalar_bf16.clear(); diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 461767e9fc..13c12b076d 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -52,7 +52,9 @@ impl OptimizationBuilder> for ElementWiseBuild } fn properties(&self) -> burn_fusion::OptimizationProperties { - self.builder.properties() + let mut props = self.builder.properties(); + props.ready = props.ready && self.builder.num_ops > self.builder.num_reshapes; + props } fn len(&self) -> usize { diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index bf31ef78ea..3d5b917bec 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -19,7 +19,8 @@ pub(crate) struct FuseOnWriteBuilder { builder: TryFuseBuilder, current_output_shape: Vec, status: OptimizationStatus, - num_ops: usize, + pub(crate) num_ops: usize, + pub(crate) num_reshapes: usize, max_bindings: u32, } @@ -38,19 +39,26 @@ impl TryFuseBuilder { } } - fn register(&mut self, add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder)) -> bool { + fn register(&mut self, add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder) -> bool) -> bool { // Always allow the first operation to be added. if !self.added_ops { self.added_ops = true; - add_ops(&mut self.builder); + + if !add_ops(&mut self.builder) { + return false; + } return true; } let mut cloned = self.builder.clone(); - add_ops(&mut cloned); + if !add_ops(&mut cloned) { + return false; + } + if cloned.estimate_bindings() > self.max_bindings { return false; } + self.builder = cloned; true } @@ -141,6 +149,7 @@ impl FuseOnWriteBuilder { Self { builder: TryFuseBuilder::new(max_bindings, bool_precision), num_ops: 0, + num_reshapes: 0, max_bindings, current_output_shape: Vec::new(), status: OptimizationStatus::Open, @@ -172,6 +181,28 @@ impl FuseOnWriteBuilder { BaseOperationDescription::Cast(desc) => self.register_unary_ops(desc, |input, out| { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), + BaseOperationDescription::Reshape(desc) => { + if !self.output_is_compatible(&desc.out) { + return false; + } + + if self.builder.register(|build| { + let input = match build.input_reshaped(&desc.input, &desc.out) { + Some(val) => val, + None => return false, + }; + let out = build.output(&desc.out); + + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true + }) { + self.num_reshapes += 1; + true + } else { + false + } + } _ => false, } } @@ -302,7 +333,9 @@ impl FuseOnWriteBuilder { lhs, rhs, out, - }) + }); + + true }) } NumericOperationDescription::MaskFill(desc) => { @@ -321,7 +354,9 @@ impl FuseOnWriteBuilder { lhs, rhs, out, - }) + }); + + true }) } NumericOperationDescription::Ones(desc) => { @@ -336,7 +371,9 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } NumericOperationDescription::Zeros(desc) => { @@ -351,7 +388,9 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } NumericOperationDescription::Full((desc, elem)) => { @@ -363,7 +402,9 @@ impl FuseOnWriteBuilder { let input = build.scalar(elem, desc.dtype); let out = build.output(desc); - build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })) + build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out })); + + true }) } _ => false, @@ -383,7 +424,9 @@ impl FuseOnWriteBuilder { let rhs = build.input(&desc.rhs); let out = build.output(&desc.out); - build.register_operation(func(lhs, rhs, out)) + build.register_operation(func(lhs, rhs, out)); + + true }) } @@ -398,7 +441,8 @@ impl FuseOnWriteBuilder { self.builder.register(|build| { let input = build.input(&desc.input); let out = build.output(&desc.out); - build.register_operation(func(input, out)) + build.register_operation(func(input, out)); + true }) } @@ -420,17 +464,42 @@ impl FuseOnWriteBuilder { let rhs = build.scalar(&desc.rhs, elem); let out = build.output(&desc.out); - build.register_operation(func(lhs, rhs, out)) + build.register_operation(func(lhs, rhs, out)); + + true }) } fn output_is_compatible(&mut self, out: &TensorDescription) -> bool { if self.current_output_shape.is_empty() { self.current_output_shape.clone_from(&out.shape); - } else if self.current_output_shape != out.shape { + return true; + } + + // Last axis should be equal. + if self.current_output_shape.last() != out.shape.last() { return false; } + let rank = self.current_output_shape.len(); + + // Rank should be equal. + if rank != out.shape.len() { + return false; + } + + for i in 0..(rank - 1) { + let curr = self.current_output_shape[i]; + let new = out.shape[i]; + + // Broadcast is supported. + // + // 0 is the shape id for a global shape of 1. + if curr != new && new != 0 { + return false; + } + } + true } } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 25c4ce0c30..ef8bed3b7b 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -32,20 +32,14 @@ pub fn read( ElemwisePrecision::I8 => Line::cast_from(locals.l_i8.find(pos)), ElemwisePrecision::Bool => Line::cast_from(locals.l_bool.find(pos)), }, - Arg::Scalar(pos, precision) => match comptime![precision] { - ElemwisePrecision::F32 => Line::cast_from(*inputs.s_f32.index(pos)), - ElemwisePrecision::F16 => Line::cast_from(*inputs.s_f16.index(pos)), - ElemwisePrecision::BF16 => Line::cast_from(*inputs.s_bf16.index(pos)), - ElemwisePrecision::U64 => Line::cast_from(*inputs.s_u64.index(pos)), - ElemwisePrecision::U32 => Line::cast_from(*inputs.s_u32.index(pos)), - ElemwisePrecision::U16 => Line::cast_from(*inputs.s_u16.index(pos)), - ElemwisePrecision::U8 => Line::cast_from(*inputs.s_u8.index(pos)), - ElemwisePrecision::I64 => Line::cast_from(*inputs.s_i64.index(pos)), - ElemwisePrecision::I32 => Line::cast_from(*inputs.s_i32.index(pos)), - ElemwisePrecision::I16 => Line::cast_from(*inputs.s_i16.index(pos)), - ElemwisePrecision::I8 => Line::cast_from(*inputs.s_i8.index(pos)), - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, + Arg::Scalar(..) => { + let scalar = read_scalar::(inputs, arg); + Line::new(scalar) + } + Arg::ScalarShape(_) => { + let scalar = read_scalar_shape(inputs, arg); + Line::cast_from(scalar) + } Arg::Literal(val, _precision) => Line::cast_from(val.runtime()), Arg::InputReshaped { original, shape, .. @@ -86,6 +80,17 @@ pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) } } +#[cube] +pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { + match arg { + Arg::ScalarShape(pos) => { + let offset = comptime![inputs.s_u32.len() - pos - 1]; + *inputs.s_u32.index(offset) + } + _ => comptime![panic!["Not a scalar shape"]], + } +} + #[cube] pub fn read_input( inputs: &GlobalArgs, @@ -767,12 +772,15 @@ fn index_offset_with_layout( let offset_ref = offset_layout * tensor.line_size(); let mut offset = 0u32; - #[unroll] + // Need to unroll when fusing a reshape. + let unroll = comptime![shape.is_some()]; + + #[unroll(unroll)] for i in dim_start..dim_end { let shape_i = match comptime![shape.clone()] { Some(s) => { let arg = comptime![s.index(i.clone())]; - read_scalar::(inputs, comptime![arg.clone()]) + read_scalar_shape(inputs, comptime![arg.clone()]) } None => tensor.shape(i), }; diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 188a18ba39..8530ad23cf 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -11,10 +11,10 @@ pub enum Arg { Local(u32, ElemwisePrecision), Output(u32, ElemwisePrecision, LayoutInfo), Scalar(u32, ElemwisePrecision), + ScalarShape(u32), /// Only constant that can be encoded into an u32 can be used as literal. Literal(u32, ElemwisePrecision), InputReshaped { - id: u32, original: Box, shape: Sequence, }, @@ -39,6 +39,7 @@ impl Arg { Arg::Output(_, p, _) => p, Arg::Scalar(_, p) => p, Arg::Literal(_, p) => p, + Arg::ScalarShape(_) => return ElemwisePrecision::U32, Arg::InputReshaped { original, .. } => return original.precision(), } } @@ -89,11 +90,6 @@ pub enum ElemwiseOp { rhs: Arg, out: Arg, }, - Reshape { - input: Arg, - out: Arg, - shape: Sequence, - }, } #[derive(CubeLaunch)] diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index 3e9c6117f9..fa5abbfba9 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -550,9 +550,6 @@ pub fn fuse_on_write( } _ => comptime![panic!("Unsupported precision {op:?}")], }, - ElemwiseOp::Reshape { .. } => { - // Nothing to do. - } ElemwiseOp::ConditionalAssign { cond, lhs, diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index f89bcdac27..8948a926a3 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -19,6 +19,7 @@ pub struct FuseOnWriteTrace { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, + shapes: Vec, ops: Vec, reads: BTreeMap, writes: BTreeMap, @@ -496,6 +497,14 @@ impl FuseOnWriteTrace { } } + for relative in self.shapes.iter().rev() { + let global = context.tensors.get(&relative.id).unwrap(); + + for shape in global.shape.iter().rev() { + inputs.s_u32.push(ScalarArg::new(*shape as u32)) + } + } + inputs } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index eb6c3f0a02..e22ae04c72 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -6,6 +6,7 @@ use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, DType, Element, }; +use cubecl::prelude::Sequence; use std::collections::BTreeMap; #[derive(Clone)] @@ -14,6 +15,7 @@ pub struct FuseOnWriteTraceBuilder { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, + shapes: Vec, ops: Vec, reads: BTreeMap, pub bool_precision: ElemwisePrecision, @@ -28,6 +30,7 @@ impl FuseOnWriteTraceBuilder { outputs: RegisteredTensors::default(), inputs: RegisteredTensors::default(), scalars: BTreeMap::default(), + shapes: Vec::default(), ops: Vec::new(), reads: BTreeMap::new(), bool_precision, @@ -130,6 +133,65 @@ impl FuseOnWriteTraceBuilder { } } + pub fn input_reshaped( + &mut self, + tensor: &TensorDescription, + output: &TensorDescription, + ) -> Option { + let precision = tensor.dtype.into(); + + // Bool tensors are encoded as bool_precision. + let precision_input = match precision { + ElemwisePrecision::Bool => self.bool_precision, + _ => precision, + }; + + match self.locals.get(precision, tensor.id) { + // Can't fused an already fused input. + // + // TODO: Can fuse one that is in global memory. + Some(_) => { + if self.outputs.get(precision_input, tensor.id).is_some() { + return None; + } + + // self.inputs.update(precision_input, tensor); + None + } + None => { + let new_input = self.inputs.insert(precision_input, tensor.clone()); + let out = self.locals.create(precision, tensor.id); + let original = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); + + let mut shape = Sequence::new(); + + let index = self.shapes.len(); + self.shapes.push(output.clone()); + let rank = output.shape.len(); + + for i in 0..output.shape.len() { + let id = index * rank + i; + shape.push(Arg::ScalarShape(id as u32)); + } + + let input = Arg::InputReshaped { + original: Box::new(original), + shape, + }; + + self.reads.insert( + tensor.id, + ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + }), + ); + + Some(out) + } + } + } + pub fn scalar(&mut self, _: &E, dtype: DType) -> Arg { let precision = dtype.into(); @@ -168,11 +230,14 @@ impl FuseOnWriteTraceBuilder { ); } + let shapes = self.shapes.clone(); + // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, inputs, scalars, + shapes, ops, reads, writes, @@ -334,10 +399,6 @@ impl FuseOnWriteTraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), - ElemwiseOp::Reshape { input, out, .. } => { - mark(input, &mut local_tensor_ids_input); - mark(out, &mut local_tensor_ids_output); - } }; // For all operators, mark their local tensor id in the proper set. From e8493af3fd783d6cfef2452ccc534b01b448f5ae Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 22 Jan 2025 19:04:47 -0500 Subject: [PATCH 03/28] WIP testing --- crates/burn-core/Cargo.toml | 2 +- crates/burn-core/src/nn/linear.rs | 22 +++++++++++++------ .../burn-jit/src/fusion/on_write/builder.rs | 4 ++++ crates/burn-jit/src/fusion/on_write/io.rs | 2 +- crates/burn-jit/src/fusion/on_write/ir.rs | 1 - crates/burn-jit/src/fusion/on_write/trace.rs | 5 +++-- .../src/fusion/on_write/trace_builder.rs | 2 ++ 7 files changed, 26 insertions(+), 12 deletions(-) diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e895cc4572..dc14d45f62 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -113,7 +113,7 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. +test-cuda = ["cuda-jit", "fusion"] # To use cuda during testing, default uses ndarray. test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 738dd87c80..f65164b522 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -75,15 +75,20 @@ impl Linear { return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1); } - let weight = self.weight.val().unsqueeze(); - let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); + // let weight = self.weight.val().unsqueeze(); - let output = input.matmul(weight); + let output = input * 5; - match bias { - Some(bias) => output + bias, + match &self.bias { + Some(bias) => output + bias.val().unsqueeze(), None => output, } + + // let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); + // match bias { + // Some(bias) => output + bias, + // None => output, + // } } } @@ -189,13 +194,16 @@ mod tests { let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); let linear = config.init::(&device); - let input_1d = Tensor::::ones(Shape::new([2]), &device); - let input_2d = Tensor::::ones(Shape::new([1, 2]), &device); + let input_1d = Tensor::::ones(Shape::new([3]), &device); + let input_2d = Tensor::::ones(Shape::new([1, 3]), &device); let result_1d = linear.forward(input_1d).unsqueeze::<2>(); + println!("{result_1d}"); let result_2d = linear.forward(input_2d); + println!("{result_2d}"); assert_eq!(result_1d.into_data(), result_2d.into_data()); + panic!("Test"); } #[test] diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 3d5b917bec..dce613bc75 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -182,6 +182,10 @@ impl FuseOnWriteBuilder { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), BaseOperationDescription::Reshape(desc) => { + if self.current_output_shape.is_empty() { + return false; + } + if !self.output_is_compatible(&desc.out) { return false; } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index ef8bed3b7b..b8bac76235 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -84,7 +84,7 @@ pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { match arg { Arg::ScalarShape(pos) => { - let offset = comptime![inputs.s_u32.len() - pos - 1]; + let offset = comptime![{ inputs.s_u32.len() - pos - 1 }]; *inputs.s_u32.index(offset) } _ => comptime![panic!["Not a scalar shape"]], diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 8530ad23cf..c96c3e1e57 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -125,7 +125,6 @@ pub struct GlobalArgs { pub s_u32: Sequence, pub s_u16: Sequence, pub s_u8: Sequence, - pub t_reshaped: Sequence, } impl GlobalArgsLaunch<'_, R> { diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 8948a926a3..1656a3a1a2 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -439,7 +439,6 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), - SequenceArg::new(), ); for hi in handle_inputs.iter() { @@ -497,13 +496,16 @@ impl FuseOnWriteTrace { } } + let mut tmp = vec![]; for relative in self.shapes.iter().rev() { let global = context.tensors.get(&relative.id).unwrap(); for shape in global.shape.iter().rev() { + tmp.push(shape); inputs.s_u32.push(ScalarArg::new(*shape as u32)) } } + println!("Shape {:?}", tmp); inputs } @@ -536,7 +538,6 @@ impl FuseOnWriteTrace { SequenceArg::new(), SequenceArg::new(), SequenceArg::new(), - SequenceArg::new(), ); for item in handle_outputs.iter() { match item { diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index e22ae04c72..0f653b6033 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -169,8 +169,10 @@ impl FuseOnWriteTraceBuilder { self.shapes.push(output.clone()); let rank = output.shape.len(); + println!("output {output:?}"); for i in 0..output.shape.len() { let id = index * rank + i; + println!("id {id:?}"); shape.push(Arg::ScalarShape(id as u32)); } From 4a2d03ecc44a8d83cef5e80d2caea08db1116385 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Fri, 24 Jan 2025 18:12:50 -0500 Subject: [PATCH 04/28] Very wip --- crates/burn-core/src/lib.rs | 3 + crates/burn-core/src/nn/linear.rs | 27 ++-- .../burn-jit/src/fusion/elemwise/builder.rs | 5 +- .../src/fusion/elemwise/optimization.rs | 1 + .../burn-jit/src/fusion/on_write/builder.rs | 19 +-- crates/burn-jit/src/fusion/on_write/io.rs | 134 ++++++++++++------ crates/burn-jit/src/fusion/on_write/trace.rs | 41 +++--- .../src/fusion/on_write/trace_builder.rs | 100 +++++++------ crates/burn-tensor/src/tests/ops/reshape.rs | 61 ++++++++ 9 files changed, 263 insertions(+), 128 deletions(-) diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index ade8d64db7..34887ec6b9 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -59,12 +59,15 @@ extern crate alloc; pub type TestBackend = burn_ndarray::NdArray; #[cfg(all(test, feature = "test-tch"))] +/// Backend for test cases pub type TestBackend = burn_tch::LibTorch; #[cfg(all(test, feature = "test-wgpu"))] +/// Backend for test cases pub type TestBackend = burn_wgpu::Wgpu; #[cfg(all(test, feature = "test-cuda"))] +/// Backend for test cases pub type TestBackend = burn_cuda::Cuda; /// Backend for autodiff test cases diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index f65164b522..994cac5b04 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -62,6 +62,10 @@ impl LinearConfig { } } +// into_contuiguous +// strides contibous +// shape + impl Linear { /// Applies the forward pass on the input tensor. /// @@ -75,20 +79,14 @@ impl Linear { return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1); } - // let weight = self.weight.val().unsqueeze(); - - let output = input * 5; + let weight = self.weight.val().unsqueeze(); + let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); + let output = input.matmul(weight); - match &self.bias { - Some(bias) => output + bias.val().unsqueeze(), + match bias { + Some(bias) => output + bias, None => output, } - - // let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); - // match bias { - // Some(bias) => output + bias, - // None => output, - // } } } @@ -194,16 +192,13 @@ mod tests { let config = LinearConfig::new(2, 3).with_initializer(Initializer::Constant { value }); let linear = config.init::(&device); - let input_1d = Tensor::::ones(Shape::new([3]), &device); - let input_2d = Tensor::::ones(Shape::new([1, 3]), &device); + let input_1d = Tensor::::ones(Shape::new([2]), &device); + let input_2d = Tensor::::ones(Shape::new([1, 2]), &device); let result_1d = linear.forward(input_1d).unsqueeze::<2>(); - println!("{result_1d}"); let result_2d = linear.forward(input_2d); - println!("{result_2d}"); assert_eq!(result_1d.into_data(), result_2d.into_data()); - panic!("Test"); } #[test] diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 13c12b076d..9e817a9ac6 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -31,6 +31,7 @@ impl ElementWiseBuilder { impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_tensor::repr::OperationDescription) { + println!("{:?}", operation); self.builder.register(operation) } @@ -48,7 +49,9 @@ impl OptimizationBuilder> for ElementWiseBuild } fn status(&self) -> burn_fusion::OptimizationStatus { - self.builder.status() + let state = self.builder.status(); + println!("{state:?}"); + state } fn properties(&self) -> burn_fusion::OptimizationProperties { diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index bc67e9e945..bd76633c74 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -28,6 +28,7 @@ pub struct ElemwiseOptimizationState { impl ElemwiseOptimization { /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + println!("{:?}", self.trace); self.trace .run::(&self.client, &self.device, context, &ElemwiseRunner) .unwrap(); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index dce613bc75..ff7a7b8116 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -182,9 +182,9 @@ impl FuseOnWriteBuilder { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), BaseOperationDescription::Reshape(desc) => { - if self.current_output_shape.is_empty() { - return false; - } + // if self.current_output_shape.is_empty() { + // return false; + // } if !self.output_is_compatible(&desc.out) { return false; @@ -481,9 +481,9 @@ impl FuseOnWriteBuilder { } // Last axis should be equal. - if self.current_output_shape.last() != out.shape.last() { - return false; - } + // if self.current_output_shape.last() != out.shape.last() { + // return false; + // } let rank = self.current_output_shape.len(); @@ -492,16 +492,19 @@ impl FuseOnWriteBuilder { return false; } - for i in 0..(rank - 1) { + for i in 0..rank { let curr = self.current_output_shape[i]; let new = out.shape[i]; // Broadcast is supported. // // 0 is the shape id for a global shape of 1. - if curr != new && new != 0 { + if curr != new && new != 0 && curr != 0 { + println!("ALLO {curr} {new}"); return false; } + + self.current_output_shape[0] = usize::max(curr, new); } true diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index b8bac76235..e1414b513c 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -84,7 +84,7 @@ pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: Arg) pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { match arg { Arg::ScalarShape(pos) => { - let offset = comptime![{ inputs.s_u32.len() - pos - 1 }]; + let offset = comptime![inputs.s_u32.len() - pos - 1]; *inputs.s_u32.index(offset) } _ => comptime![panic!["Not a scalar shape"]], @@ -498,94 +498,94 @@ fn get_offset( Arg::Input(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = inputs.t_f32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::F16 => { let layout = inputs.t_f16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = inputs.t_bf16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U64 => { let layout = inputs.t_u64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U32 => { let layout = inputs.t_u32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U16 => { let layout = inputs.t_u16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U8 => { let layout = inputs.t_u8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I64 => { let layout = inputs.t_i64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I32 => { let layout = inputs.t_i32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I16 => { let layout = inputs.t_i16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I8 => { let layout = inputs.t_i8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, Arg::Output(index, precision, _) => match comptime![precision] { ElemwisePrecision::F32 => { let layout = outputs.t_f32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::F16 => { let layout = outputs.t_f16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::BF16 => { let layout = outputs.t_bf16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U64 => { let layout = outputs.t_u64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U32 => { let layout = outputs.t_u32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U16 => { let layout = outputs.t_u16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::U8 => { let layout = outputs.t_u8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I64 => { let layout = outputs.t_i64.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I32 => { let layout = outputs.t_i32.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I16 => { let layout = outputs.t_i16.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } ElemwisePrecision::I8 => { let layout = outputs.t_i8.index(index); - index_offset_with_layout(inputs, tensor, layout, pos, 0, config.rank, shape) + index_offset_with_layout(inputs, tensor, layout, pos, config.rank, shape) } _ => comptime![panic!("Unsupported precision {precision:?}")], }, @@ -764,29 +764,81 @@ fn index_offset_with_layout( inputs: &GlobalArgs, tensor: &Tensor>, layout: &Tensor>, - offset_layout: u32, - dim_start: u32, - dim_end: u32, + index: u32, + #[comptime] rank: u32, #[comptime] shape: Option>, ) -> u32 { - let offset_ref = offset_layout * tensor.line_size(); + // Need to unroll when fusing a reshape. + match comptime![shape.clone()] { + Some(shape) => { + let index_standard = reshaped_index_standard(inputs, layout, index, rank, shape); + convert_index_standard_to_original_index(tensor, rank, index_standard) + } + None => { + let index = index * tensor.line_size(); + let mut offset = 0u32; + + for i in 0..rank { + let coordinate = layout.coordinate(index, i); + offset += coordinate * tensor.stride(i); + } + + let offset = offset / tensor.line_size(); + + offset + } + } +} + +#[cube] +fn reshaped_index_standard( + inputs: &GlobalArgs, + layout: &Tensor>, + index: u32, + #[comptime] rank: u32, + #[comptime] shape: Sequence, +) -> u32 { + let index = index * layout.line_size(); let mut offset = 0u32; + let mut stride_curr = 1u32; + let shapes = comptime![shape.rev()]; - // Need to unroll when fusing a reshape. - let unroll = comptime![shape.is_some()]; + // let mut j = comptime![0u32]; + + #[unroll] + for r in 0..rank { + let arg = comptime![shapes.index(r.clone())]; + let i = rank - r; + let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); + let coordinate = layout.coordinate(index, i); + + offset += coordinate * stride_curr; + stride_curr *= shape_i; + } + + let offset = offset / layout.line_size(); + + offset +} + +#[cube] +fn convert_index_standard_to_original_index( + original: &Tensor>, + rank: u32, + index_standard: u32, +) -> u32 { + let mut remaining = index_standard; + let mut index = 0; + + #[unroll] + for i in 0..rank { + let shape = original.shape(i); + let stride = original.stride(i); + let coordinate = remaining % shape; - #[unroll(unroll)] - for i in dim_start..dim_end { - let shape_i = match comptime![shape.clone()] { - Some(s) => { - let arg = comptime![s.index(i.clone())]; - read_scalar_shape(inputs, comptime![arg.clone()]) - } - None => tensor.shape(i), - }; - let ogwl = offset_ref / layout.stride(i); - offset += ogwl % shape_i * tensor.stride(i); + remaining /= shape; + index += coordinate * stride; } - offset / tensor.line_size() + index } diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 1656a3a1a2..64b869fc44 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -19,9 +19,9 @@ pub struct FuseOnWriteTrace { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes: Vec, + shapes: Vec, ops: Vec, - reads: BTreeMap, + reads: BTreeMap>, writes: BTreeMap, inputs_unhandled: Vec, } @@ -104,7 +104,7 @@ struct LaunchAnalysis<'a, R: JitRuntime> { handle_inputs: Vec>, handle_outputs: Vec>, reference: Option, - reads: BTreeMap, + reads: BTreeMap>, writes: BTreeMap, rank: usize, vectorization: u8, @@ -162,9 +162,12 @@ impl FuseOnWriteTrace { let outputs = self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); - let mut ops = Sequence::new(); - for op in analysis.reads.into_values() { - ops.push(op); + let mut ops = Sequence::::new(); + + for read_ops in analysis.reads.into_values() { + for op in read_ops { + ops.push(op); + } } for op in self.ops.iter() { @@ -261,6 +264,7 @@ impl FuseOnWriteTrace { if status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) + && !self.shapes.contains(&tensor_relative.id) { analysis.potential_inplaces.push(PotentialInplace { input_pos: i, @@ -321,11 +325,13 @@ impl FuseOnWriteTrace { strides: handle_input.handle.strides.clone(), }); - if let Some(ElemwiseOp::Assign(op)) = - analysis.reads.get_mut(&handle_input.relative_id) - { - op.input.add_layout_info(LayoutInfo::IsRef); - }; + if let Some(ops) = analysis.reads.get_mut(&handle_input.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::IsRef); + }; + } + } if let Some(ElemwiseOp::Assign(op)) = analysis.writes.get_mut(&tensor_relative.id) @@ -402,8 +408,12 @@ impl FuseOnWriteTrace { for hi in analysis.handle_inputs.iter() { if let Some(reference) = analysis.reference.as_ref() { if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { - if let Some(ElemwiseOp::Assign(op)) = analysis.reads.get_mut(&hi.relative_id) { - op.input.add_layout_info(LayoutInfo::SameAsRef); + if let Some(ops) = analysis.reads.get_mut(&hi.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::SameAsRef); + } + } } } } @@ -496,16 +506,13 @@ impl FuseOnWriteTrace { } } - let mut tmp = vec![]; for relative in self.shapes.iter().rev() { - let global = context.tensors.get(&relative.id).unwrap(); + let global = context.tensors.get(relative).unwrap(); for shape in global.shape.iter().rev() { - tmp.push(shape); inputs.s_u32.push(ScalarArg::new(*shape as u32)) } } - println!("Shape {:?}", tmp); inputs } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 0f653b6033..1c25320fff 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -15,9 +15,9 @@ pub struct FuseOnWriteTraceBuilder { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes: Vec, + shapes: Vec, ops: Vec, - reads: BTreeMap, + reads: BTreeMap>, pub bool_precision: ElemwisePrecision, outputs_unhandled: Vec, inputs_unhandled: Vec, @@ -99,13 +99,17 @@ impl FuseOnWriteTraceBuilder { let out = self.locals.create(precision, tensor.id); let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); - self.reads.insert( - tensor.id, - ElemwiseOp::Assign(UnaryElemwiseArgs { - input, - out: out.clone(), - }), - ); + let reads = if !self.reads.contains_key(&tensor.id) { + self.reads.insert(tensor.id, Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; + + reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + })); out } @@ -146,52 +150,56 @@ impl FuseOnWriteTraceBuilder { _ => precision, }; - match self.locals.get(precision, tensor.id) { - // Can't fused an already fused input. - // - // TODO: Can fuse one that is in global memory. + let input_index = match self.locals.get(precision, tensor.id) { Some(_) => { + // Can't fused an already fused input. if self.outputs.get(precision_input, tensor.id).is_some() { return None; } - // self.inputs.update(precision_input, tensor); - None + match self.inputs.get_index(precision_input, tensor.id) { + Some(index) => { + self.inputs.update(precision_input, tensor); + index as u32 + } + None => return None, + } } - None => { - let new_input = self.inputs.insert(precision_input, tensor.clone()); - let out = self.locals.create(precision, tensor.id); - let original = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); + None => self.inputs.insert(precision_input, tensor.clone()), + }; - let mut shape = Sequence::new(); + let out = self.locals.create(precision, tensor.id); + let original = Arg::Input(input_index, precision_input, LayoutInfo::Unknown); - let index = self.shapes.len(); - self.shapes.push(output.clone()); - let rank = output.shape.len(); + let mut shape = Sequence::new(); - println!("output {output:?}"); - for i in 0..output.shape.len() { - let id = index * rank + i; - println!("id {id:?}"); - shape.push(Arg::ScalarShape(id as u32)); - } + let index = self.shapes.len(); + self.shapes.push(output.id.clone()); + let rank = output.shape.len(); - let input = Arg::InputReshaped { - original: Box::new(original), - shape, - }; + for i in 0..output.shape.len() { + let id = index * rank + i; + shape.push(Arg::ScalarShape(id as u32)); + } - self.reads.insert( - tensor.id, - ElemwiseOp::Assign(UnaryElemwiseArgs { - input, - out: out.clone(), - }), - ); + let input = Arg::InputReshaped { + original: Box::new(original), + shape, + }; - Some(out) - } - } + let reads = if !self.reads.contains_key(&tensor.id) { + self.reads.insert(tensor.id, Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; + + reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { + input, + out: out.clone(), + })); + + Some(out) } pub fn scalar(&mut self, _: &E, dtype: DType) -> Arg { @@ -404,8 +412,10 @@ impl FuseOnWriteTraceBuilder { }; // For all operators, mark their local tensor id in the proper set. - for (_, op) in self.reads.iter() { - mark_op(op); + for (_, ops) in self.reads.iter() { + for op in ops { + mark_op(op); + } } for op in self.ops.iter() { diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index bd9ffbf860..aa31b90c34 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -14,6 +14,67 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_reshape_maybe_fused() { + let tensor = TestTensorInt::arange(0..32, &Default::default()); + let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); + let tensor1 = tensor.clone().reshape([1, 4, 8]); + let tensor2 = tensor.reshape([8, 4, 1]); + + let output = tensor0 + tensor1 + tensor2; + let expected = TensorData::from([ + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [9, 10, 11, 12, 13, 14, 15, 16], + [18, 19, 20, 21, 22, 23, 24, 25], + [27, 28, 29, 30, 31, 32, 33, 34], + ], + [ + [4, 5, 6, 7, 8, 9, 10, 11], + [13, 14, 15, 16, 17, 18, 19, 20], + [22, 23, 24, 25, 26, 27, 28, 29], + [31, 32, 33, 34, 35, 36, 37, 38], + ], + [ + [8, 9, 10, 11, 12, 13, 14, 15], + [17, 18, 19, 20, 21, 22, 23, 24], + [26, 27, 28, 29, 30, 31, 32, 33], + [35, 36, 37, 38, 39, 40, 41, 42], + ], + [ + [12, 13, 14, 15, 16, 17, 18, 19], + [21, 22, 23, 24, 25, 26, 27, 28], + [30, 31, 32, 33, 34, 35, 36, 37], + [39, 40, 41, 42, 43, 44, 45, 46], + ], + [ + [16, 17, 18, 19, 20, 21, 22, 23], + [25, 26, 27, 28, 29, 30, 31, 32], + [34, 35, 36, 37, 38, 39, 40, 41], + [43, 44, 45, 46, 47, 48, 49, 50], + ], + [ + [20, 21, 22, 23, 24, 25, 26, 27], + [29, 30, 31, 32, 33, 34, 35, 36], + [38, 39, 40, 41, 42, 43, 44, 45], + [47, 48, 49, 50, 51, 52, 53, 54], + ], + [ + [24, 25, 26, 27, 28, 29, 30, 31], + [33, 34, 35, 36, 37, 38, 39, 40], + [42, 43, 44, 45, 46, 47, 48, 49], + [51, 52, 53, 54, 55, 56, 57, 58], + ], + [ + [28, 29, 30, 31, 32, 33, 34, 35], + [37, 38, 39, 40, 41, 42, 43, 44], + [46, 47, 48, 49, 50, 51, 52, 53], + [55, 56, 57, 58, 59, 60, 61, 62], + ], + ]); + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_reshape_int() { let data = TensorData::from([0, 1, 2]); From c83782b94761d9b9e381980a9f7f7dd4113fd83a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 28 Jan 2025 15:51:23 -0500 Subject: [PATCH 05/28] WIP works better --- .../src/fusion/elemwise/optimization.rs | 1 + crates/burn-jit/src/fusion/on_write/io.rs | 48 ++++++++------- crates/burn-jit/src/fusion/on_write/trace.rs | 4 +- .../src/kernel/quantization/dequantize.rs | 2 +- .../src/kernel/quantization/quantize.rs | 3 + crates/burn-tensor/src/tests/ops/reshape.rs | 60 +++++++++++++++++++ 6 files changed, 96 insertions(+), 22 deletions(-) diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index bd76633c74..b732847dc6 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -115,6 +115,7 @@ impl TraceRunner for ElemwiseRunner { let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); + println!("{shape:?} - {total_elem:?} - {cube_count:?} - {vectorization}"); unsafe { elemwise_fuse::launch_unchecked( diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index e1414b513c..ad771dba0b 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -1,5 +1,8 @@ use super::ir::*; -use cubecl::prelude::*; +use cubecl::{ + ir::{ExpandElement, Variable}, + prelude::*, +}; #[cube] /// Read the value from the [arg](Arg) and cast it to the generic cube primitive. @@ -355,7 +358,7 @@ pub fn write( } }; let tensor = outputs.t_f16.index_mut(pos); - tensor[offset] = Line::cast_from(value); + tensor[offset] = Line::cast_from(offset); } ElemwisePrecision::BF16 => { let tensor = outputs.t_bf16.index(pos); @@ -775,17 +778,15 @@ fn index_offset_with_layout( convert_index_standard_to_original_index(tensor, rank, index_standard) } None => { - let index = index * tensor.line_size(); + let offset_ref = index * tensor.line_size(); let mut offset = 0u32; - for i in 0..rank { - let coordinate = layout.coordinate(index, i); - offset += coordinate * tensor.stride(i); + for i in 0u32..rank { + let coordinate_broadcasted = (offset_ref / layout.stride(i)) % tensor.shape(i); + offset += coordinate_broadcasted * tensor.stride(i); } - let offset = offset / tensor.line_size(); - - offset + offset / tensor.line_size() } } } @@ -801,26 +802,33 @@ fn reshaped_index_standard( let index = index * layout.line_size(); let mut offset = 0u32; let mut stride_curr = 1u32; - let shapes = comptime![shape.rev()]; - - // let mut j = comptime![0u32]; #[unroll] for r in 0..rank { - let arg = comptime![shapes.index(r.clone())]; - let i = rank - r; + let i = comptime![index_i(rank, r)]; + let arg = comptime![shape.index(i.clone())]; let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); - let coordinate = layout.coordinate(index, i); - offset += coordinate * stride_curr; + let ogwl = index / layout.stride(i); + offset += ogwl % shape_i * stride_curr; + stride_curr *= shape_i; } - let offset = offset / layout.line_size(); - offset } +fn index_i>>(rank: u32, iter: Elem) -> ExpandElementTyped { + let elem = iter.into(); + let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); + let result = rank - elem - 1; + println!("Result rank {rank:?} elem {elem:?} {result:?}"); + let scalar: Variable = result.into(); + let expand: ExpandElement = ExpandElement::Plain(scalar); + + expand.into() +} + #[cube] fn convert_index_standard_to_original_index( original: &Tensor>, @@ -830,7 +838,7 @@ fn convert_index_standard_to_original_index( let mut remaining = index_standard; let mut index = 0; - #[unroll] + // #[unroll] for i in 0..rank { let shape = original.shape(i); let stride = original.stride(i); @@ -840,5 +848,5 @@ fn convert_index_standard_to_original_index( index += coordinate * stride; } - index + index / original.line_size() } diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 64b869fc44..f074a93b6e 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -50,8 +50,8 @@ pub trait TraceRunner { ) -> u8 { // The default version uses the last dimension as vectorization axis and assumes a // perpendicular contiguous line. - let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { + println!("Desc Input {desc:?}"); let rank = handle.strides.len(); // Last dimension strides should be 1, otherwise vecX won't be contiguous. @@ -70,6 +70,7 @@ pub trait TraceRunner { }; let vectorization_output = |desc: &TensorDescription| { + println!("Desc Output {desc:?}"); let rank = desc.shape.len(); for s in R::line_size_elem(&desc.dtype.into()) { @@ -510,6 +511,7 @@ impl FuseOnWriteTrace { let global = context.tensors.get(relative).unwrap(); for shape in global.shape.iter().rev() { + println!("{shape}"); inputs.s_u32.push(ScalarArg::new(*shape as u32)) } } diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..a29f38c04d 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -10,7 +10,7 @@ use super::{QParams, QTensor}; #[cube] pub(crate) fn dequantize_affine_int8( - value: Line, + value: Line,// 4 i32 scale: f32, offset: i32, ) -> Line { diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..a3492e1c23 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -24,6 +24,8 @@ pub(crate) fn quantize_affine_int8( ) } +/// 32 bits encoder en f32 => int8 +/// u32 => 4 valeurs quantizer #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_affine_int8_kernel( input: &Tensor>, @@ -78,6 +80,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( } } + #[cube] pub(crate) fn quantize_symmetric_int8( value: Line, diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index aa31b90c34..a1c48742c2 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -14,6 +14,66 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_reshape_maybe_fused_1() { + let tensor = TestTensorInt::arange(0..32, &Default::default()); + let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); + let tensor1 = tensor.clone().reshape([1, 4, 8]); + + let output = tensor0 + tensor1; + let expected = TensorData::from([ + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + [ + [0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31], + ], + ]); + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_reshape_maybe_fused() { let tensor = TestTensorInt::arange(0..32, &Default::default()); From 96d2ff0650c2280f62c82e1af22ba11bc520bae7 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 28 Jan 2025 17:52:49 -0500 Subject: [PATCH 06/28] Fix vectorization --- crates/burn-fusion/src/stream/context.rs | 5 +++-- .../burn-jit/src/fusion/elemwise/builder.rs | 1 - .../src/fusion/elemwise/optimization.rs | 2 -- .../burn-jit/src/fusion/on_write/builder.rs | 10 +++------ crates/burn-jit/src/fusion/on_write/io.rs | 2 -- crates/burn-jit/src/fusion/on_write/trace.rs | 22 ++++++++++++++----- .../src/fusion/on_write/trace_builder.rs | 15 +++++++------ .../src/kernel/quantization/dequantize.rs | 2 +- .../src/kernel/quantization/quantize.rs | 1 - crates/burn-tensor/src/tests/ops/reshape.rs | 3 +-- 10 files changed, 32 insertions(+), 31 deletions(-) diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index eeeeacbb91..21c2f69a9f 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -42,8 +42,6 @@ pub struct Context<'a, H> { pub(crate) struct OperationConverter { tensors_relative2global: HashMap, tensors_global2relative: HashMap, - /// Only useful to create new shape ID. - /// You should use tensor descriptions to retrieve the proper shape. shapes_global2relative: HashMap, scalar_f32: Vec, scalar_f16: Vec, @@ -205,9 +203,11 @@ impl OperationConverter { pub(crate) fn clear(&mut self) { self.tensors_relative2global.clear(); self.tensors_global2relative.clear(); + self.shapes_global2relative.clear(); // global 1 is always shape id 0. self.shapes_global2relative.insert(1, 0); + self.scalar_f32.clear(); self.scalar_f16.clear(); self.scalar_bf16.clear(); @@ -1186,6 +1186,7 @@ impl RelativeOps for TensorDescription { // We never saw this dim value before, therefore we create a new ID. let dim_id = converter.shapes_global2relative.len(); relative_shape.push(dim_id); + converter.shapes_global2relative.insert(*dim, dim_id); } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 9e817a9ac6..61adef4f29 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -31,7 +31,6 @@ impl ElementWiseBuilder { impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_tensor::repr::OperationDescription) { - println!("{:?}", operation); self.builder.register(operation) } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index b732847dc6..bc67e9e945 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -28,7 +28,6 @@ pub struct ElemwiseOptimizationState { impl ElemwiseOptimization { /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { - println!("{:?}", self.trace); self.trace .run::(&self.client, &self.device, context, &ElemwiseRunner) .unwrap(); @@ -115,7 +114,6 @@ impl TraceRunner for ElemwiseRunner { let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); - println!("{shape:?} - {total_elem:?} - {cube_count:?} - {vectorization}"); unsafe { elemwise_fuse::launch_unchecked( diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index ff7a7b8116..4c5dfe911e 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -63,8 +63,8 @@ impl TryFuseBuilder { true } - fn build(&self) -> FuseOnWriteTrace { - self.builder.build() + fn build(&self, shape: Vec) -> FuseOnWriteTrace { + self.builder.build(shape) } } @@ -116,7 +116,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { - self.builder.build() + self.builder.build(self.current_output_shape.clone()) } fn len(&self) -> usize { @@ -182,10 +182,6 @@ impl FuseOnWriteBuilder { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), BaseOperationDescription::Reshape(desc) => { - // if self.current_output_shape.is_empty() { - // return false; - // } - if !self.output_is_compatible(&desc.out) { return false; } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index ad771dba0b..7085643c35 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -822,7 +822,6 @@ fn index_i>>(rank: u32, iter: Elem) -> Expand let elem = iter.into(); let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); let result = rank - elem - 1; - println!("Result rank {rank:?} elem {elem:?} {result:?}"); let scalar: Variable = result.into(); let expand: ExpandElement = ExpandElement::Plain(scalar); @@ -838,7 +837,6 @@ fn convert_index_standard_to_original_index( let mut remaining = index_standard; let mut index = 0; - // #[unroll] for i in 0..rank { let shape = original.shape(i); let stride = original.stride(i); diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index f074a93b6e..1157480e29 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -19,7 +19,8 @@ pub struct FuseOnWriteTrace { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes: Vec, + shapes_reshape: Vec, + shape_ref: Vec, ops: Vec, reads: BTreeMap>, writes: BTreeMap, @@ -47,11 +48,11 @@ pub trait TraceRunner { handles_inputs: impl Iterator>, inputs: impl Iterator, outputs: impl Iterator, + reshaped: impl Iterator, ) -> u8 { // The default version uses the last dimension as vectorization axis and assumes a // perpendicular contiguous line. let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { - println!("Desc Input {desc:?}"); let rank = handle.strides.len(); // Last dimension strides should be 1, otherwise vecX won't be contiguous. @@ -70,7 +71,6 @@ pub trait TraceRunner { }; let vectorization_output = |desc: &TensorDescription| { - println!("Desc Output {desc:?}"); let rank = desc.shape.len(); for s in R::line_size_elem(&desc.dtype.into()) { @@ -93,6 +93,10 @@ pub trait TraceRunner { output = Ord::min(vectorization_output(tensor), output); } + for tensor in reshaped { + output = Ord::min(vectorization_output(tensor), output); + } + output } } @@ -240,10 +244,16 @@ impl FuseOnWriteTrace { self.analyse_inputs(context, &mut analysis); self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); + let tensors_reshaped = self + .shapes_reshape + .iter() + .map(|id| context.tensors.get(id).unwrap()); + analysis.vectorization = Runner::vectorization( analysis.handle_inputs.iter().map(|item| &item.handle), analysis.global_inputs.iter(), analysis.global_outputs.iter(), + tensors_reshaped, ); analysis @@ -265,7 +275,8 @@ impl FuseOnWriteTrace { if status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) - && !self.shapes.contains(&tensor_relative.id) + && !self.shapes_reshape.contains(&tensor_relative.id) + && self.shape_ref == tensor_relative.shape { analysis.potential_inplaces.push(PotentialInplace { input_pos: i, @@ -507,11 +518,10 @@ impl FuseOnWriteTrace { } } - for relative in self.shapes.iter().rev() { + for relative in self.shapes_reshape.iter().rev() { let global = context.tensors.get(relative).unwrap(); for shape in global.shape.iter().rev() { - println!("{shape}"); inputs.s_u32.push(ScalarArg::new(*shape as u32)) } } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 1c25320fff..d60be69825 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -15,7 +15,7 @@ pub struct FuseOnWriteTraceBuilder { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes: Vec, + shapes_reshape: Vec, ops: Vec, reads: BTreeMap>, pub bool_precision: ElemwisePrecision, @@ -30,7 +30,7 @@ impl FuseOnWriteTraceBuilder { outputs: RegisteredTensors::default(), inputs: RegisteredTensors::default(), scalars: BTreeMap::default(), - shapes: Vec::default(), + shapes_reshape: Vec::new(), ops: Vec::new(), reads: BTreeMap::new(), bool_precision, @@ -173,8 +173,8 @@ impl FuseOnWriteTraceBuilder { let mut shape = Sequence::new(); - let index = self.shapes.len(); - self.shapes.push(output.id.clone()); + let index = self.shapes_reshape.len(); + self.shapes_reshape.push(output.id.clone()); let rank = output.shape.len(); for i in 0..output.shape.len() { @@ -218,7 +218,7 @@ impl FuseOnWriteTraceBuilder { Arg::Scalar(new_index, precision) } - pub fn build(&self) -> FuseOnWriteTrace { + pub fn build(&self, shape: Vec) -> FuseOnWriteTrace { let inputs = self.inputs.clone(); let outputs = self.output_tensors(); let ops = self.ops.clone(); @@ -240,14 +240,15 @@ impl FuseOnWriteTraceBuilder { ); } - let shapes = self.shapes.clone(); + let shapes_reshape = self.shapes_reshape.clone(); // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, inputs, scalars, - shapes, + shapes_reshape, + shape, ops, reads, writes, diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index a29f38c04d..c8cac2860e 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -10,7 +10,7 @@ use super::{QParams, QTensor}; #[cube] pub(crate) fn dequantize_affine_int8( - value: Line,// 4 i32 + value: Line, // 4 i32 scale: f32, offset: i32, ) -> Line { diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index a3492e1c23..ec064ed5fa 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -80,7 +80,6 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( } } - #[cube] pub(crate) fn quantize_symmetric_int8( value: Line, diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index a1c48742c2..3a31bb5fc9 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -77,11 +77,10 @@ mod tests { #[test] fn should_support_reshape_maybe_fused() { let tensor = TestTensorInt::arange(0..32, &Default::default()); - let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); let tensor1 = tensor.clone().reshape([1, 4, 8]); let tensor2 = tensor.reshape([8, 4, 1]); - let output = tensor0 + tensor1 + tensor2; + let output = tensor1 + tensor2; let expected = TensorData::from([ [ [0, 1, 2, 3, 4, 5, 6, 7], From c431a4347d9de4c11ba95f35e7aa1df2fda28c0b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Tue, 28 Jan 2025 18:52:12 -0500 Subject: [PATCH 07/28] Still debug --- crates/burn-fusion/src/ops/boolean.rs | 8 +++---- crates/burn-fusion/src/ops/float.rs | 4 ++-- crates/burn-fusion/src/ops/int.rs | 4 ++-- crates/burn-fusion/src/stream/context.rs | 2 +- .../burn-jit/src/fusion/elemwise/builder.rs | 1 + .../burn-jit/src/fusion/on_write/builder.rs | 24 ++++++++++++++++--- crates/burn-jit/src/fusion/on_write/io.rs | 6 ++--- .../src/fusion/on_write/trace_builder.rs | 6 ++++- crates/burn-router/src/ops/op_bool.rs | 6 ++--- crates/burn-router/src/ops/op_float.rs | 9 ++++--- crates/burn-router/src/ops/op_int.rs | 9 ++++--- crates/burn-tensor/src/repr/operation.rs | 9 +------ 12 files changed, 51 insertions(+), 37 deletions(-) diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index baa5169db3..1857a85e5b 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -16,8 +16,8 @@ use burn_tensor::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, HandleContainer, OperationDescription, PermuteOperationDescription, - RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, - SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + RepeatDimOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }, Device, Shape, }; @@ -145,7 +145,7 @@ impl BoolTensorOps for Fusion { fn bool_reshape(tensor: BoolTensor, shape: Shape) -> BoolTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -160,7 +160,7 @@ impl BoolTensorOps for Fusion { let stream = tensor.stream; let out = tensor.client.tensor_uninitialized(shape.dims, DType::Bool); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index b3e2a80432..493224fa36 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -615,7 +615,7 @@ impl FloatTensorOps for Fusion { fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -631,7 +631,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let out = tensor.client.tensor_uninitialized(shape.dims, dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bdb47df02c..82343f7a1b 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -69,7 +69,7 @@ impl IntTensorOps for Fusion { fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { #[derive(new)] struct ReshapeDimsOps { - desc: ReshapeDescription, + desc: UnaryOperationDescription, _b: PhantomData, } @@ -86,7 +86,7 @@ impl IntTensorOps for Fusion { .client .tensor_uninitialized(shape.dims, B::IntElem::dtype()); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 21c2f69a9f..1bb0bc5deb 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1077,7 +1077,7 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::ToDevice(desc.to_relative(converter)) } BaseOperationDescription::Reshape(desc) => { - BaseOperationDescription::Reshape(ReshapeDescription { + BaseOperationDescription::Reshape(UnaryOperationDescription { input: desc.input.to_relative(converter), out: desc.out.to_relative(converter), }) diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 61adef4f29..c36599427d 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -31,6 +31,7 @@ impl ElementWiseBuilder { impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_tensor::repr::OperationDescription) { + println!("op {operation:?}"); self.builder.register(operation) } diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 4c5dfe911e..c6091eea34 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -105,6 +105,12 @@ impl OptimizationBuilder for FuseOnWriteBuilder { return; } } + OperationDescription::BaseBool(ops) => { + if !self.register_base(ops) { + self.status = OptimizationStatus::Closed; + return; + } + } _ => { self.status = OptimizationStatus::Closed; return; @@ -116,7 +122,9 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { - self.builder.build(self.current_output_shape.clone()) + let trace = self.builder.build(self.current_output_shape.clone()); + println!("Trace {trace:?}"); + trace } fn len(&self) -> usize { @@ -182,6 +190,12 @@ impl FuseOnWriteBuilder { ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) }), BaseOperationDescription::Reshape(desc) => { + if desc.input.shape == desc.out.shape { + return self.register_unary_ops(desc, |input, out| { + ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }) + }); + } + if !self.output_is_compatible(&desc.out) { return false; } @@ -485,9 +499,12 @@ impl FuseOnWriteBuilder { // Rank should be equal. if rank != out.shape.len() { + println!("Not same rank"); return false; } + let mut updated = self.current_output_shape.clone(); + for i in 0..rank { let curr = self.current_output_shape[i]; let new = out.shape[i]; @@ -496,12 +513,13 @@ impl FuseOnWriteBuilder { // // 0 is the shape id for a global shape of 1. if curr != new && new != 0 && curr != 0 { - println!("ALLO {curr} {new}"); + println!("Not compatible"); return false; } - self.current_output_shape[0] = usize::max(curr, new); + updated[0] = usize::max(curr, new); } + core::mem::swap(&mut updated, &mut self.current_output_shape); true } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 7085643c35..a89b389aff 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -358,7 +358,7 @@ pub fn write( } }; let tensor = outputs.t_f16.index_mut(pos); - tensor[offset] = Line::cast_from(offset); + tensor[offset] = Line::cast_from(value); } ElemwisePrecision::BF16 => { let tensor = outputs.t_bf16.index(pos); @@ -782,8 +782,8 @@ fn index_offset_with_layout( let mut offset = 0u32; for i in 0u32..rank { - let coordinate_broadcasted = (offset_ref / layout.stride(i)) % tensor.shape(i); - offset += coordinate_broadcasted * tensor.stride(i); + let ogwl = offset_ref / layout.stride(i); + offset += ogwl % tensor.shape(i) * tensor.stride(i); } offset / tensor.line_size() diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index d60be69825..a6c5d161fd 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -154,6 +154,7 @@ impl FuseOnWriteTraceBuilder { Some(_) => { // Can't fused an already fused input. if self.outputs.get(precision_input, tensor.id).is_some() { + println!("Can't fused an already fused input."); return None; } @@ -162,7 +163,10 @@ impl FuseOnWriteTraceBuilder { self.inputs.update(precision_input, tensor); index as u32 } - None => return None, + None => { + println!("HM"); + return None; + } } } None => self.inputs.insert(precision_input, tensor.clone()), diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 25c46ae854..b5ec3660ae 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -5,8 +5,8 @@ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -81,7 +81,7 @@ impl BoolTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index dda01990e0..10bddc3803 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -11,10 +11,9 @@ use burn_tensor::repr::{ FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + RepeatDimOperationDescription, ScalarOperationDescription, ScatterOperationDescription, + SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -408,7 +407,7 @@ impl FloatTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index db81602d4f..9aa3bf2dc8 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -11,10 +11,9 @@ use burn_tensor::repr::{ GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + RepeatDimOperationDescription, ScalarOperationDescription, ScatterOperationDescription, + SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -63,7 +62,7 @@ impl IntTensorOps for BackendRouter { let client = tensor.client.clone(); let out = client.register_empty_tensor(shape.into(), tensor.dtype); - let desc = ReshapeDescription { + let desc = UnaryOperationDescription { input: tensor.into_description(), out: out.to_description_out(), }; diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 001b9d6e83..58a8d83fce 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -208,7 +208,7 @@ pub enum BaseOperationDescription { /// Float => [reshape](crate::ops::FloatTensorOps::float_reshape). /// Int => [reshape](crate::ops::IntTensorOps::int_reshape). /// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape). - Reshape(ReshapeDescription), + Reshape(UnaryOperationDescription), /// Operation corresponding to: /// @@ -586,13 +586,6 @@ pub struct RandomOperationDescription { pub distribution: Distribution, } -#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct ReshapeDescription { - pub input: TensorDescription, - pub out: TensorDescription, -} - #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ExpandDescription { From 4aeb900e51bb2fb8268b4a1e59e7e28d4613794b Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 29 Jan 2025 15:55:27 -0500 Subject: [PATCH 08/28] Fix some problems --- crates/burn-jit/src/fusion/elemwise/optimization.rs | 1 + crates/burn-jit/src/fusion/on_write/builder.rs | 8 ++------ crates/burn-jit/src/fusion/on_write/trace.rs | 10 +++++++++- crates/burn-tensor/src/tensor/api/bool.rs | 4 +++- crates/burn-tensor/src/tensor/api/numeric.rs | 2 +- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index bc67e9e945..6c48b3a2fd 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -111,6 +111,7 @@ impl TraceRunner for ElemwiseRunner { None => panic!("Invalid argument"), }; + println!("Shape {shape:?} - {vectorization}"); let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index c6091eea34..e32ac10057 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -122,9 +122,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { - let trace = self.builder.build(self.current_output_shape.clone()); - println!("Trace {trace:?}"); - trace + self.builder.build(self.current_output_shape.clone()) } fn len(&self) -> usize { @@ -499,7 +497,6 @@ impl FuseOnWriteBuilder { // Rank should be equal. if rank != out.shape.len() { - println!("Not same rank"); return false; } @@ -513,11 +510,10 @@ impl FuseOnWriteBuilder { // // 0 is the shape id for a global shape of 1. if curr != new && new != 0 && curr != 0 { - println!("Not compatible"); return false; } - updated[0] = usize::max(curr, new); + updated[i] = usize::max(curr, new); } core::mem::swap(&mut updated, &mut self.current_output_shape); diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 1157480e29..382df84356 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -304,7 +304,15 @@ impl FuseOnWriteTrace { context: &mut Context<'_, JitFusionHandle>, analysis: &mut LaunchAnalysis<'a, R>, ) { - for (precision, tensor_relative) in self.outputs.iter() { + let mut output_sorted: Vec<_> = self.outputs.iter().collect(); + output_sorted.sort_by(|(_, a), (_, b)| { + let a_val: usize = a.shape.iter().sum(); + let b_val: usize = b.shape.iter().sum(); + + b_val.cmp(&a_val) + }); + + for (precision, tensor_relative) in output_sorted { let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let strides = strides_dyn_rank(&tensor_global.shape); diff --git a/crates/burn-tensor/src/tensor/api/bool.rs b/crates/burn-tensor/src/tensor/api/bool.rs index ea7c5b196d..e89ad38179 100644 --- a/crates/burn-tensor/src/tensor/api/bool.rs +++ b/crates/burn-tensor/src/tensor/api/bool.rs @@ -122,7 +122,9 @@ where }; // Generate and return the mask by applying the comparison to the matrix. - compare(matrix, 0).unsqueeze() + // println!("Matrix {matrix}"); + let out = compare(matrix, 0); + out.unsqueeze() } /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index b82175c3fe..e294b5b02f 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -1449,8 +1449,8 @@ where // last two dimensions let shape = &self.shape().dims[D - 2..].to_owned(); - let mask = Tensor::::tril_mask(shape, diagonal, &self.device()).unsqueeze(); + self.mask_fill(mask, 0) } From e43bef08a378325b2e582a3732ed846964d36cd8 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 29 Jan 2025 16:37:50 -0500 Subject: [PATCH 09/28] Fix other broadcast issues --- crates/burn-jit/src/fusion/on_write/trace.rs | 27 +++++++++++++++----- crates/burn-tensor/src/tensor/api/numeric.rs | 1 + crates/burn-tensor/src/tests/stats/eye.rs | 2 ++ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 382df84356..93ea9ae6b1 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -304,15 +304,22 @@ impl FuseOnWriteTrace { context: &mut Context<'_, JitFusionHandle>, analysis: &mut LaunchAnalysis<'a, R>, ) { - let mut output_sorted: Vec<_> = self.outputs.iter().collect(); - output_sorted.sort_by(|(_, a), (_, b)| { + let mut output_sorted: Vec<_> = self.outputs.iter().enumerate().collect(); + output_sorted.sort_by(|(_, (_, a)), (_, (_, b))| { let a_val: usize = a.shape.iter().sum(); let b_val: usize = b.shape.iter().sum(); b_val.cmp(&a_val) }); + let mut handles = Vec::with_capacity(self.outputs.len()); + let mut globals = Vec::with_capacity(self.outputs.len()); - for (precision, tensor_relative) in output_sorted { + for _ in 0..self.outputs.len() { + handles.push(None); + globals.push(None); + } + + for (position_original, (precision, tensor_relative)) in output_sorted.into_iter() { let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let strides = strides_dyn_rank(&tensor_global.shape); @@ -363,11 +370,12 @@ impl FuseOnWriteTrace { context .handles .register_handle(tensor_global.id, handle_input.handle.clone()); - analysis.handle_outputs.push(HandleOutput::Alias { + + handles[position_original] = Some(HandleOutput::Alias { input_pos: potential_inplace.input_pos, precision, }); - analysis.global_outputs.push(tensor_global); + globals[position_original] = Some(tensor_global); } else { if analysis.reference.is_none() { analysis.reference = Some(Reference { @@ -411,16 +419,21 @@ impl FuseOnWriteTrace { .handles .register_handle(tensor_global.id, handle.clone()); - analysis.handle_outputs.push(HandleOutput::Owned { + handles[position_original] = Some(HandleOutput::Owned { precision, handle, global_shape: tensor_global.shape.clone(), global_id: tensor_global.id, }); - analysis.global_outputs.push(tensor_global); + globals[position_original] = Some(tensor_global); } } + for (handle, global) in handles.into_iter().zip(globals.into_iter()) { + analysis.handle_outputs.push(handle.unwrap()); + analysis.global_outputs.push(global.unwrap()); + } + Self::add_layout_info_inputs(analysis); } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index e294b5b02f..c9315823d7 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2208,6 +2208,7 @@ where let indices = Tensor::::arange(0..size as i64, device).unsqueeze::<2>(); let ones = K::ones([1, size].into(), device); let zeros = K::zeros([size, size].into(), device); + Self::new(K::scatter(0, zeros, indices.primitive, ones)) } } diff --git a/crates/burn-tensor/src/tests/stats/eye.rs b/crates/burn-tensor/src/tests/stats/eye.rs index b3bc9c9343..d490129884 100644 --- a/crates/burn-tensor/src/tests/stats/eye.rs +++ b/crates/burn-tensor/src/tests/stats/eye.rs @@ -9,6 +9,8 @@ mod tests { let device = Default::default(); let tensor = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); let rhs = TestTensor::<2>::eye(3, &device); + println!("Tensor {tensor}"); + println!("Rhs {rhs}"); assert_eq!(tensor.to_data(), rhs.to_data()); } From 19b01f495c30461ac59dbfd0f113c1fda4924d42 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Thu, 30 Jan 2025 17:05:17 -0500 Subject: [PATCH 10/28] Fix another bug, but still very wip --- crates/burn-jit/src/fusion/base.rs | 1 + .../burn-jit/src/fusion/elemwise/builder.rs | 1 + .../src/fusion/elemwise/optimization.rs | 2 +- .../burn-jit/src/fusion/on_write/builder.rs | 14 +++-- crates/burn-jit/src/fusion/on_write/io.rs | 25 ++++++-- crates/burn-jit/src/fusion/on_write/trace.rs | 6 +- crates/burn-jit/src/ops/int_ops.rs | 1 + crates/burn-tensor/src/tests/ops/reshape.rs | 62 ++----------------- 8 files changed, 44 insertions(+), 68 deletions(-) diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 48587a1bf9..b432bee09f 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -244,6 +244,7 @@ impl JitFusionHandle { /// Return the reference to a tensor argument. pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); + println!("Shape {:?} - Strides {:?}", handle.shape, handle.strides); unsafe { TensorArg::from_raw_parts_and_size( diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index c36599427d..423ac7a7cd 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -36,6 +36,7 @@ impl OptimizationBuilder> for ElementWiseBuild } fn build(&self) -> JitOptimization { + println!("Build"); let client = R::client(&self.device); let trace = self.builder.build(); let elementwise = diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 6c48b3a2fd..2bee441df5 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -111,7 +111,7 @@ impl TraceRunner for ElemwiseRunner { None => panic!("Invalid argument"), }; - println!("Shape {shape:?} - {vectorization}"); + println!("RUN Shape {shape:?} - {vectorization}"); let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index e32ac10057..8b2789ce0a 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -105,12 +105,14 @@ impl OptimizationBuilder for FuseOnWriteBuilder { return; } } - OperationDescription::BaseBool(ops) => { - if !self.register_base(ops) { - self.status = OptimizationStatus::Closed; - return; - } - } + // TODO + // + // OperationDescription::BaseBool(ops) => { + // if !self.register_base(ops) { + // self.status = OptimizationStatus::Closed; + // return; + // } + // } _ => { self.status = OptimizationStatus::Closed; return; diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index a89b389aff..e6008e788c 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -829,8 +829,8 @@ fn index_i>>(rank: u32, iter: Elem) -> Expand } #[cube] -fn convert_index_standard_to_original_index( - original: &Tensor>, +fn convert_index_standard_to_original_index( + original: &Tensor>, rank: u32, index_standard: u32, ) -> u32 { @@ -839,12 +839,29 @@ fn convert_index_standard_to_original_index( for i in 0..rank { let shape = original.shape(i); - let stride = original.stride(i); let coordinate = remaining % shape; remaining /= shape; - index += coordinate * stride; + index += coordinate * original.stride(i); } index / original.line_size() } + +#[cube] +fn convert_index_standard_to_original_index_2( + original: &Tensor>, + layout: &Tensor>, + rank: u32, + index_standard: u32, +) -> u32 { + let offset_ref = index_standard; + let mut offset = 0u32; + + for i in 0u32..rank { + let ogwl = offset_ref / original.stride(i); + offset += ogwl % original.shape(i) * original.stride(i); + } + + offset / original.line_size() +} diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 93ea9ae6b1..d64fd20d0e 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -183,6 +183,7 @@ impl FuseOnWriteTrace { ops.push(op); } + println!("Chosen refernec {:?}", analysis.reference); let config = ElemwiseConfig { rank: analysis.rank as u32, ref_layout: analysis @@ -322,6 +323,7 @@ impl FuseOnWriteTrace { for (position_original, (precision, tensor_relative)) in output_sorted.into_iter() { let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let strides = strides_dyn_rank(&tensor_global.shape); + println!("shape {:?} strides {strides:?}", tensor_global.shape); if let Some(index) = analysis .potential_inplaces @@ -379,7 +381,7 @@ impl FuseOnWriteTrace { } else { if analysis.reference.is_none() { analysis.reference = Some(Reference { - layout: Arg::Output(0, precision, LayoutInfo::IsRef), + layout: Arg::Output(position_original as u32, precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: strides.clone(), }); @@ -485,7 +487,9 @@ impl FuseOnWriteTrace { ); for hi in handle_inputs.iter() { + println!("Input shape {:?}", hi.global_shape); let arg = hi.handle.as_tensor_arg(&hi.global_shape, vectorization); + println!("Done"); match hi.precision { ElemwisePrecision::F32 => inputs.t_f32.push(arg), ElemwisePrecision::F16 => inputs.t_f16.push(arg), diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 5702a90849..73b64265df 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -39,6 +39,7 @@ where } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { + println!("Strides reshape {:?}", tensor.strides); super::reshape(tensor, shape) } diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index 3a31bb5fc9..0f48415ea1 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -75,63 +75,13 @@ mod tests { } #[test] - fn should_support_reshape_maybe_fused() { - let tensor = TestTensorInt::arange(0..32, &Default::default()); - let tensor1 = tensor.clone().reshape([1, 4, 8]); - let tensor2 = tensor.reshape([8, 4, 1]); + fn should_support_reshape_maybe_fused_3() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor1 = tensor.reshape([2, 2, 1]); + let tensor2 = TestTensorInt::<3>::full([2, 2, 3], 5, &Default::default()); - let output = tensor1 + tensor2; - let expected = TensorData::from([ - [ - [0, 1, 2, 3, 4, 5, 6, 7], - [9, 10, 11, 12, 13, 14, 15, 16], - [18, 19, 20, 21, 22, 23, 24, 25], - [27, 28, 29, 30, 31, 32, 33, 34], - ], - [ - [4, 5, 6, 7, 8, 9, 10, 11], - [13, 14, 15, 16, 17, 18, 19, 20], - [22, 23, 24, 25, 26, 27, 28, 29], - [31, 32, 33, 34, 35, 36, 37, 38], - ], - [ - [8, 9, 10, 11, 12, 13, 14, 15], - [17, 18, 19, 20, 21, 22, 23, 24], - [26, 27, 28, 29, 30, 31, 32, 33], - [35, 36, 37, 38, 39, 40, 41, 42], - ], - [ - [12, 13, 14, 15, 16, 17, 18, 19], - [21, 22, 23, 24, 25, 26, 27, 28], - [30, 31, 32, 33, 34, 35, 36, 37], - [39, 40, 41, 42, 43, 44, 45, 46], - ], - [ - [16, 17, 18, 19, 20, 21, 22, 23], - [25, 26, 27, 28, 29, 30, 31, 32], - [34, 35, 36, 37, 38, 39, 40, 41], - [43, 44, 45, 46, 47, 48, 49, 50], - ], - [ - [20, 21, 22, 23, 24, 25, 26, 27], - [29, 30, 31, 32, 33, 34, 35, 36], - [38, 39, 40, 41, 42, 43, 44, 45], - [47, 48, 49, 50, 51, 52, 53, 54], - ], - [ - [24, 25, 26, 27, 28, 29, 30, 31], - [33, 34, 35, 36, 37, 38, 39, 40], - [42, 43, 44, 45, 46, 47, 48, 49], - [51, 52, 53, 54, 55, 56, 57, 58], - ], - [ - [28, 29, 30, 31, 32, 33, 34, 35], - [37, 38, 39, 40, 41, 42, 43, 44], - [46, 47, 48, 49, 50, 51, 52, 53], - [55, 56, 57, 58, 59, 60, 61, 62], - ], - ]); - output.into_data().assert_eq(&expected, false); + let expected_tensor1 = TensorData::from([[[0], [2]], [[1], [2]]]); + tensor1.into_data().assert_eq(&expected_tensor1, false); } #[test] From f3b34596d2b6f148b092a127c8cbb5aff7aafeaf Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 11:15:25 -0500 Subject: [PATCH 11/28] WIP Works --- crates/burn-jit/src/fusion/on_write/io.rs | 62 ++++++++----------- crates/burn-jit/src/fusion/on_write/mod.rs | 1 + .../burn-jit/src/fusion/on_write/position.rs | 31 ++++++++++ crates/burn-jit/src/fusion/on_write/trace.rs | 37 ++++++++--- crates/burn-tensor/src/tests/ops/reshape.rs | 5 +- 5 files changed, 89 insertions(+), 47 deletions(-) create mode 100644 crates/burn-jit/src/fusion/on_write/position.rs diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index e6008e788c..6a9b108270 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -774,8 +774,9 @@ fn index_offset_with_layout( // Need to unroll when fusing a reshape. match comptime![shape.clone()] { Some(shape) => { - let index_standard = reshaped_index_standard(inputs, layout, index, rank, shape); - convert_index_standard_to_original_index(tensor, rank, index_standard) + let index_reshaped = reshaped_index(inputs, layout, index, rank, shape); + reshaped_index_to_original_index(tensor, index_reshaped, rank) + // index_reshaped } None => { let offset_ref = index * tensor.line_size(); @@ -792,7 +793,7 @@ fn index_offset_with_layout( } #[cube] -fn reshaped_index_standard( +fn reshaped_index( inputs: &GlobalArgs, layout: &Tensor>, index: u32, @@ -800,6 +801,7 @@ fn reshaped_index_standard( #[comptime] shape: Sequence, ) -> u32 { let index = index * layout.line_size(); + let mut offset = 0u32; let mut stride_curr = 1u32; @@ -818,50 +820,36 @@ fn reshaped_index_standard( offset } -fn index_i>>(rank: u32, iter: Elem) -> ExpandElementTyped { - let elem = iter.into(); - let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); - let result = rank - elem - 1; - let scalar: Variable = result.into(); - let expand: ExpandElement = ExpandElement::Plain(scalar); - - expand.into() -} - #[cube] -fn convert_index_standard_to_original_index( +fn reshaped_index_to_original_index( original: &Tensor>, - rank: u32, - index_standard: u32, + index_reshaped: u32, + #[comptime] rank: u32, ) -> u32 { - let mut remaining = index_standard; - let mut index = 0; + let mut remaining = index_reshaped; + let mut offset = 0; + + #[unroll] + for r in 0..rank { + let i = comptime![index_i(rank, r)]; + let shape = original.shape(comptime![i.clone()]); + let stride = original.stride(i); - for i in 0..rank { - let shape = original.shape(i); let coordinate = remaining % shape; remaining /= shape; - index += coordinate * original.stride(i); + offset += coordinate * stride; } - index / original.line_size() + offset / original.line_size() } -#[cube] -fn convert_index_standard_to_original_index_2( - original: &Tensor>, - layout: &Tensor>, - rank: u32, - index_standard: u32, -) -> u32 { - let offset_ref = index_standard; - let mut offset = 0u32; - - for i in 0u32..rank { - let ogwl = offset_ref / original.stride(i); - offset += ogwl % original.shape(i) * original.stride(i); - } +fn index_i>>(rank: u32, iter: Elem) -> ExpandElementTyped { + let elem = iter.into(); + let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); + let result = rank - elem - 1; + let scalar: Variable = result.into(); + let expand: ExpandElement = ExpandElement::Plain(scalar); - offset / original.line_size() + expand.into() } diff --git a/crates/burn-jit/src/fusion/on_write/mod.rs b/crates/burn-jit/src/fusion/on_write/mod.rs index d40d682dcd..4dda2d324f 100644 --- a/crates/burn-jit/src/fusion/on_write/mod.rs +++ b/crates/burn-jit/src/fusion/on_write/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod builder; pub(crate) mod io; pub(crate) mod ir; pub(crate) mod kernel; +pub(super) mod position; pub mod trace; pub(crate) mod trace_builder; diff --git a/crates/burn-jit/src/fusion/on_write/position.rs b/crates/burn-jit/src/fusion/on_write/position.rs new file mode 100644 index 0000000000..194e7dbb92 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/position.rs @@ -0,0 +1,31 @@ +use super::ir::ElemwisePrecision; +use std::collections::BTreeMap; + +/// Group output position by [element precision](ElemwisePrecision). +#[derive(Default, Debug)] +pub struct PositionMapper { + map: BTreeMap>, +} + +impl PositionMapper { + /// Register a new output with the given precision and position. + pub fn register(&mut self, precision: ElemwisePrecision, pos_handle: usize) { + if let Some(positions) = self.map.get_mut(&precision) { + positions.push(pos_handle); + } else { + self.map.insert(precision, vec![pos_handle]); + } + } + + /// Returns the right position from the precision and the global position in all outputs. + pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { + self.map + .get(&precision) + .unwrap() + .iter() + .enumerate() + .find(|(_pos_elem, pos_all)| **pos_all == pos_handle) + .map(|(pos_elem, _pos_all)| pos_elem) + .unwrap() as u32 + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index d64fd20d0e..501c3ee884 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -4,6 +4,7 @@ use crate::{ }; use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; +use super::position::PositionMapper; use burn_fusion::stream::Context; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, @@ -183,7 +184,7 @@ impl FuseOnWriteTrace { ops.push(op); } - println!("Chosen refernec {:?}", analysis.reference); + println!("Chosen reference {:?}", analysis.reference); let config = ElemwiseConfig { rank: analysis.rank as u32, ref_layout: analysis @@ -238,7 +239,7 @@ impl FuseOnWriteTrace { reference: None, reads: self.reads.clone(), writes: self.writes.clone(), - rank: 1, + rank: self.shape_ref.len(), vectorization: 1, }; @@ -266,12 +267,12 @@ impl FuseOnWriteTrace { analysis: &mut LaunchAnalysis<'a, R>, ) { for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); + let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); // Important to take the status of the relative graph and not // the global graph, since the status of the global graph // might be of a later operation on the same tensor id. let status = &tensor_relative.status; - let handle = context.handles.get_handle(&tensor_global.id, status); + let mut handle = context.handles.get_handle(&tensor_global.id, status); if status == &TensorStatus::ReadWrite && handle.handle.can_mut() @@ -286,7 +287,14 @@ impl FuseOnWriteTrace { }); } - analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); + if tensor_global.shape.len() < analysis.rank { + let num_elem: usize = tensor_global.shape.iter().product(); + for _ in 0..(analysis.rank - tensor_global.shape.len()) { + tensor_global.shape.insert(0, 1); + handle.strides.insert(0, num_elem); + } + } + analysis.handle_inputs.push(HandleInput { precision, handle, @@ -305,13 +313,25 @@ impl FuseOnWriteTrace { context: &mut Context<'_, JitFusionHandle>, analysis: &mut LaunchAnalysis<'a, R>, ) { - let mut output_sorted: Vec<_> = self.outputs.iter().enumerate().collect(); + let mut position_mapper = PositionMapper::default(); + let mut output_sorted: Vec<_> = self + .outputs + .iter() + .enumerate() + .map(|(pos, (precision, tensor))| { + position_mapper.register(precision, pos); + (pos, (precision, tensor)) + }) + .collect(); + output_sorted.sort_by(|(_, (_, a)), (_, (_, b))| { let a_val: usize = a.shape.iter().sum(); let b_val: usize = b.shape.iter().sum(); b_val.cmp(&a_val) }); + println!("Mapper {position_mapper:?}"); + println!("Sorted {output_sorted:?}"); let mut handles = Vec::with_capacity(self.outputs.len()); let mut globals = Vec::with_capacity(self.outputs.len()); @@ -380,8 +400,9 @@ impl FuseOnWriteTrace { globals[position_original] = Some(tensor_global); } else { if analysis.reference.is_none() { + let position = position_mapper.resolve_index(&precision, position_original); analysis.reference = Some(Reference { - layout: Arg::Output(position_original as u32, precision, LayoutInfo::IsRef), + layout: Arg::Output(position, precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: strides.clone(), }); @@ -487,9 +508,7 @@ impl FuseOnWriteTrace { ); for hi in handle_inputs.iter() { - println!("Input shape {:?}", hi.global_shape); let arg = hi.handle.as_tensor_arg(&hi.global_shape, vectorization); - println!("Done"); match hi.precision { ElemwisePrecision::F32 => inputs.t_f32.push(arg), ElemwisePrecision::F16 => inputs.t_f16.push(arg), diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index 0f48415ea1..ce06dd4aed 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -17,10 +17,13 @@ mod tests { #[test] fn should_support_reshape_maybe_fused_1() { let tensor = TestTensorInt::arange(0..32, &Default::default()); + // let tensor = tensor.reshape([1, 1, 32]); + // println!("{tensor}"); let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); let tensor1 = tensor.clone().reshape([1, 4, 8]); - let output = tensor0 + tensor1; + + println!("{output}"); let expected = TensorData::from([ [ [0, 1, 2, 3, 4, 5, 6, 7], From f80b7222e4c7e4473273b7480beddb9882cabbba Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 11:21:08 -0500 Subject: [PATCH 12/28] Cleanup --- crates/burn-jit/src/fusion/elemwise/builder.rs | 6 +----- crates/burn-jit/src/fusion/elemwise/optimization.rs | 1 - crates/burn-jit/src/fusion/on_write/builder.rs | 5 +++++ crates/burn-jit/src/fusion/on_write/io.rs | 1 - crates/burn-jit/src/fusion/on_write/trace.rs | 5 ----- crates/burn-jit/src/fusion/on_write/trace_builder.rs | 2 -- 6 files changed, 6 insertions(+), 14 deletions(-) diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 423ac7a7cd..13c12b076d 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -31,12 +31,10 @@ impl ElementWiseBuilder { impl OptimizationBuilder> for ElementWiseBuilder { fn register(&mut self, operation: &burn_tensor::repr::OperationDescription) { - println!("op {operation:?}"); self.builder.register(operation) } fn build(&self) -> JitOptimization { - println!("Build"); let client = R::client(&self.device); let trace = self.builder.build(); let elementwise = @@ -50,9 +48,7 @@ impl OptimizationBuilder> for ElementWiseBuild } fn status(&self) -> burn_fusion::OptimizationStatus { - let state = self.builder.status(); - println!("{state:?}"); - state + self.builder.status() } fn properties(&self) -> burn_fusion::OptimizationProperties { diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 2bee441df5..bc67e9e945 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -111,7 +111,6 @@ impl TraceRunner for ElemwiseRunner { None => panic!("Invalid argument"), }; - println!("RUN Shape {shape:?} - {vectorization}"); let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 8b2789ce0a..911ad2253f 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -196,6 +196,11 @@ impl FuseOnWriteBuilder { }); } + if desc.input.shape.len() > desc.out.shape.len() { + // Not yet supported. + return false; + } + if !self.output_is_compatible(&desc.out) { return false; } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 6a9b108270..2f76e86092 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -776,7 +776,6 @@ fn index_offset_with_layout( Some(shape) => { let index_reshaped = reshaped_index(inputs, layout, index, rank, shape); reshaped_index_to_original_index(tensor, index_reshaped, rank) - // index_reshaped } None => { let offset_ref = index * tensor.line_size(); diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 501c3ee884..83f19c6528 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -163,7 +163,6 @@ impl FuseOnWriteTrace { runner: &Runner, ) -> Result<(), Runner::Error> { let analysis = self.analyse::(client, device, context); - let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); let outputs = self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); @@ -184,7 +183,6 @@ impl FuseOnWriteTrace { ops.push(op); } - println!("Chosen reference {:?}", analysis.reference); let config = ElemwiseConfig { rank: analysis.rank as u32, ref_layout: analysis @@ -330,8 +328,6 @@ impl FuseOnWriteTrace { b_val.cmp(&a_val) }); - println!("Mapper {position_mapper:?}"); - println!("Sorted {output_sorted:?}"); let mut handles = Vec::with_capacity(self.outputs.len()); let mut globals = Vec::with_capacity(self.outputs.len()); @@ -343,7 +339,6 @@ impl FuseOnWriteTrace { for (position_original, (precision, tensor_relative)) in output_sorted.into_iter() { let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let strides = strides_dyn_rank(&tensor_global.shape); - println!("shape {:?} strides {strides:?}", tensor_global.shape); if let Some(index) = analysis .potential_inplaces diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index a6c5d161fd..a037d4a2f5 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -154,7 +154,6 @@ impl FuseOnWriteTraceBuilder { Some(_) => { // Can't fused an already fused input. if self.outputs.get(precision_input, tensor.id).is_some() { - println!("Can't fused an already fused input."); return None; } @@ -164,7 +163,6 @@ impl FuseOnWriteTraceBuilder { index as u32 } None => { - println!("HM"); return None; } } From f86fbcbaaada555a5f2f9e81540cfea9ffb115fd Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 12:45:50 -0500 Subject: [PATCH 13/28] Support broadcasted vectorization --- crates/burn-jit/src/fusion/base.rs | 1 - crates/burn-jit/src/fusion/on_write/trace.rs | 161 ++++++++++++++---- .../src/fusion/on_write/trace_builder.rs | 16 +- crates/burn-tensor/src/tests/ops/reshape.rs | 15 +- 4 files changed, 150 insertions(+), 43 deletions(-) diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index b432bee09f..48587a1bf9 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -244,7 +244,6 @@ impl JitFusionHandle { /// Return the reference to a tensor argument. pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); - println!("Shape {:?} - Strides {:?}", handle.shape, handle.strides); unsafe { TensorArg::from_raw_parts_and_size( diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 83f19c6528..712aa8ed39 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -20,7 +20,7 @@ pub struct FuseOnWriteTrace { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes_reshape: Vec, + reshapes: Vec, shape_ref: Vec, ops: Vec, reads: BTreeMap>, @@ -28,6 +28,12 @@ pub struct FuseOnWriteTrace { inputs_unhandled: Vec, } +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Reshape { + pub reshaped: TensorId, + pub original: TensorId, +} + /// A trace runner is responsible for determining the vectorization factor as well as launching /// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) /// with a provided [element wise config](ElemwiseConfig). @@ -46,11 +52,17 @@ pub trait TraceRunner { /// The vectorization factor for all inputs and outputs. fn vectorization<'a>( + vectorizations: &mut BTreeMap, handles_inputs: impl Iterator>, inputs: impl Iterator, outputs: impl Iterator, - reshaped: impl Iterator, - ) -> u8 { + reshaped: impl Iterator, + ) { + enum Vect { + Broadcated, + Max(u8), + } + // The default version uses the last dimension as vectorization axis and assumes a // perpendicular contiguous line. let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { @@ -58,17 +70,22 @@ pub trait TraceRunner { // Last dimension strides should be 1, otherwise vecX won't be contiguous. if handle.strides[rank - 1] != 1 { - return 1; + return Vect::Max(1); + } + let shape_axis = desc.shape[rank - 1]; + + if shape_axis == 1 { + return Vect::Broadcated; } for s in R::line_size_elem(&desc.dtype.into()) { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % s as usize == 0 { - return s; + // The last dimension should be a multiple of the vector size or broadcated. + if shape_axis % s as usize == 0 { + return Vect::Max(s); } } - 1 + Vect::Max(1) }; let vectorization_output = |desc: &TensorDescription| { @@ -77,28 +94,82 @@ pub trait TraceRunner { for s in R::line_size_elem(&desc.dtype.into()) { // The last dimension should be a multiple of the vector size. if desc.shape[rank - 1] % s as usize == 0 { - return s; + return Vect::Max(s); } } - 1 + Vect::Max(1) }; - let mut output = u8::MAX; + let vectorization_reshape = + |reshaped: &TensorDescription, original: &TensorDescription, multi_reads: bool| { + let reshape_axis = reshaped.shape[reshaped.shape.len() - 1]; + let shape_axis = original.shape[original.shape.len() - 1]; + + if !multi_reads && reshape_axis == 1 { + return Vect::Broadcated; + } + + for s in R::line_size_elem(&reshaped.dtype.into()) { + if !multi_reads { + // The last dimension should be a multiple of the vector size or broadcated. + if reshape_axis % s as usize == 0 { + return Vect::Max(s); + } + } else { + // Since the original tensor must share the same vectorization factor as the + // reshaped tensor, they must have compatible shapes when both are access + // independently. + if reshape_axis % s as usize == 0 && shape_axis % s as usize == 0 { + return Vect::Max(s); + } + } + } + + Vect::Max(1) + }; + + let mut max_current = u8::MAX; for (handle, tensor) in handles_inputs.zip(inputs) { - output = Ord::min(vectorization_input(handle, tensor), output); + match vectorization_input(&handle, tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; } for tensor in outputs { - output = Ord::min(vectorization_output(tensor), output); + match vectorization_output(tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; } - for tensor in reshaped { - output = Ord::min(vectorization_output(tensor), output); + for (reshaped, original, multi_reads) in reshaped { + match vectorization_reshape(reshaped, original, multi_reads) { + Vect::Broadcated => { + vectorizations.insert(original.id, 1); + vectorizations.insert(reshaped.id, 1); + } + Vect::Max(val) => { + vectorizations.insert(original.id, 0); + vectorizations.insert(reshaped.id, 0); + max_current = Ord::min(val, max_current); + } + } } - output + for (_id, val) in vectorizations.iter_mut() { + if *val == 0 { + *val = max_current; + } + } } } @@ -112,8 +183,8 @@ struct LaunchAnalysis<'a, R: JitRuntime> { reference: Option, reads: BTreeMap>, writes: BTreeMap, + vectorization: BTreeMap, rank: usize, - vectorization: u8, } #[derive(Debug)] @@ -127,6 +198,7 @@ enum HandleOutput { precision: ElemwisePrecision, handle: JitFusionHandle, global_shape: Vec, + vectorization: u8, }, } @@ -137,6 +209,7 @@ struct HandleInput { precision: ElemwisePrecision, handle: JitFusionHandle, global_shape: Vec, + vectorization: u8, } #[derive(Debug)] @@ -163,9 +236,8 @@ impl FuseOnWriteTrace { runner: &Runner, ) -> Result<(), Runner::Error> { let analysis = self.analyse::(client, device, context); - let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); - let outputs = - self.register_outputs::<_, BT>(&analysis.handle_outputs, analysis.vectorization); + let inputs = self.register_inputs(context, &analysis.handle_inputs); + let outputs = self.register_outputs::<_, BT>(&analysis.handle_outputs); let mut ops = Sequence::::new(); @@ -235,27 +307,45 @@ impl FuseOnWriteTrace { handle_inputs: Vec::new(), handle_outputs: Vec::new(), reference: None, + vectorization: BTreeMap::default(), reads: self.reads.clone(), writes: self.writes.clone(), rank: self.shape_ref.len(), - vectorization: 1, }; self.analyse_inputs(context, &mut analysis); self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); - let tensors_reshaped = self - .shapes_reshape - .iter() - .map(|id| context.tensors.get(id).unwrap()); + let tensors_reshaped = self.reshapes.iter().map(|reshape| { + ( + context.tensors.get(&reshape.reshaped).unwrap(), + context.tensors.get(&reshape.original).unwrap(), + self.reads.get(&reshape.original).unwrap().len() > 1, + ) + }); - analysis.vectorization = Runner::vectorization( + Runner::vectorization( + &mut analysis.vectorization, analysis.handle_inputs.iter().map(|item| &item.handle), analysis.global_inputs.iter(), analysis.global_outputs.iter(), tensors_reshaped, ); + for handle in analysis.handle_inputs.iter_mut() { + handle.vectorization = *analysis.vectorization.get(&handle.global_id).unwrap(); + } + for handle in analysis.handle_outputs.iter_mut() { + match handle { + HandleOutput::Owned { + vectorization, + global_id, + .. + } => *vectorization = *analysis.vectorization.get(&global_id).unwrap(), + _ => {} + } + } + analysis } @@ -275,7 +365,11 @@ impl FuseOnWriteTrace { if status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) - && !self.shapes_reshape.contains(&tensor_relative.id) + && self + .reshapes + .iter() + .find(|r| r.reshaped == tensor_relative.id) + .is_none() && self.shape_ref == tensor_relative.shape { analysis.potential_inplaces.push(PotentialInplace { @@ -299,6 +393,7 @@ impl FuseOnWriteTrace { relative_id: tensor_relative.id, global_id: tensor_global.id, global_shape: tensor_global.shape.clone(), + vectorization: 1, }); analysis.global_inputs.push(tensor_global); } @@ -442,6 +537,7 @@ impl FuseOnWriteTrace { handle, global_shape: tensor_global.shape.clone(), global_id: tensor_global.id, + vectorization: 1, }); globals[position_original] = Some(tensor_global); } @@ -475,7 +571,6 @@ impl FuseOnWriteTrace { &self, context: &mut Context<'_, JitFusionHandle>, handle_inputs: &'h [HandleInput], - vectorization: u8, ) -> GlobalArgsLaunch<'h, R> { let mut inputs = GlobalArgsLaunch::new( SequenceArg::new(), @@ -503,7 +598,7 @@ impl FuseOnWriteTrace { ); for hi in handle_inputs.iter() { - let arg = hi.handle.as_tensor_arg(&hi.global_shape, vectorization); + let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); match hi.precision { ElemwisePrecision::F32 => inputs.t_f32.push(arg), ElemwisePrecision::F16 => inputs.t_f16.push(arg), @@ -557,8 +652,8 @@ impl FuseOnWriteTrace { } } - for relative in self.shapes_reshape.iter().rev() { - let global = context.tensors.get(relative).unwrap(); + for relative in self.reshapes.iter().rev() { + let global = context.tensors.get(&relative.reshaped).unwrap(); for shape in global.shape.iter().rev() { inputs.s_u32.push(ScalarArg::new(*shape as u32)) @@ -571,7 +666,6 @@ impl FuseOnWriteTrace { fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( &self, handle_outputs: &'s [HandleOutput], - vectorization: u8, ) -> GlobalArgsLaunch<'s, R> { let mut outputs = GlobalArgsLaunch::new( SequenceArg::new(), @@ -620,9 +714,10 @@ impl FuseOnWriteTrace { precision, handle, global_shape, + vectorization, .. } => { - let arg = handle.as_tensor_arg(global_shape, vectorization); + let arg = handle.as_tensor_arg(global_shape, *vectorization); match precision { ElemwisePrecision::F32 => outputs.t_f32.push(arg), diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index a037d4a2f5..babef2ae8c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -1,5 +1,6 @@ use super::{ ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, LayoutInfo, UnaryElemwiseArgs}, + trace::Reshape, trace::{FuseOnWriteTrace, RegisteredTensors}, }; use burn_tensor::{ @@ -15,7 +16,7 @@ pub struct FuseOnWriteTraceBuilder { outputs: RegisteredTensors, inputs: RegisteredTensors, scalars: BTreeMap, - shapes_reshape: Vec, + reshapes: Vec, ops: Vec, reads: BTreeMap>, pub bool_precision: ElemwisePrecision, @@ -30,7 +31,7 @@ impl FuseOnWriteTraceBuilder { outputs: RegisteredTensors::default(), inputs: RegisteredTensors::default(), scalars: BTreeMap::default(), - shapes_reshape: Vec::new(), + reshapes: Vec::new(), ops: Vec::new(), reads: BTreeMap::new(), bool_precision, @@ -175,8 +176,11 @@ impl FuseOnWriteTraceBuilder { let mut shape = Sequence::new(); - let index = self.shapes_reshape.len(); - self.shapes_reshape.push(output.id.clone()); + let index = self.reshapes.len(); + self.reshapes.push(Reshape { + reshaped: output.id.clone(), + original: tensor.id.clone(), + }); let rank = output.shape.len(); for i in 0..output.shape.len() { @@ -242,14 +246,14 @@ impl FuseOnWriteTraceBuilder { ); } - let shapes_reshape = self.shapes_reshape.clone(); + let reshapes = self.reshapes.clone(); // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, inputs, scalars, - shapes_reshape, + reshapes, shape, ops, reads, diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index ce06dd4aed..b1d4dcd077 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -17,13 +17,10 @@ mod tests { #[test] fn should_support_reshape_maybe_fused_1() { let tensor = TestTensorInt::arange(0..32, &Default::default()); - // let tensor = tensor.reshape([1, 1, 32]); - // println!("{tensor}"); let tensor0 = TestTensorInt::zeros([8, 4, 8], &Default::default()); let tensor1 = tensor.clone().reshape([1, 4, 8]); let output = tensor0 + tensor1; - println!("{output}"); let expected = TensorData::from([ [ [0, 1, 2, 3, 4, 5, 6, 7], @@ -77,6 +74,18 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_reshape_maybe_fused_2() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor1 = tensor.reshape([2, 2, 1]); + let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); + let output = tensor2 + tensor1; + + let expected_tensor1 = + TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); + output.into_data().assert_eq(&expected_tensor1, false); + } + #[test] fn should_support_reshape_maybe_fused_3() { let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); From 11550b4b31345e1f4d72d07db8788d52b0f67445 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 13:33:06 -0500 Subject: [PATCH 14/28] Cleanup --- crates/burn-core/src/nn/linear.rs | 4 ---- crates/burn-jit/src/fusion/on_write/builder.rs | 5 ----- crates/burn-jit/src/fusion/on_write/io.rs | 10 ++++++---- crates/burn-jit/src/kernel/quantization/dequantize.rs | 2 +- crates/burn-jit/src/kernel/quantization/quantize.rs | 2 -- crates/burn-jit/src/ops/int_ops.rs | 1 - crates/burn-tensor/src/tensor/api/bool.rs | 4 +--- crates/burn-tensor/src/tests/stats/eye.rs | 2 -- 8 files changed, 8 insertions(+), 22 deletions(-) diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 994cac5b04..d3cf5951b0 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -62,10 +62,6 @@ impl LinearConfig { } } -// into_contuiguous -// strides contibous -// shape - impl Linear { /// Applies the forward pass on the input tensor. /// diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 911ad2253f..5110f3df3f 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -495,11 +495,6 @@ impl FuseOnWriteBuilder { return true; } - // Last axis should be equal. - // if self.current_output_shape.last() != out.shape.last() { - // return false; - // } - let rank = self.current_output_shape.len(); // Rank should be equal. diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 2f76e86092..56a90ed525 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -771,7 +771,6 @@ fn index_offset_with_layout( #[comptime] rank: u32, #[comptime] shape: Option>, ) -> u32 { - // Need to unroll when fusing a reshape. match comptime![shape.clone()] { Some(shape) => { let index_reshaped = reshaped_index(inputs, layout, index, rank, shape); @@ -806,7 +805,7 @@ fn reshaped_index( #[unroll] for r in 0..rank { - let i = comptime![index_i(rank, r)]; + let i = comptime![reverse_index(rank, r)]; let arg = comptime![shape.index(i.clone())]; let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); @@ -830,7 +829,7 @@ fn reshaped_index_to_original_index( #[unroll] for r in 0..rank { - let i = comptime![index_i(rank, r)]; + let i = comptime![reverse_index(rank, r)]; let shape = original.shape(comptime![i.clone()]); let stride = original.stride(i); @@ -843,7 +842,10 @@ fn reshaped_index_to_original_index( offset / original.line_size() } -fn index_i>>(rank: u32, iter: Elem) -> ExpandElementTyped { +fn reverse_index>>( + rank: u32, + iter: Elem, +) -> ExpandElementTyped { let elem = iter.into(); let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); let result = rank - elem - 1; diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index c8cac2860e..72040d8839 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -10,7 +10,7 @@ use super::{QParams, QTensor}; #[cube] pub(crate) fn dequantize_affine_int8( - value: Line, // 4 i32 + value: Line, scale: f32, offset: i32, ) -> Line { diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index ec064ed5fa..e9494aa987 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -24,8 +24,6 @@ pub(crate) fn quantize_affine_int8( ) } -/// 32 bits encoder en f32 => int8 -/// u32 => 4 valeurs quantizer #[cube(launch_unchecked)] pub(crate) fn quantize_per_tensor_affine_int8_kernel( input: &Tensor>, diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 73b64265df..5702a90849 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -39,7 +39,6 @@ where } fn int_reshape(tensor: IntTensor, shape: Shape) -> IntTensor { - println!("Strides reshape {:?}", tensor.strides); super::reshape(tensor, shape) } diff --git a/crates/burn-tensor/src/tensor/api/bool.rs b/crates/burn-tensor/src/tensor/api/bool.rs index e89ad38179..ea7c5b196d 100644 --- a/crates/burn-tensor/src/tensor/api/bool.rs +++ b/crates/burn-tensor/src/tensor/api/bool.rs @@ -122,9 +122,7 @@ where }; // Generate and return the mask by applying the comparison to the matrix. - // println!("Matrix {matrix}"); - let out = compare(matrix, 0); - out.unsqueeze() + compare(matrix, 0).unsqueeze() } /// Creates a mask for the upper triangle of a matrix, which can be used to fill the specified diff --git a/crates/burn-tensor/src/tests/stats/eye.rs b/crates/burn-tensor/src/tests/stats/eye.rs index d490129884..b3bc9c9343 100644 --- a/crates/burn-tensor/src/tests/stats/eye.rs +++ b/crates/burn-tensor/src/tests/stats/eye.rs @@ -9,8 +9,6 @@ mod tests { let device = Default::default(); let tensor = TestTensor::<2>::from([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]); let rhs = TestTensor::<2>::eye(3, &device); - println!("Tensor {tensor}"); - println!("Rhs {rhs}"); assert_eq!(tensor.to_data(), rhs.to_data()); } From 5e14374147169a2628d83f66107a6db31804fbd4 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 14:11:22 -0500 Subject: [PATCH 15/28] Still some bugs --- crates/burn-jit/src/fusion/on_write/builder.rs | 1 + crates/burn-jit/src/fusion/on_write/trace.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 5110f3df3f..13b1809f3e 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -216,6 +216,7 @@ impl FuseOnWriteBuilder { true }) { + println!("Reshape"); self.num_reshapes += 1; true } else { diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 712aa8ed39..e6166e813f 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -368,7 +368,7 @@ impl FuseOnWriteTrace { && self .reshapes .iter() - .find(|r| r.reshaped == tensor_relative.id) + .find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) .is_none() && self.shape_ref == tensor_relative.shape { From b104e7898a6ec5d0aa1fa362e32fb1d394f0c4bc Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sat, 1 Feb 2025 16:47:32 -0500 Subject: [PATCH 16/28] Fix multi vectorization broadcasting fused --- crates/burn-core/Cargo.toml | 2 +- crates/burn-core/src/nn/rnn/gate_controller.rs | 9 ++++++++- crates/burn-core/src/nn/rnn/lstm.rs | 5 +++++ crates/burn-jit/src/fusion/base.rs | 8 ++++---- crates/burn-jit/src/fusion/on_write/builder.rs | 10 +++++++++- crates/burn-jit/src/fusion/on_write/io.rs | 2 +- crates/burn-tensor/src/tests/ops/reshape.rs | 14 ++++++++++++++ 7 files changed, 42 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index dc14d45f62..e895cc4572 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -113,7 +113,7 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda-jit", "fusion"] # To use cuda during testing, default uses ndarray. +test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index f20f7fa2a1..4d5868a936 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -57,7 +57,14 @@ impl GateController { /// H = hidden state /// b = bias terms pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { - self.input_transform.forward(input) + self.hidden_transform.forward(hidden) + println!("{input}"); + let temp = self.input_transform.forward(input); + let temp2 = self.hidden_transform.forward(hidden); + println!("1: {temp}"); + println!("2: {temp2}"); + let temp = temp + temp2; + // panic!("3: {temp}"); + temp } /// Used to initialize a gate controller with known weight layers, diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 9a7c23399b..9968a48a0a 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -721,21 +721,26 @@ mod tests { output_with_init_state .to_data() .assert_approx_eq(&expected_output_with_init_state, 3); + println!("1"); output_without_init_state .to_data() .assert_approx_eq(&expected_output_without_init_state, 3); + println!("2"); state_with_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_with_init_state, 3); + println!("3"); state_with_init_state .cell .to_data() .assert_approx_eq(&expected_cn_with_init_state, 3); + println!("4"); state_without_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_without_init_state, 3); + println!("5"); state_without_init_state .cell .to_data() diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 48587a1bf9..9721735132 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -130,10 +130,10 @@ impl FusionRuntime for FusionJitRuntime { device.clone(), BT::as_elem_native_unchecked().into(), )), - Box::new(MatmulBuilder::::new( - device.clone(), - BT::as_elem_native_unchecked().into(), - )), + // Box::new(MatmulBuilder::::new( + // device.clone(), + // BT::as_elem_native_unchecked().into(), + // )), ] } } diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 13b1809f3e..70a3820d28 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -196,12 +196,19 @@ impl FuseOnWriteBuilder { }); } + println!( + "Input rank {} : {}", + desc.input.shape.len(), + desc.out.shape.len() + ); if desc.input.shape.len() > desc.out.shape.len() { + println!("Not supported - invalid rank"); // Not yet supported. return false; } if !self.output_is_compatible(&desc.out) { + println!("Not supported - invalid output"); return false; } @@ -216,10 +223,11 @@ impl FuseOnWriteBuilder { true }) { - println!("Reshape"); + println!("Fusing Reshape"); self.num_reshapes += 1; true } else { + println!("Can't reshape already fused tensor."); false } } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 56a90ed525..24f2b94aa3 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -777,7 +777,7 @@ fn index_offset_with_layout( reshaped_index_to_original_index(tensor, index_reshaped, rank) } None => { - let offset_ref = index * tensor.line_size(); + let offset_ref = index * layout.line_size(); let mut offset = 0u32; for i in 0u32..rank { diff --git a/crates/burn-tensor/src/tests/ops/reshape.rs b/crates/burn-tensor/src/tests/ops/reshape.rs index b1d4dcd077..de581b2d4d 100644 --- a/crates/burn-tensor/src/tests/ops/reshape.rs +++ b/crates/burn-tensor/src/tests/ops/reshape.rs @@ -96,6 +96,20 @@ mod tests { tensor1.into_data().assert_eq(&expected_tensor1, false); } + #[test] + fn should_support_reshape_maybe_fused_4() { + let tensor = TestTensorInt::<3>::from_data([[[0, 2], [1, 2]]], &Default::default()); + let tensor2 = TestTensorInt::<3>::full([2, 2, 4], 4, &Default::default()); + let tensor2 = tensor2.swap_dims(0, 1); + let tensor1 = tensor.reshape([2, 2, 1]); + let output = tensor2 + tensor1; + println!("{output}"); + + let expected_tensor1 = + TensorData::from([[[4, 4, 4, 4], [6, 6, 6, 6]], [[5, 5, 5, 5], [6, 6, 6, 6]]]); + output.into_data().assert_eq(&expected_tensor1, false); + } + #[test] fn should_support_reshape_int() { let data = TensorData::from([0, 1, 2]); From 3eabb6caa257d229f5c1896f3796a1af0dafba44 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 10:58:39 -0500 Subject: [PATCH 17/28] Add fuse settings --- .../burn-core/src/nn/transformer/decoder.rs | 24 +++--- .../burn-fusion/src/stream/execution/base.rs | 6 ++ crates/burn-jit/src/fusion/base.rs | 8 +- .../burn-jit/src/fusion/elemwise/builder.rs | 20 +++-- crates/burn-jit/src/fusion/matmul/builder.rs | 17 +++- .../burn-jit/src/fusion/on_write/builder.rs | 80 ++++++++++++++----- crates/burn-jit/src/fusion/on_write/trace.rs | 17 +++- .../src/fusion/on_write/trace_builder.rs | 10 ++- 8 files changed, 133 insertions(+), 49 deletions(-) diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index c79c21a78c..4c509bee86 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -455,6 +455,8 @@ impl TransformerDecoder { #[cfg(test)] mod tests { + use burn_tensor::Device; + use super::*; use crate::tensor::Distribution; use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; @@ -481,20 +483,16 @@ mod tests { } fn test_autoregressive(config: TransformerDecoderConfig) { - let device = Default::default(); + let device: Device = Default::default(); let [batch_size, seq_length, d_model] = [3, 4, config.d_model]; - let transformer = config.init(&device); - - let memory = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - &device, - ); - let target = Tensor::::random( - [batch_size, seq_length, d_model], - Distribution::Default, - &device, - ); + let transformer = config.init::(&device); + + let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) + .float() + .reshape([batch_size, seq_length, d_model]); + let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device) + .float() + .reshape([batch_size, seq_length, d_model]); let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device()); let input = TransformerDecoderInput::new(target.clone(), memory.clone()) .target_mask_attn(mask_attn); diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index 149b2d095d..ef31748a37 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -49,6 +49,12 @@ impl OperationQueue { let mut context = self.converter.context(handles); optimization.execute(&mut context); + log::info!("====== MATMUL ======"); + for op in &self.global[0..num_drained] { + log::info!("{op:?}") + } + log::info!("====== END ======"); + self.drain_queue(num_drained, handles); self.operations.drain(0..num_drained); } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 9721735132..48587a1bf9 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -130,10 +130,10 @@ impl FusionRuntime for FusionJitRuntime { device.clone(), BT::as_elem_native_unchecked().into(), )), - // Box::new(MatmulBuilder::::new( - // device.clone(), - // BT::as_elem_native_unchecked().into(), - // )), + Box::new(MatmulBuilder::::new( + device.clone(), + BT::as_elem_native_unchecked().into(), + )), ] } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 13c12b076d..3cf49b8069 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -2,7 +2,10 @@ use burn_fusion::OptimizationBuilder; use crate::{ fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + on_write::{ + builder::{FuseOnWriteBuilder, FuseSettings}, + ir::ElemwisePrecision, + }, JitOptimization, }, JitRuntime, @@ -23,7 +26,16 @@ impl ElementWiseBuilder { let max_bindings = props.hardware_properties().max_bindings; Self { - builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), + builder: FuseOnWriteBuilder::new( + max_bindings, + bool_precision, + FuseSettings { + broadcast: true, + output_shape_updates: true, + mix_vectorization: true, + inplace: true, + }, + ), device, } } @@ -52,9 +64,7 @@ impl OptimizationBuilder> for ElementWiseBuild } fn properties(&self) -> burn_fusion::OptimizationProperties { - let mut props = self.builder.properties(); - props.ready = props.ready && self.builder.num_ops > self.builder.num_reshapes; - props + self.builder.properties() } fn len(&self) -> usize { diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index f197237819..38fb1d0acc 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -3,7 +3,10 @@ use burn_tensor::repr::{FloatOperationDescription, OperationDescription}; use crate::{ fusion::{ - on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + on_write::{ + builder::{FuseOnWriteBuilder, FuseSettings}, + ir::ElemwisePrecision, + }, JitOptimization, }, JitRuntime, @@ -24,10 +27,16 @@ impl MatmulBuilder { let client = R::client(&device); let props = client.properties(); let max_bindings = props.hardware_properties().max_bindings; + let settings = FuseSettings { + broadcast: true, + output_shape_updates: false, + mix_vectorization: true, + inplace: true, + }; Self { - builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), - builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision), + builder: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings), + builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision, settings), device, matmul: None, } @@ -41,6 +50,7 @@ impl OptimizationBuilder> for MatmulBuilder } if self.matmul.is_none() { + log::info!("New matmul fusion"); if let OperationDescription::Float(_, FloatOperationDescription::Matmul(op)) = operation { let lhs = self.builder.input_unhandled(&op.lhs); @@ -56,6 +66,7 @@ impl OptimizationBuilder> for MatmulBuilder )); } else { self.builder.close(); + self.builder_fallback.close(); } } else { self.builder.register(operation); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 70a3820d28..f03d9af9ab 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -13,10 +13,12 @@ use burn_tensor::{ Element, }; use cubecl::ir::Elem; +use serde::{Deserialize, Serialize}; /// Fused element wise operations that are normally memory bound. pub(crate) struct FuseOnWriteBuilder { builder: TryFuseBuilder, + settings: FuseSettings, current_output_shape: Vec, status: OptimizationStatus, pub(crate) num_ops: usize, @@ -24,6 +26,25 @@ pub(crate) struct FuseOnWriteBuilder { max_bindings: u32, } +/// Controls which operations can be fused. +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FuseSettings { + /// Enables broadcasting of shapes. + pub broadcast: bool, + /// Enables output shape updates. + /// + /// When broadcast is enabled, the output shape can become bigger after a fusion, + /// therefore an update is needed. + pub output_shape_updates: bool, + /// Enables mix vectorization factor. + /// + /// Useful when the last dimension is broadcasted for one of the tensors, which would limit the + /// vectorization factor to be 1 without this setting enabled. + pub mix_vectorization: bool, + /// Enables the reuse of input buffers. + pub inplace: bool, +} + struct TryFuseBuilder { builder: FuseOnWriteTraceBuilder, max_bindings: u32, @@ -31,9 +52,9 @@ struct TryFuseBuilder { } impl TryFuseBuilder { - fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { + fn new(max_bindings: u32, bool_precision: ElemwisePrecision, settings: FuseSettings) -> Self { Self { - builder: FuseOnWriteTraceBuilder::new(bool_precision), + builder: FuseOnWriteTraceBuilder::new(bool_precision, settings), max_bindings, added_ops: false, } @@ -70,6 +91,7 @@ impl TryFuseBuilder { impl OptimizationBuilder for FuseOnWriteBuilder { fn register(&mut self, op: &OperationDescription) { + log::info!("Register {op:?}"); if let OptimizationStatus::Closed = self.status { return; } @@ -124,6 +146,10 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { + if !self.properties().ready { + panic!("Building a not ready sss"); + } + self.builder.build(self.current_output_shape.clone()) } @@ -134,7 +160,11 @@ impl OptimizationBuilder for FuseOnWriteBuilder { fn reset(&mut self) { self.num_ops = 0; self.status = OptimizationStatus::Open; - self.builder = TryFuseBuilder::new(self.max_bindings, self.builder.builder.bool_precision); + self.builder = TryFuseBuilder::new( + self.max_bindings, + self.builder.builder.bool_precision, + self.settings, + ); self.current_output_shape.clear(); } @@ -143,7 +173,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn properties(&self) -> OptimizationProperties { - let ready = self.num_ops > 0; + let ready = self.num_ops > 0 && self.num_ops > self.num_reshapes; OptimizationProperties { ready, @@ -153,9 +183,14 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } impl FuseOnWriteBuilder { - pub fn new(max_bindings: u32, bool_precision: ElemwisePrecision) -> Self { + pub fn new( + max_bindings: u32, + bool_precision: ElemwisePrecision, + settings: FuseSettings, + ) -> Self { Self { - builder: TryFuseBuilder::new(max_bindings, bool_precision), + builder: TryFuseBuilder::new(max_bindings, bool_precision, settings), + settings, num_ops: 0, num_reshapes: 0, max_bindings, @@ -196,19 +231,12 @@ impl FuseOnWriteBuilder { }); } - println!( - "Input rank {} : {}", - desc.input.shape.len(), - desc.out.shape.len() - ); if desc.input.shape.len() > desc.out.shape.len() { - println!("Not supported - invalid rank"); // Not yet supported. return false; } if !self.output_is_compatible(&desc.out) { - println!("Not supported - invalid output"); return false; } @@ -223,11 +251,9 @@ impl FuseOnWriteBuilder { true }) { - println!("Fusing Reshape"); self.num_reshapes += 1; true } else { - println!("Can't reshape already fused tensor."); false } } @@ -517,15 +543,29 @@ impl FuseOnWriteBuilder { let curr = self.current_output_shape[i]; let new = out.shape[i]; - // Broadcast is supported. - // - // 0 is the shape id for a global shape of 1. - if curr != new && new != 0 && curr != 0 { + if curr == new { + continue; + } + + // Broadcast not enabled. + if !self.settings.broadcast { return false; } - updated[i] = usize::max(curr, new); + // Broadcasted on new dim. + if new == 0 { + continue; + } + + // Broadcasted on curr dim - update reference output shape. + if curr == 0 && self.settings.output_shape_updates { + updated[i] = new; + continue; + } + + return false; } + core::mem::swap(&mut updated, &mut self.current_output_shape); true diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index e6166e813f..7d141564e8 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -3,6 +3,7 @@ use crate::{ BoolElement, JitRuntime, }; +use super::builder::FuseSettings; use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; use super::position::PositionMapper; use burn_fusion::stream::Context; @@ -19,6 +20,7 @@ use std::collections::BTreeMap; pub struct FuseOnWriteTrace { outputs: RegisteredTensors, inputs: RegisteredTensors, + settings: FuseSettings, scalars: BTreeMap, reshapes: Vec, shape_ref: Vec, @@ -332,6 +334,18 @@ impl FuseOnWriteTrace { tensors_reshaped, ); + // If mix vectorization is disable, we set the vectorization factor of each tensor to the + // minimum value found. + if !self.settings.mix_vectorization { + let factor = analysis.vectorization.values().min().cloned(); + if let Some(factor) = factor { + analysis + .vectorization + .iter_mut() + .for_each(|(_, vf)| *vf = factor); + } + } + for handle in analysis.handle_inputs.iter_mut() { handle.vectorization = *analysis.vectorization.get(&handle.global_id).unwrap(); } @@ -362,7 +376,8 @@ impl FuseOnWriteTrace { let status = &tensor_relative.status; let mut handle = context.handles.get_handle(&tensor_global.id, status); - if status == &TensorStatus::ReadWrite + if self.settings.inplace + && status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) && self diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index babef2ae8c..3032206af2 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -1,7 +1,7 @@ use super::{ + builder::FuseSettings, ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, LayoutInfo, UnaryElemwiseArgs}, - trace::Reshape, - trace::{FuseOnWriteTrace, RegisteredTensors}, + trace::{FuseOnWriteTrace, RegisteredTensors, Reshape}, }; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, @@ -14,6 +14,7 @@ use std::collections::BTreeMap; pub struct FuseOnWriteTraceBuilder { locals: Locals, outputs: RegisteredTensors, + settings: FuseSettings, inputs: RegisteredTensors, scalars: BTreeMap, reshapes: Vec, @@ -25,10 +26,11 @@ pub struct FuseOnWriteTraceBuilder { } impl FuseOnWriteTraceBuilder { - pub fn new(bool_precision: ElemwisePrecision) -> Self { + pub fn new(bool_precision: ElemwisePrecision, settings: FuseSettings) -> Self { Self { locals: Locals::default(), outputs: RegisteredTensors::default(), + settings, inputs: RegisteredTensors::default(), scalars: BTreeMap::default(), reshapes: Vec::new(), @@ -247,11 +249,13 @@ impl FuseOnWriteTraceBuilder { } let reshapes = self.reshapes.clone(); + let settings = self.settings; // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, inputs, + settings, scalars, reshapes, shape, From 10fc2170e2f16a071db00c42d534907b869f7ea6 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 11:57:45 -0500 Subject: [PATCH 18/28] Fix broadcast issue --- .../burn-jit/src/fusion/elemwise/optimization.rs | 1 - crates/burn-jit/src/fusion/on_write/builder.rs | 7 ++++++- .../burn-jit/src/fusion/on_write/trace_builder.rs | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index bc67e9e945..71b44b8e44 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -110,7 +110,6 @@ impl TraceRunner for ElemwiseRunner { }, None => panic!("Invalid argument"), }; - let total_elem = shape.iter().product::() / *vectorization as usize; let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index f03d9af9ab..2a38239f04 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -566,7 +566,12 @@ impl FuseOnWriteBuilder { return false; } - core::mem::swap(&mut updated, &mut self.current_output_shape); + if updated != self.current_output_shape { + if updated != out.shape { + return false; + } + self.current_output_shape.clone_from_slice(&out.shape); + } true } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 3032206af2..31eb11ab91 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -251,6 +251,20 @@ impl FuseOnWriteTraceBuilder { let reshapes = self.reshapes.clone(); let settings = self.settings; + // println!("=== Fusing {} Operations ===", self.ops.len()); + // for (i, r) in reads.iter() { + // println!(" READ {i:?} => {r:?}"); + // } + + // for (i, op) in self.ops.iter().enumerate() { + // println!(" EXECUTE {i} => {op:?}"); + // } + // for (i, w) in writes.iter() { + // println!(" WRITE {i:?} => {w:?}"); + // } + + // println!("=================="); + // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, From dcf563d2297fec4cd33b6f1b2368d7d6f5a9da2f Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 12:19:16 -0500 Subject: [PATCH 19/28] Fix performance --- crates/burn-jit/src/fusion/on_write/builder.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 2a38239f04..b461ed36f6 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -146,10 +146,6 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn build(&self) -> FuseOnWriteTrace { - if !self.properties().ready { - panic!("Building a not ready sss"); - } - self.builder.build(self.current_output_shape.clone()) } @@ -173,7 +169,7 @@ impl OptimizationBuilder for FuseOnWriteBuilder { } fn properties(&self) -> OptimizationProperties { - let ready = self.num_ops > 0 && self.num_ops > self.num_reshapes; + let ready = self.num_ops > 0; OptimizationProperties { ready, From 4471ea33ab5aed64c123ae7de4234b316cbb80c3 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 12:50:19 -0500 Subject: [PATCH 20/28] Some cleanup --- backend-comparison/benches/matmul_fused.rs | 3 +- crates/burn-jit/src/fusion/on_write/ir.rs | 30 +++++++++++++++++++- crates/burn-jit/src/fusion/on_write/trace.rs | 26 ++--------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs index fbec64c648..65c2a92074 100644 --- a/backend-comparison/benches/matmul_fused.rs +++ b/backend-comparison/benches/matmul_fused.rs @@ -26,8 +26,7 @@ impl Benchmark for MatmulBenchmark { } fn execute(&self, (lhs, rhs, bias): Self::Args) { - let bias = bias.unsqueeze(); - gelu(relu(lhs.matmul(rhs)) + bias); + let _output = gelu(relu(lhs.matmul(rhs)) + bias.unsqueeze()); } fn prepare(&self) -> Self::Args { diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index c96c3e1e57..ad774cd24d 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -100,7 +100,7 @@ pub struct ReshapedTensor { shape: Sequence, } -#[derive(CubeLaunch)] +#[derive(CubeLaunch, Default)] /// Global arguments that are used for fusing [element wise operations](ElemwiseOp). pub struct GlobalArgs { pub t_f32: Sequence>>, @@ -127,6 +127,34 @@ pub struct GlobalArgs { pub s_u8: Sequence, } +impl Default for GlobalArgsLaunch<'_, R> { + fn default() -> Self { + Self::new( + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + Default::default(), + ) + } +} impl GlobalArgsLaunch<'_, R> { /// Get the shape of the given [argument](Arg). /// diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index 7d141564e8..2b595d77d9 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -682,30 +682,8 @@ impl FuseOnWriteTrace { &self, handle_outputs: &'s [HandleOutput], ) -> GlobalArgsLaunch<'s, R> { - let mut outputs = GlobalArgsLaunch::new( - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - ); + let mut outputs = GlobalArgsLaunch::default(); + for item in handle_outputs.iter() { match item { HandleOutput::Alias { From b9bf504c16f173fad2a3894cef5ce6b2b82503bc Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 17:40:40 -0500 Subject: [PATCH 21/28] Big refactoring --- .../burn-jit/src/fusion/elemwise/builder.rs | 5 +- crates/burn-jit/src/fusion/matmul/builder.rs | 5 +- .../burn-jit/src/fusion/on_write/builder.rs | 24 +- crates/burn-jit/src/fusion/on_write/mod.rs | 3 +- .../burn-jit/src/fusion/on_write/position.rs | 31 - .../burn-jit/src/fusion/on_write/settings.rs | 20 + crates/burn-jit/src/fusion/on_write/trace.rs | 805 ------------------ .../src/fusion/on_write/trace/base.rs | 164 ++++ .../{trace_builder.rs => trace/builder.rs} | 21 +- .../src/fusion/on_write/trace/executor.rs | 209 +++++ .../src/fusion/on_write/trace/inputs.rs | 85 ++ .../burn-jit/src/fusion/on_write/trace/mod.rs | 14 + .../src/fusion/on_write/trace/outputs.rs | 294 +++++++ .../src/fusion/on_write/trace/plan.rs | 84 ++ .../src/fusion/on_write/trace/runner.rs | 154 ++++ .../fusion/on_write/trace/vectorization.rs | 82 ++ 16 files changed, 1114 insertions(+), 886 deletions(-) delete mode 100644 crates/burn-jit/src/fusion/on_write/position.rs create mode 100644 crates/burn-jit/src/fusion/on_write/settings.rs delete mode 100644 crates/burn-jit/src/fusion/on_write/trace.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/base.rs rename crates/burn-jit/src/fusion/on_write/{trace_builder.rs => trace/builder.rs} (96%) create mode 100644 crates/burn-jit/src/fusion/on_write/trace/executor.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/inputs.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/mod.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/outputs.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/plan.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/runner.rs create mode 100644 crates/burn-jit/src/fusion/on_write/trace/vectorization.rs diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 3cf49b8069..a137f88692 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -2,10 +2,7 @@ use burn_fusion::OptimizationBuilder; use crate::{ fusion::{ - on_write::{ - builder::{FuseOnWriteBuilder, FuseSettings}, - ir::ElemwisePrecision, - }, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, JitOptimization, }, JitRuntime, diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index 38fb1d0acc..fb0ebfb48d 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -3,10 +3,7 @@ use burn_tensor::repr::{FloatOperationDescription, OperationDescription}; use crate::{ fusion::{ - on_write::{ - builder::{FuseOnWriteBuilder, FuseSettings}, - ir::ElemwisePrecision, - }, + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings}, JitOptimization, }, JitRuntime, diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index b461ed36f6..717e80ff2d 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -1,7 +1,7 @@ use super::{ ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, UnaryElemwiseArgs}, - trace::FuseOnWriteTrace, - trace_builder::FuseOnWriteTraceBuilder, + settings::FuseSettings, + trace::{FuseOnWriteTrace, FuseOnWriteTraceBuilder}, }; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ @@ -13,7 +13,6 @@ use burn_tensor::{ Element, }; use cubecl::ir::Elem; -use serde::{Deserialize, Serialize}; /// Fused element wise operations that are normally memory bound. pub(crate) struct FuseOnWriteBuilder { @@ -26,25 +25,6 @@ pub(crate) struct FuseOnWriteBuilder { max_bindings: u32, } -/// Controls which operations can be fused. -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -pub struct FuseSettings { - /// Enables broadcasting of shapes. - pub broadcast: bool, - /// Enables output shape updates. - /// - /// When broadcast is enabled, the output shape can become bigger after a fusion, - /// therefore an update is needed. - pub output_shape_updates: bool, - /// Enables mix vectorization factor. - /// - /// Useful when the last dimension is broadcasted for one of the tensors, which would limit the - /// vectorization factor to be 1 without this setting enabled. - pub mix_vectorization: bool, - /// Enables the reuse of input buffers. - pub inplace: bool, -} - struct TryFuseBuilder { builder: FuseOnWriteTraceBuilder, max_bindings: u32, diff --git a/crates/burn-jit/src/fusion/on_write/mod.rs b/crates/burn-jit/src/fusion/on_write/mod.rs index 4dda2d324f..69bbc724d1 100644 --- a/crates/burn-jit/src/fusion/on_write/mod.rs +++ b/crates/burn-jit/src/fusion/on_write/mod.rs @@ -2,7 +2,6 @@ pub(crate) mod builder; pub(crate) mod io; pub(crate) mod ir; pub(crate) mod kernel; -pub(super) mod position; +pub(crate) mod settings; pub mod trace; -pub(crate) mod trace_builder; diff --git a/crates/burn-jit/src/fusion/on_write/position.rs b/crates/burn-jit/src/fusion/on_write/position.rs deleted file mode 100644 index 194e7dbb92..0000000000 --- a/crates/burn-jit/src/fusion/on_write/position.rs +++ /dev/null @@ -1,31 +0,0 @@ -use super::ir::ElemwisePrecision; -use std::collections::BTreeMap; - -/// Group output position by [element precision](ElemwisePrecision). -#[derive(Default, Debug)] -pub struct PositionMapper { - map: BTreeMap>, -} - -impl PositionMapper { - /// Register a new output with the given precision and position. - pub fn register(&mut self, precision: ElemwisePrecision, pos_handle: usize) { - if let Some(positions) = self.map.get_mut(&precision) { - positions.push(pos_handle); - } else { - self.map.insert(precision, vec![pos_handle]); - } - } - - /// Returns the right position from the precision and the global position in all outputs. - pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { - self.map - .get(&precision) - .unwrap() - .iter() - .enumerate() - .find(|(_pos_elem, pos_all)| **pos_all == pos_handle) - .map(|(pos_elem, _pos_all)| pos_elem) - .unwrap() as u32 - } -} diff --git a/crates/burn-jit/src/fusion/on_write/settings.rs b/crates/burn-jit/src/fusion/on_write/settings.rs new file mode 100644 index 0000000000..761b98e8b2 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/settings.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +/// Controls which operations can be fused. +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FuseSettings { + /// Enables broadcasting of shapes. + pub broadcast: bool, + /// Enables output shape updates. + /// + /// When broadcast is enabled, the output shape can become bigger after a fusion, + /// therefore an update is needed. + pub output_shape_updates: bool, + /// Enables mix vectorization factor. + /// + /// Useful when the last dimension is broadcasted for one of the tensors, which would limit the + /// vectorization factor to be 1 without this setting enabled. + pub mix_vectorization: bool, + /// Enables the reuse of input buffers. + pub inplace: bool, +} diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs deleted file mode 100644 index 2b595d77d9..0000000000 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ /dev/null @@ -1,805 +0,0 @@ -use crate::{ - fusion::{on_write::ir::LayoutInfo, strides_dyn_rank, JitFusionHandle}, - BoolElement, JitRuntime, -}; - -use super::builder::FuseSettings; -use super::ir::{Arg, ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}; -use super::position::PositionMapper; -use burn_fusion::stream::Context; -use burn_tensor::{ - repr::{TensorDescription, TensorId, TensorStatus}, - DType, -}; -use cubecl::{ir::Elem, prelude::*}; -use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; - -#[derive(new, Clone, Serialize, Deserialize, Debug)] -/// Trace containing all element wise operations as well as reads and writes. -pub struct FuseOnWriteTrace { - outputs: RegisteredTensors, - inputs: RegisteredTensors, - settings: FuseSettings, - scalars: BTreeMap, - reshapes: Vec, - shape_ref: Vec, - ops: Vec, - reads: BTreeMap>, - writes: BTreeMap, - inputs_unhandled: Vec, -} - -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct Reshape { - pub reshaped: TensorId, - pub original: TensorId, -} - -/// A trace runner is responsible for determining the vectorization factor as well as launching -/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) -/// with a provided [element wise config](ElemwiseConfig). -pub trait TraceRunner { - /// The error that might happen while running the trace. - type Error; - - /// Run the trace. - fn run<'a>( - &'a self, - client: &'a ComputeClient, - inputs: GlobalArgsLaunch<'a, R>, - outputs: GlobalArgsLaunch<'a, R>, - config: &'a ElemwiseConfig, - ) -> Result<(), Self::Error>; - - /// The vectorization factor for all inputs and outputs. - fn vectorization<'a>( - vectorizations: &mut BTreeMap, - handles_inputs: impl Iterator>, - inputs: impl Iterator, - outputs: impl Iterator, - reshaped: impl Iterator, - ) { - enum Vect { - Broadcated, - Max(u8), - } - - // The default version uses the last dimension as vectorization axis and assumes a - // perpendicular contiguous line. - let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { - let rank = handle.strides.len(); - - // Last dimension strides should be 1, otherwise vecX won't be contiguous. - if handle.strides[rank - 1] != 1 { - return Vect::Max(1); - } - let shape_axis = desc.shape[rank - 1]; - - if shape_axis == 1 { - return Vect::Broadcated; - } - - for s in R::line_size_elem(&desc.dtype.into()) { - // The last dimension should be a multiple of the vector size or broadcated. - if shape_axis % s as usize == 0 { - return Vect::Max(s); - } - } - - Vect::Max(1) - }; - - let vectorization_output = |desc: &TensorDescription| { - let rank = desc.shape.len(); - - for s in R::line_size_elem(&desc.dtype.into()) { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % s as usize == 0 { - return Vect::Max(s); - } - } - - Vect::Max(1) - }; - - let vectorization_reshape = - |reshaped: &TensorDescription, original: &TensorDescription, multi_reads: bool| { - let reshape_axis = reshaped.shape[reshaped.shape.len() - 1]; - let shape_axis = original.shape[original.shape.len() - 1]; - - if !multi_reads && reshape_axis == 1 { - return Vect::Broadcated; - } - - for s in R::line_size_elem(&reshaped.dtype.into()) { - if !multi_reads { - // The last dimension should be a multiple of the vector size or broadcated. - if reshape_axis % s as usize == 0 { - return Vect::Max(s); - } - } else { - // Since the original tensor must share the same vectorization factor as the - // reshaped tensor, they must have compatible shapes when both are access - // independently. - if reshape_axis % s as usize == 0 && shape_axis % s as usize == 0 { - return Vect::Max(s); - } - } - } - - Vect::Max(1) - }; - - let mut max_current = u8::MAX; - - for (handle, tensor) in handles_inputs.zip(inputs) { - match vectorization_input(&handle, tensor) { - Vect::Broadcated => vectorizations.insert(tensor.id, 1), - Vect::Max(val) => { - max_current = Ord::min(val, max_current); - vectorizations.insert(tensor.id, 0) - } - }; - } - - for tensor in outputs { - match vectorization_output(tensor) { - Vect::Broadcated => vectorizations.insert(tensor.id, 1), - Vect::Max(val) => { - max_current = Ord::min(val, max_current); - vectorizations.insert(tensor.id, 0) - } - }; - } - - for (reshaped, original, multi_reads) in reshaped { - match vectorization_reshape(reshaped, original, multi_reads) { - Vect::Broadcated => { - vectorizations.insert(original.id, 1); - vectorizations.insert(reshaped.id, 1); - } - Vect::Max(val) => { - vectorizations.insert(original.id, 0); - vectorizations.insert(reshaped.id, 0); - max_current = Ord::min(val, max_current); - } - } - } - - for (_id, val) in vectorizations.iter_mut() { - if *val == 0 { - *val = max_current; - } - } - } -} - -#[derive(Debug)] -struct LaunchAnalysis<'a, R: JitRuntime> { - potential_inplaces: Vec>, - global_inputs: Vec, - global_outputs: Vec, - handle_inputs: Vec>, - handle_outputs: Vec>, - reference: Option, - reads: BTreeMap>, - writes: BTreeMap, - vectorization: BTreeMap, - rank: usize, -} - -#[derive(Debug)] -enum HandleOutput { - Alias { - input_pos: usize, - precision: ElemwisePrecision, - }, - Owned { - global_id: TensorId, - precision: ElemwisePrecision, - handle: JitFusionHandle, - global_shape: Vec, - vectorization: u8, - }, -} - -#[derive(Debug)] -struct HandleInput { - relative_id: TensorId, - global_id: TensorId, - precision: ElemwisePrecision, - handle: JitFusionHandle, - global_shape: Vec, - vectorization: u8, -} - -#[derive(Debug)] -struct Reference { - layout: Arg, - shape: Vec, - strides: Vec, -} - -#[derive(Debug)] -struct PotentialInplace<'a> { - input_pos: usize, - tensor_relative: &'a TensorDescription, - strides: Vec, -} - -impl FuseOnWriteTrace { - /// Run a trace with the given [runner](TraceRunner). - pub fn run>( - &self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - runner: &Runner, - ) -> Result<(), Runner::Error> { - let analysis = self.analyse::(client, device, context); - let inputs = self.register_inputs(context, &analysis.handle_inputs); - let outputs = self.register_outputs::<_, BT>(&analysis.handle_outputs); - - let mut ops = Sequence::::new(); - - for read_ops in analysis.reads.into_values() { - for op in read_ops { - ops.push(op); - } - } - - for op in self.ops.iter() { - ops.push(op.clone()); - } - - for op in analysis.writes.into_values() { - ops.push(op); - } - - let config = ElemwiseConfig { - rank: analysis.rank as u32, - ref_layout: analysis - .reference - .expect("An output should exist for the fused kernel") - .layout, - ops, - }; - - match Runner::run(runner, client, inputs, outputs, &config) { - Err(err) => { - self.rollback(context, analysis.handle_inputs, analysis.handle_outputs); - Err(err) - } - Ok(val) => Ok(val), - } - } - - fn rollback( - &self, - context: &mut Context<'_, JitFusionHandle>, - handle_inputs: Vec>, - handle_outputs: Vec>, - ) { - for input in handle_inputs { - context - .handles - .register_handle(input.global_id, input.handle); - } - for output in handle_outputs { - if let HandleOutput::Owned { - global_id, handle, .. - } = output - { - context.handles.register_handle(global_id, handle); - } - } - } - - fn analyse<'a, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( - &'a self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - ) -> LaunchAnalysis<'a, R> { - let mut analysis = LaunchAnalysis { - potential_inplaces: Vec::new(), - global_inputs: Vec::new(), - global_outputs: Vec::new(), - handle_inputs: Vec::new(), - handle_outputs: Vec::new(), - reference: None, - vectorization: BTreeMap::default(), - reads: self.reads.clone(), - writes: self.writes.clone(), - rank: self.shape_ref.len(), - }; - - self.analyse_inputs(context, &mut analysis); - self.analyse_outputs::<_, BT>(client, device, context, &mut analysis); - - let tensors_reshaped = self.reshapes.iter().map(|reshape| { - ( - context.tensors.get(&reshape.reshaped).unwrap(), - context.tensors.get(&reshape.original).unwrap(), - self.reads.get(&reshape.original).unwrap().len() > 1, - ) - }); - - Runner::vectorization( - &mut analysis.vectorization, - analysis.handle_inputs.iter().map(|item| &item.handle), - analysis.global_inputs.iter(), - analysis.global_outputs.iter(), - tensors_reshaped, - ); - - // If mix vectorization is disable, we set the vectorization factor of each tensor to the - // minimum value found. - if !self.settings.mix_vectorization { - let factor = analysis.vectorization.values().min().cloned(); - if let Some(factor) = factor { - analysis - .vectorization - .iter_mut() - .for_each(|(_, vf)| *vf = factor); - } - } - - for handle in analysis.handle_inputs.iter_mut() { - handle.vectorization = *analysis.vectorization.get(&handle.global_id).unwrap(); - } - for handle in analysis.handle_outputs.iter_mut() { - match handle { - HandleOutput::Owned { - vectorization, - global_id, - .. - } => *vectorization = *analysis.vectorization.get(&global_id).unwrap(), - _ => {} - } - } - - analysis - } - - fn analyse_inputs<'a, R: JitRuntime>( - &'a self, - context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, R>, - ) { - for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { - let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); - // Important to take the status of the relative graph and not - // the global graph, since the status of the global graph - // might be of a later operation on the same tensor id. - let status = &tensor_relative.status; - let mut handle = context.handles.get_handle(&tensor_global.id, status); - - if self.settings.inplace - && status == &TensorStatus::ReadWrite - && handle.handle.can_mut() - && !self.inputs_unhandled.contains(&tensor_relative.id) - && self - .reshapes - .iter() - .find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) - .is_none() - && self.shape_ref == tensor_relative.shape - { - analysis.potential_inplaces.push(PotentialInplace { - input_pos: i, - tensor_relative, - strides: handle.strides.clone(), - }); - } - - if tensor_global.shape.len() < analysis.rank { - let num_elem: usize = tensor_global.shape.iter().product(); - for _ in 0..(analysis.rank - tensor_global.shape.len()) { - tensor_global.shape.insert(0, 1); - handle.strides.insert(0, num_elem); - } - } - - analysis.handle_inputs.push(HandleInput { - precision, - handle, - relative_id: tensor_relative.id, - global_id: tensor_global.id, - global_shape: tensor_global.shape.clone(), - vectorization: 1, - }); - analysis.global_inputs.push(tensor_global); - } - } - - fn analyse_outputs<'a, R: JitRuntime, BT: BoolElement>( - &'a self, - client: &ComputeClient, - device: &R::Device, - context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, R>, - ) { - let mut position_mapper = PositionMapper::default(); - let mut output_sorted: Vec<_> = self - .outputs - .iter() - .enumerate() - .map(|(pos, (precision, tensor))| { - position_mapper.register(precision, pos); - (pos, (precision, tensor)) - }) - .collect(); - - output_sorted.sort_by(|(_, (_, a)), (_, (_, b))| { - let a_val: usize = a.shape.iter().sum(); - let b_val: usize = b.shape.iter().sum(); - - b_val.cmp(&a_val) - }); - let mut handles = Vec::with_capacity(self.outputs.len()); - let mut globals = Vec::with_capacity(self.outputs.len()); - - for _ in 0..self.outputs.len() { - handles.push(None); - globals.push(None); - } - - for (position_original, (precision, tensor_relative)) in output_sorted.into_iter() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); - let strides = strides_dyn_rank(&tensor_global.shape); - - if let Some(index) = analysis - .potential_inplaces - .iter() - .enumerate() - .find(|(_pos, pi)| { - pi.tensor_relative.dtype == tensor_global.dtype - && pi.tensor_relative.shape == tensor_relative.shape - && pi.strides == strides - }) - .map(|(pos, _)| pos) - { - let potential_inplace = analysis.potential_inplaces.remove(index); - let handle_input = analysis - .handle_inputs - .get(potential_inplace.input_pos) - .unwrap(); - - if analysis.reference.is_none() { - let index_input = self - .inputs - .get_index(precision, potential_inplace.tensor_relative.id) - .unwrap(); - - analysis.reference = Some(Reference { - layout: Arg::Input(index_input as u32, precision, LayoutInfo::IsRef), - shape: tensor_global.shape.clone(), - strides: handle_input.handle.strides.clone(), - }); - - if let Some(ops) = analysis.reads.get_mut(&handle_input.relative_id) { - for op in ops.iter_mut() { - if let ElemwiseOp::Assign(op) = op { - op.input.add_layout_info(LayoutInfo::IsRef); - }; - } - } - - if let Some(ElemwiseOp::Assign(op)) = - analysis.writes.get_mut(&tensor_relative.id) - { - op.out.add_layout_info(LayoutInfo::IsRef); - }; - } - - context - .handles - .register_handle(tensor_global.id, handle_input.handle.clone()); - - handles[position_original] = Some(HandleOutput::Alias { - input_pos: potential_inplace.input_pos, - precision, - }); - globals[position_original] = Some(tensor_global); - } else { - if analysis.reference.is_none() { - let position = position_mapper.resolve_index(&precision, position_original); - analysis.reference = Some(Reference { - layout: Arg::Output(position, precision, LayoutInfo::IsRef), - shape: tensor_global.shape.clone(), - strides: strides.clone(), - }); - - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&tensor_relative.id).unwrap() - { - op.out.add_layout_info(LayoutInfo::IsRef); - }; - } else if let Some(reference) = analysis.reference.as_ref() { - if reference.strides == strides && reference.shape == tensor_global.shape { - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&tensor_relative.id).unwrap() - { - op.out.add_layout_info(LayoutInfo::SameAsRef); - }; - } - } - - // We encode bool tensors as `B`. - let dtype = match tensor_global.dtype { - DType::Bool => BT::dtype(), - _ => tensor_global.dtype, - }; - let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); - - let handle = JitFusionHandle { - client: client.clone(), - handle: client.empty(size), - device: device.clone(), - strides, - dtype, - }; - - analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); - context - .handles - .register_handle(tensor_global.id, handle.clone()); - - handles[position_original] = Some(HandleOutput::Owned { - precision, - handle, - global_shape: tensor_global.shape.clone(), - global_id: tensor_global.id, - vectorization: 1, - }); - globals[position_original] = Some(tensor_global); - } - } - - for (handle, global) in handles.into_iter().zip(globals.into_iter()) { - analysis.handle_outputs.push(handle.unwrap()); - analysis.global_outputs.push(global.unwrap()); - } - - Self::add_layout_info_inputs(analysis); - } - - fn add_layout_info_inputs(analysis: &mut LaunchAnalysis<'_, R>) { - for hi in analysis.handle_inputs.iter() { - if let Some(reference) = analysis.reference.as_ref() { - if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { - if let Some(ops) = analysis.reads.get_mut(&hi.relative_id) { - for op in ops.iter_mut() { - if let ElemwiseOp::Assign(op) = op { - op.input.add_layout_info(LayoutInfo::SameAsRef); - } - } - } - } - } - } - } - - fn register_inputs<'h, R: JitRuntime>( - &self, - context: &mut Context<'_, JitFusionHandle>, - handle_inputs: &'h [HandleInput], - ) -> GlobalArgsLaunch<'h, R> { - let mut inputs = GlobalArgsLaunch::new( - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - SequenceArg::new(), - ); - - for hi in handle_inputs.iter() { - let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); - match hi.precision { - ElemwisePrecision::F32 => inputs.t_f32.push(arg), - ElemwisePrecision::F16 => inputs.t_f16.push(arg), - ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), - ElemwisePrecision::I64 => inputs.t_i64.push(arg), - ElemwisePrecision::I32 => inputs.t_i32.push(arg), - ElemwisePrecision::I16 => inputs.t_i16.push(arg), - ElemwisePrecision::I8 => inputs.t_i8.push(arg), - ElemwisePrecision::U64 => inputs.t_u64.push(arg), - ElemwisePrecision::U32 => inputs.t_u32.push(arg), - ElemwisePrecision::U16 => inputs.t_u16.push(arg), - ElemwisePrecision::U8 => inputs.t_u8.push(arg), - _ => panic!("Unsupported input precision {:?}", hi.precision), - }; - } - - for (precision, count) in self.scalars.iter() { - for i in 0..(*count as usize) { - match precision { - ElemwisePrecision::F32 => { - inputs.s_f32.push(ScalarArg::new(context.scalar_f32[i])) - } - ElemwisePrecision::F16 => { - inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) - } - ElemwisePrecision::BF16 => { - inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) - } - ElemwisePrecision::I64 => { - inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) - } - ElemwisePrecision::I32 => { - inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) - } - ElemwisePrecision::I16 => { - inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) - } - ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), - ElemwisePrecision::U64 => { - inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) - } - ElemwisePrecision::U32 => { - inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) - } - ElemwisePrecision::U16 => { - inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) - } - ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), - ElemwisePrecision::Bool => todo!(), - } - } - } - - for relative in self.reshapes.iter().rev() { - let global = context.tensors.get(&relative.reshaped).unwrap(); - - for shape in global.shape.iter().rev() { - inputs.s_u32.push(ScalarArg::new(*shape as u32)) - } - } - - inputs - } - - fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( - &self, - handle_outputs: &'s [HandleOutput], - ) -> GlobalArgsLaunch<'s, R> { - let mut outputs = GlobalArgsLaunch::default(); - - for item in handle_outputs.iter() { - match item { - HandleOutput::Alias { - input_pos, - precision, - } => match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), - ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), - _ => todo!(), - }, - HandleOutput::Owned { - precision, - handle, - global_shape, - vectorization, - .. - } => { - let arg = handle.as_tensor_arg(global_shape, *vectorization); - - match precision { - ElemwisePrecision::F32 => outputs.t_f32.push(arg), - ElemwisePrecision::F16 => outputs.t_f16.push(arg), - ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), - ElemwisePrecision::I64 => outputs.t_i64.push(arg), - ElemwisePrecision::I32 => outputs.t_i32.push(arg), - ElemwisePrecision::I16 => outputs.t_i16.push(arg), - ElemwisePrecision::I8 => outputs.t_i8.push(arg), - ElemwisePrecision::U64 => outputs.t_u64.push(arg), - ElemwisePrecision::U32 => outputs.t_u32.push(arg), - ElemwisePrecision::U16 => outputs.t_u16.push(arg), - ElemwisePrecision::U8 => outputs.t_u8.push(arg), - ElemwisePrecision::Bool => match BT::dtype() { - DType::U32 => outputs.t_u32.push(arg), - DType::U8 => outputs.t_u8.push(arg), - _ => todo!(), - }, - }; - } - } - } - - outputs - } -} - -#[derive(Default, Clone, Serialize, Deserialize, Debug)] -pub struct RegisteredTensors { - tensors: BTreeMap>, -} - -impl RegisteredTensors { - pub fn iter(&self) -> impl Iterator { - self.tensors.iter().flat_map(|(precision, descriptions)| { - descriptions.iter().map(|desc| (*precision, desc)) - }) - } - - pub fn len(&self) -> usize { - self.tensors.values().map(|v| v.len()).sum() - } - - pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { - self.tensors.get(&precision).and_then(|items| { - items - .iter() - .enumerate() - .find(|(_pos, tensor)| tensor.id == tensor_id) - .map(|(pos, _)| pos) - }) - } - - pub fn get_all(&self, precision: ElemwisePrecision) -> &[TensorDescription] { - self.tensors - .get(&precision) - .map(|v| v.as_slice()) - .unwrap_or(&[]) - } - - pub fn get( - &self, - precision: ElemwisePrecision, - tensor_id: TensorId, - ) -> Option<&TensorDescription> { - self.get_all(precision) - .iter() - .find(|desc| desc.id == tensor_id) - } - - pub fn insert(&mut self, precision: ElemwisePrecision, tensor: TensorDescription) -> u32 { - if let Some(tensors) = self.tensors.get_mut(&precision) { - let position = tensors.len() as u32; - tensors.push(tensor); - position - } else { - self.tensors.insert(precision, vec![tensor]); - 0 - } - } - - pub fn update(&mut self, precision: ElemwisePrecision, tensor: &TensorDescription) { - if let Some(tensors) = self.tensors.get_mut(&precision) { - if let Some(tensor_old) = tensors - .iter_mut() - .find(|tensor_old| tensor_old.id == tensor.id) - { - tensor_old.status = tensor.status.clone(); - } - } - } -} diff --git a/crates/burn-jit/src/fusion/on_write/trace/base.rs b/crates/burn-jit/src/fusion/on_write/trace/base.rs new file mode 100644 index 0000000000..cad56f9609 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/base.rs @@ -0,0 +1,164 @@ +use crate::{fusion::JitFusionHandle, BoolElement, JitRuntime}; + +use super::{ + super::{ + ir::{ElemwiseOp, ElemwisePrecision}, + settings::FuseSettings, + }, + executor::LaunchPlanExecutor, + inputs::InputsPlanner, + outputs::OutputsPlanner, + vectorization::VectorizationPlanner, + HandleInput, HandleOutput, LaunchPlan, TraceRunner, +}; +use burn_fusion::stream::Context; +use burn_tensor::repr::{TensorDescription, TensorId}; +use cubecl::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +#[derive(new, Clone, Serialize, Deserialize, Debug)] +/// Trace containing all element wise operations as well as reads and writes. +pub struct FuseOnWriteTrace { + outputs: RegisteredTensors, + inputs: RegisteredTensors, + settings: FuseSettings, + scalars: BTreeMap, + reshapes: Vec, + shape_ref: Vec, + ops: Vec, + reads: BTreeMap>, + writes: BTreeMap, + inputs_unhandled: Vec, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Reshape { + pub reshaped: TensorId, + pub original: TensorId, +} + +impl FuseOnWriteTrace { + /// Run a trace with the given [runner](TraceRunner). + pub fn run>( + &self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + runner: &Runner, + ) -> Result<(), Runner::Error> { + let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len()); + + InputsPlanner::::new( + &self.inputs, + &self.inputs_unhandled, + &self.reshapes, + &self.shape_ref, + &self.settings, + ) + .run(context, &mut plan); + + OutputsPlanner::::new(&self.inputs, &self.outputs) + .run::(client, device, context, &mut plan); + + VectorizationPlanner::::new(&self.reshapes, &self.reads, &self.settings) + .run::(context, &mut plan); + + match LaunchPlanExecutor::::new(&self.scalars, &self.reshapes, &self.ops) + .execute::<_, BT>(client, runner, context, plan) + { + Err((err, handle_inputs, handle_outputs)) => { + self.rollback(context, handle_inputs, handle_outputs); + Err(err) + } + Ok(val) => Ok(val), + } + } + + fn rollback( + &self, + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: Vec>, + handle_outputs: Vec>, + ) { + for input in handle_inputs { + context + .handles + .register_handle(input.global_id, input.handle); + } + for output in handle_outputs { + if let HandleOutput::Owned { + global_id, handle, .. + } = output + { + context.handles.register_handle(global_id, handle); + } + } + } +} + +#[derive(Default, Clone, Serialize, Deserialize, Debug)] +pub struct RegisteredTensors { + tensors: BTreeMap>, +} + +impl RegisteredTensors { + pub fn iter(&self) -> impl Iterator { + self.tensors.iter().flat_map(|(precision, descriptions)| { + descriptions.iter().map(|desc| (*precision, desc)) + }) + } + + pub fn len(&self) -> usize { + self.tensors.values().map(|v| v.len()).sum() + } + + pub fn get_index(&self, precision: ElemwisePrecision, tensor_id: TensorId) -> Option { + self.tensors.get(&precision).and_then(|items| { + items + .iter() + .enumerate() + .find(|(_pos, tensor)| tensor.id == tensor_id) + .map(|(pos, _)| pos) + }) + } + + pub fn get_all(&self, precision: ElemwisePrecision) -> &[TensorDescription] { + self.tensors + .get(&precision) + .map(|v| v.as_slice()) + .unwrap_or(&[]) + } + + pub fn get( + &self, + precision: ElemwisePrecision, + tensor_id: TensorId, + ) -> Option<&TensorDescription> { + self.get_all(precision) + .iter() + .find(|desc| desc.id == tensor_id) + } + + pub fn insert(&mut self, precision: ElemwisePrecision, tensor: TensorDescription) -> u32 { + if let Some(tensors) = self.tensors.get_mut(&precision) { + let position = tensors.len() as u32; + tensors.push(tensor); + position + } else { + self.tensors.insert(precision, vec![tensor]); + 0 + } + } + + pub fn update(&mut self, precision: ElemwisePrecision, tensor: &TensorDescription) { + if let Some(tensors) = self.tensors.get_mut(&precision) { + if let Some(tensor_old) = tensors + .iter_mut() + .find(|tensor_old| tensor_old.id == tensor.id) + { + tensor_old.status = tensor.status.clone(); + } + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace/builder.rs similarity index 96% rename from crates/burn-jit/src/fusion/on_write/trace_builder.rs rename to crates/burn-jit/src/fusion/on_write/trace/builder.rs index 31eb11ab91..8d30ec0283 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/builder.rs @@ -1,8 +1,8 @@ -use super::{ - builder::FuseSettings, +use super::super::{ ir::{Arg, BinaryElemwiseArgs, ElemwiseOp, ElemwisePrecision, LayoutInfo, UnaryElemwiseArgs}, - trace::{FuseOnWriteTrace, RegisteredTensors, Reshape}, + settings::FuseSettings, }; +use super::{FuseOnWriteTrace, RegisteredTensors, Reshape}; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, DType, Element, @@ -251,21 +251,6 @@ impl FuseOnWriteTraceBuilder { let reshapes = self.reshapes.clone(); let settings = self.settings; - // println!("=== Fusing {} Operations ===", self.ops.len()); - // for (i, r) in reads.iter() { - // println!(" READ {i:?} => {r:?}"); - // } - - // for (i, op) in self.ops.iter().enumerate() { - // println!(" EXECUTE {i} => {op:?}"); - // } - // for (i, w) in writes.iter() { - // println!(" WRITE {i:?} => {w:?}"); - // } - - // println!("=================="); - - // Current problem is that I need btreemap instead of sequences. FuseOnWriteTrace::new( outputs, inputs, diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs new file mode 100644 index 0000000000..a26bbf991c --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -0,0 +1,209 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use burn_fusion::stream::Context; +use burn_tensor::DType; +use cubecl::{ + client::ComputeClient, + prelude::{ScalarArg, Sequence, TensorArg}, +}; + +use super::{HandleInput, HandleOutput, LaunchPlan, Reshape, TraceRunner}; +use crate::{ + fusion::{ + on_write::ir::{ElemwiseConfig, ElemwiseOp, ElemwisePrecision, GlobalArgsLaunch}, + JitFusionHandle, + }, + BoolElement, JitRuntime, +}; + +pub struct LaunchPlanExecutor<'a, R: JitRuntime> { + scalars: &'a BTreeMap, + reshapes: &'a Vec, + ops: &'a Vec, + _r: PhantomData, +} + +impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { + pub fn new( + scalars: &'a BTreeMap, + reshapes: &'a Vec, + ops: &'a Vec, + ) -> Self { + Self { + scalars, + reshapes, + ops, + _r: PhantomData, + } + } + + pub fn execute, BT: BoolElement>( + self, + client: &ComputeClient, + runner: &Runner, + context: &mut Context<'_, JitFusionHandle>, + plan: LaunchPlan<'a, R>, + ) -> Result<(), (Runner::Error, Vec>, Vec>)> { + let inputs = self.register_inputs(context, &plan.handle_inputs); + let outputs = self.register_outputs::(&plan.handle_outputs); + + let mut ops = Sequence::::new(); + + for read_ops in plan.reads.into_values() { + for op in read_ops { + ops.push(op); + } + } + + for op in self.ops.iter() { + ops.push(op.clone()); + } + + for op in plan.writes.into_values() { + ops.push(op); + } + + let config = ElemwiseConfig { + rank: plan.rank as u32, + ref_layout: plan + .reference + .expect("An output should exist for the fused kernel") + .layout, + ops, + }; + + Runner::run(runner, client, inputs, outputs, &config) + .map_err(|err| (err, plan.handle_inputs, plan.handle_outputs)) + } + fn register_inputs<'h>( + &self, + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: &'h [HandleInput], + ) -> GlobalArgsLaunch<'h, R> { + let mut inputs = GlobalArgsLaunch::default(); + + for hi in handle_inputs.iter() { + let arg = hi.handle.as_tensor_arg(&hi.global_shape, hi.vectorization); + match hi.precision { + ElemwisePrecision::F32 => inputs.t_f32.push(arg), + ElemwisePrecision::F16 => inputs.t_f16.push(arg), + ElemwisePrecision::BF16 => inputs.t_bf16.push(arg), + ElemwisePrecision::I64 => inputs.t_i64.push(arg), + ElemwisePrecision::I32 => inputs.t_i32.push(arg), + ElemwisePrecision::I16 => inputs.t_i16.push(arg), + ElemwisePrecision::I8 => inputs.t_i8.push(arg), + ElemwisePrecision::U64 => inputs.t_u64.push(arg), + ElemwisePrecision::U32 => inputs.t_u32.push(arg), + ElemwisePrecision::U16 => inputs.t_u16.push(arg), + ElemwisePrecision::U8 => inputs.t_u8.push(arg), + _ => panic!("Unsupported input precision {:?}", hi.precision), + }; + } + + for (precision, count) in self.scalars.iter() { + for i in 0..(*count as usize) { + match precision { + ElemwisePrecision::F32 => { + inputs.s_f32.push(ScalarArg::new(context.scalar_f32[i])) + } + ElemwisePrecision::F16 => { + inputs.s_f16.push(ScalarArg::new(context.scalar_f16[i])) + } + ElemwisePrecision::BF16 => { + inputs.s_bf16.push(ScalarArg::new(context.scalar_bf16[i])) + } + ElemwisePrecision::I64 => { + inputs.s_i64.push(ScalarArg::new(context.scalar_i64[i])) + } + ElemwisePrecision::I32 => { + inputs.s_i32.push(ScalarArg::new(context.scalar_i32[i])) + } + ElemwisePrecision::I16 => { + inputs.s_i16.push(ScalarArg::new(context.scalar_i16[i])) + } + ElemwisePrecision::I8 => inputs.s_i8.push(ScalarArg::new(context.scalar_i8[i])), + ElemwisePrecision::U64 => { + inputs.s_u64.push(ScalarArg::new(context.scalar_u64[i])) + } + ElemwisePrecision::U32 => { + inputs.s_u32.push(ScalarArg::new(context.scalar_u32[i])) + } + ElemwisePrecision::U16 => { + inputs.s_u16.push(ScalarArg::new(context.scalar_u16[i])) + } + ElemwisePrecision::U8 => inputs.s_u8.push(ScalarArg::new(context.scalar_u8[i])), + ElemwisePrecision::Bool => todo!(), + } + } + } + + for relative in self.reshapes.iter().rev() { + let global = context.tensors.get(&relative.reshaped).unwrap(); + + for shape in global.shape.iter().rev() { + inputs.s_u32.push(ScalarArg::new(*shape as u32)) + } + } + + inputs + } + + fn register_outputs<'s, BT: BoolElement>( + &self, + handle_outputs: &'s [HandleOutput], + ) -> GlobalArgsLaunch<'s, R> { + let mut outputs = GlobalArgsLaunch::default(); + + for item in handle_outputs.iter() { + match item { + HandleOutput::Alias { + input_pos, + precision, + } => match precision { + ElemwisePrecision::F32 => outputs.t_f32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::F16 => outputs.t_f16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::BF16 => outputs.t_bf16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I64 => outputs.t_i64.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I32 => outputs.t_i32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I16 => outputs.t_i16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::I8 => outputs.t_i8.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U64 => outputs.t_u64.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U32 => outputs.t_u32.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U16 => outputs.t_u16.push(TensorArg::alias(*input_pos)), + ElemwisePrecision::U8 => outputs.t_u8.push(TensorArg::alias(*input_pos)), + _ => todo!(), + }, + HandleOutput::Owned { + precision, + handle, + global_shape, + vectorization, + .. + } => { + let arg = handle.as_tensor_arg(global_shape, *vectorization); + + match precision { + ElemwisePrecision::F32 => outputs.t_f32.push(arg), + ElemwisePrecision::F16 => outputs.t_f16.push(arg), + ElemwisePrecision::BF16 => outputs.t_bf16.push(arg), + ElemwisePrecision::I64 => outputs.t_i64.push(arg), + ElemwisePrecision::I32 => outputs.t_i32.push(arg), + ElemwisePrecision::I16 => outputs.t_i16.push(arg), + ElemwisePrecision::I8 => outputs.t_i8.push(arg), + ElemwisePrecision::U64 => outputs.t_u64.push(arg), + ElemwisePrecision::U32 => outputs.t_u32.push(arg), + ElemwisePrecision::U16 => outputs.t_u16.push(arg), + ElemwisePrecision::U8 => outputs.t_u8.push(arg), + ElemwisePrecision::Bool => match BT::dtype() { + DType::U32 => outputs.t_u32.push(arg), + DType::U8 => outputs.t_u8.push(arg), + _ => todo!(), + }, + }; + } + } + } + + outputs + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs new file mode 100644 index 0000000000..0ae74c708d --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs @@ -0,0 +1,85 @@ +use super::Reshape; +use crate::{ + fusion::{on_write::settings::FuseSettings, JitFusionHandle}, + JitRuntime, +}; +use burn_fusion::stream::Context; +use burn_tensor::repr::{TensorId, TensorStatus}; +use std::marker::PhantomData; + +use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; + +pub struct InputsPlanner<'a, R: JitRuntime> { + inputs: &'a RegisteredTensors, + inputs_unhandled: &'a Vec, + reshapes: &'a Vec, + shape_ref: &'a Vec, + settings: &'a FuseSettings, + _r: PhantomData, +} + +impl<'a, R: JitRuntime> InputsPlanner<'a, R> { + pub fn new( + inputs: &'a RegisteredTensors, + inputs_unhandled: &'a Vec, + reshapes: &'a Vec, + shape_ref: &'a Vec, + settings: &'a FuseSettings, + ) -> Self { + Self { + inputs, + settings, + inputs_unhandled, + reshapes, + shape_ref, + _r: PhantomData, + } + } + + pub fn run(self, context: &mut Context<'_, JitFusionHandle>, plan: &mut LaunchPlan<'a, R>) { + for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { + let mut tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); + // Important to take the status of the relative graph and not + // the global graph, since the status of the global graph + // might be of a later operation on the same tensor id. + let status = &tensor_relative.status; + let mut handle = context.handles.get_handle(&tensor_global.id, status); + + if self.settings.inplace + && status == &TensorStatus::ReadWrite + && handle.handle.can_mut() + && !self.inputs_unhandled.contains(&tensor_relative.id) + && self + .reshapes + .iter() + .find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) + .is_none() + && self.shape_ref == &tensor_relative.shape + { + plan.potential_inplaces.push(PotentialInplace { + input_pos: i, + tensor_relative, + strides: handle.strides.clone(), + }); + } + + if tensor_global.shape.len() < plan.rank { + let num_elem: usize = tensor_global.shape.iter().product(); + for _ in 0..(plan.rank - tensor_global.shape.len()) { + tensor_global.shape.insert(0, 1); + handle.strides.insert(0, num_elem); + } + } + + plan.handle_inputs.push(HandleInput { + precision, + handle, + relative_id: tensor_relative.id, + global_id: tensor_global.id, + global_shape: tensor_global.shape.clone(), + vectorization: 1, + }); + plan.global_inputs.push(tensor_global); + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/mod.rs b/crates/burn-jit/src/fusion/on_write/trace/mod.rs new file mode 100644 index 0000000000..a3f0575299 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/mod.rs @@ -0,0 +1,14 @@ +pub(crate) mod executor; +pub(crate) mod inputs; +pub(crate) mod outputs; +pub(crate) mod vectorization; + +mod base; +mod builder; +mod plan; +mod runner; + +pub use base::*; +pub use builder::*; +pub use plan::*; +pub use runner::*; diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs new file mode 100644 index 0000000000..812880b197 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs @@ -0,0 +1,294 @@ +use burn_fusion::stream::Context; +use burn_tensor::{repr::TensorDescription, DType}; +use cubecl::{client::ComputeClient, ir::Elem}; + +use crate::{ + fusion::{ + on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, + strides_dyn_rank, JitFusionHandle, + }, + BoolElement, JitRuntime, +}; + +use super::{super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors}; +use std::collections::BTreeMap; + +pub struct OutputsPlanner<'a, R: JitRuntime> { + inputs: &'a RegisteredTensors, + outputs_sorted: Vec>, + handles: Vec>>, + globals: Vec>, + mapper: OutputPositionMapper, +} + +struct OutputSorted<'a> { + pos_original: usize, + precision: ElemwisePrecision, + tensor_relative: &'a TensorDescription, +} + +impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { + pub fn new(inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors) -> Self { + let mut mapper = OutputPositionMapper::default(); + let mut outputs_sorted: Vec<_> = outputs + .iter() + .enumerate() + .map(|(pos, (precision, tensor))| { + mapper.register(precision, pos); + OutputSorted { + pos_original: pos, + precision, + tensor_relative: tensor, + } + }) + .collect(); + + outputs_sorted.sort_by(|a, b| { + let a_val: usize = a.tensor_relative.shape.iter().sum(); + let b_val: usize = b.tensor_relative.shape.iter().sum(); + + b_val.cmp(&a_val) + }); + + let mut handles = Vec::with_capacity(outputs.len()); + let mut globals = Vec::with_capacity(outputs.len()); + + for _ in 0..outputs.len() { + handles.push(None); + globals.push(None); + } + + Self { + inputs, + outputs_sorted, + handles, + globals, + mapper, + } + } + + pub fn run( + mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + analysis: &mut LaunchPlan<'a, R>, + ) { + // So that we can borrow self during the iteration. + let mut outputs = Vec::new(); + core::mem::swap(&mut outputs, &mut self.outputs_sorted); + + for output in outputs.into_iter() { + let tensor_global = context + .tensors + .get(&output.tensor_relative.id) + .unwrap() + .clone(); + let strides = strides_dyn_rank(&tensor_global.shape); + + match Self::select_input_inplace(analysis, &tensor_global, &output, &strides) { + Some(index) => { + self.analyse_inplace(context, analysis, output, tensor_global, index); + } + None => { + self.analyse_output::( + client, + device, + context, + analysis, + output, + tensor_global, + strides, + ); + } + } + } + + for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) { + analysis.handle_outputs.push(handle.unwrap()); + analysis.global_outputs.push(global.unwrap()); + } + + Self::add_layout_info_inputs(analysis); + } + + fn add_layout_info_inputs(analysis: &mut LaunchPlan<'_, R>) { + for hi in analysis.handle_inputs.iter() { + if let Some(reference) = analysis.reference.as_ref() { + if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { + if let Some(ops) = analysis.reads.get_mut(&hi.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::SameAsRef); + } + } + } + } + } + } + } + + fn select_input_inplace( + analysis: &mut LaunchPlan<'a, R>, + tensor_global: &TensorDescription, + output: &OutputSorted, + strides: &[usize], + ) -> Option { + analysis + .potential_inplaces + .iter() + .enumerate() + .find(|(_pos, pi)| { + pi.tensor_relative.dtype == tensor_global.dtype + && pi.tensor_relative.shape == output.tensor_relative.shape + && pi.strides == strides + }) + .map(|(pos, _)| pos) + } + + fn analyse_inplace( + &mut self, + context: &mut Context<'_, JitFusionHandle>, + analysis: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + input_index: usize, + ) { + let potential_inplace = analysis.potential_inplaces.remove(input_index); + let handle_input = analysis + .handle_inputs + .get(potential_inplace.input_pos) + .unwrap(); + + if analysis.reference.is_none() { + let index_input = self + .inputs + .get_index(output.precision, potential_inplace.tensor_relative.id) + .unwrap(); + + analysis.reference = Some(Reference { + layout: Arg::Input(index_input as u32, output.precision, LayoutInfo::IsRef), + shape: tensor_global.shape.clone(), + strides: handle_input.handle.strides.clone(), + }); + + if let Some(ops) = analysis.reads.get_mut(&handle_input.relative_id) { + for op in ops.iter_mut() { + if let ElemwiseOp::Assign(op) = op { + op.input.add_layout_info(LayoutInfo::IsRef); + }; + } + } + + if let Some(ElemwiseOp::Assign(op)) = + analysis.writes.get_mut(&output.tensor_relative.id) + { + op.out.add_layout_info(LayoutInfo::IsRef); + }; + } + + context + .handles + .register_handle(tensor_global.id, handle_input.handle.clone()); + + self.handles[output.pos_original] = Some(HandleOutput::Alias { + input_pos: potential_inplace.input_pos, + precision: output.precision, + }); + self.globals[output.pos_original] = Some(tensor_global); + } + + fn analyse_output( + &mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + analysis: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + strides: Vec, + ) { + if analysis.reference.is_none() { + let position = self + .mapper + .resolve_index(&output.precision, output.pos_original); + analysis.reference = Some(Reference { + layout: Arg::Output(position, output.precision, LayoutInfo::IsRef), + shape: tensor_global.shape.clone(), + strides: strides.clone(), + }); + + if let ElemwiseOp::Assign(op) = + analysis.writes.get_mut(&output.tensor_relative.id).unwrap() + { + op.out.add_layout_info(LayoutInfo::IsRef); + }; + } else if let Some(reference) = analysis.reference.as_ref() { + if reference.strides == strides && reference.shape == tensor_global.shape { + if let ElemwiseOp::Assign(op) = + analysis.writes.get_mut(&output.tensor_relative.id).unwrap() + { + op.out.add_layout_info(LayoutInfo::SameAsRef); + }; + } + } + + // We encode bool tensors as `B`. + let dtype = match tensor_global.dtype { + DType::Bool => BT::dtype(), + _ => tensor_global.dtype, + }; + let size = tensor_global.shape.iter().product::() * Elem::from(dtype).size(); + + let handle = JitFusionHandle { + client: client.clone(), + handle: client.empty(size), + device: device.clone(), + strides, + dtype, + }; + + analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); + context + .handles + .register_handle(tensor_global.id, handle.clone()); + + self.handles[output.pos_original] = Some(HandleOutput::Owned { + precision: output.precision, + handle, + global_shape: tensor_global.shape.clone(), + global_id: tensor_global.id, + vectorization: 1, + }); + self.globals[output.pos_original] = Some(tensor_global); + } +} + +/// Group output position by [element precision](ElemwisePrecision). +#[derive(Default, Debug)] +pub struct OutputPositionMapper { + map: BTreeMap>, +} + +impl OutputPositionMapper { + /// Register a new output with the given precision and position. + pub fn register(&mut self, precision: ElemwisePrecision, pos_handle: usize) { + if let Some(positions) = self.map.get_mut(&precision) { + positions.push(pos_handle); + } else { + self.map.insert(precision, vec![pos_handle]); + } + } + + /// Returns the right position from the precision and the global position in all outputs. + pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { + self.map + .get(&precision) + .unwrap() + .iter() + .enumerate() + .find(|(_pos_elem, pos_all)| **pos_all == pos_handle) + .map(|(pos_elem, _pos_all)| pos_elem) + .unwrap() as u32 + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs new file mode 100644 index 0000000000..896b8d7f51 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -0,0 +1,84 @@ +use std::collections::BTreeMap; + +use crate::{ + fusion::{ + on_write::ir::{Arg, ElemwiseOp, ElemwisePrecision}, + JitFusionHandle, + }, + JitRuntime, +}; +use burn_tensor::repr::{TensorDescription, TensorId}; + +#[derive(Debug)] +pub(crate) struct LaunchPlan<'a, R: JitRuntime> { + pub potential_inplaces: Vec>, + pub global_inputs: Vec, + pub global_outputs: Vec, + pub handle_inputs: Vec>, + pub handle_outputs: Vec>, + pub reference: Option, + pub reads: BTreeMap>, + pub writes: BTreeMap, + pub vectorization: BTreeMap, + pub rank: usize, +} + +impl<'a, R: JitRuntime> LaunchPlan<'a, R> { + pub fn new( + reads: &BTreeMap>, + writes: &BTreeMap, + rank: usize, + ) -> Self { + LaunchPlan { + potential_inplaces: Vec::new(), + global_inputs: Vec::new(), + global_outputs: Vec::new(), + handle_inputs: Vec::new(), + handle_outputs: Vec::new(), + reference: None, + vectorization: BTreeMap::default(), + reads: reads.clone(), + writes: writes.clone(), + rank, + } + } +} + +#[derive(Debug)] +pub enum HandleOutput { + Alias { + input_pos: usize, + precision: ElemwisePrecision, + }, + Owned { + global_id: TensorId, + precision: ElemwisePrecision, + handle: JitFusionHandle, + global_shape: Vec, + vectorization: u8, + }, +} + +#[derive(Debug)] +pub struct HandleInput { + pub relative_id: TensorId, + pub global_id: TensorId, + pub precision: ElemwisePrecision, + pub handle: JitFusionHandle, + pub global_shape: Vec, + pub vectorization: u8, +} + +#[derive(Debug)] +pub struct Reference { + pub layout: Arg, + pub shape: Vec, + pub strides: Vec, +} + +#[derive(Debug)] +pub struct PotentialInplace<'a> { + pub input_pos: usize, + pub tensor_relative: &'a TensorDescription, + pub strides: Vec, +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/runner.rs b/crates/burn-jit/src/fusion/on_write/trace/runner.rs new file mode 100644 index 0000000000..dc9e2a8f83 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/runner.rs @@ -0,0 +1,154 @@ +use super::super::ir::{ElemwiseConfig, GlobalArgsLaunch}; +use crate::{fusion::JitFusionHandle, JitRuntime}; +use burn_tensor::repr::{TensorDescription, TensorId}; +use cubecl::prelude::*; +use std::collections::BTreeMap; + +/// A trace runner is responsible for determining the vectorization factor as well as launching +/// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) +/// with a provided [element wise config](ElemwiseConfig). +pub trait TraceRunner { + /// The error that might happen while running the trace. + type Error; + + /// Run the trace. + fn run<'a>( + &'a self, + client: &'a ComputeClient, + inputs: GlobalArgsLaunch<'a, R>, + outputs: GlobalArgsLaunch<'a, R>, + config: &'a ElemwiseConfig, + ) -> Result<(), Self::Error>; + + /// The vectorization factor for all inputs and outputs. + fn vectorization<'a>( + vectorizations: &mut BTreeMap, + handles_inputs: impl Iterator>, + inputs: impl Iterator, + outputs: impl Iterator, + reshaped: impl Iterator, + ) { + vectorization_default(vectorizations, handles_inputs, inputs, outputs, reshaped) + } +} + +fn vectorization_default<'a, R: JitRuntime>( + vectorizations: &mut BTreeMap, + handles_inputs: impl Iterator>, + inputs: impl Iterator, + outputs: impl Iterator, + reshaped: impl Iterator, +) { + enum Vect { + Broadcated, + Max(u8), + } + + // The default version uses the last dimension as vectorization axis and assumes a + // perpendicular contiguous line. + let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { + let rank = handle.strides.len(); + + // Last dimension strides should be 1, otherwise vecX won't be contiguous. + if handle.strides[rank - 1] != 1 { + return Vect::Max(1); + } + let shape_axis = desc.shape[rank - 1]; + + if shape_axis == 1 { + return Vect::Broadcated; + } + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size or broadcated. + if shape_axis % s as usize == 0 { + return Vect::Max(s); + } + } + + Vect::Max(1) + }; + + let vectorization_output = |desc: &TensorDescription| { + let rank = desc.shape.len(); + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size. + if desc.shape[rank - 1] % s as usize == 0 { + return Vect::Max(s); + } + } + + Vect::Max(1) + }; + + let vectorization_reshape = + |reshaped: &TensorDescription, original: &TensorDescription, multi_reads: bool| { + let reshape_axis = reshaped.shape[reshaped.shape.len() - 1]; + let shape_axis = original.shape[original.shape.len() - 1]; + + if !multi_reads && reshape_axis == 1 { + return Vect::Broadcated; + } + + for s in R::line_size_elem(&reshaped.dtype.into()) { + if !multi_reads { + // The last dimension should be a multiple of the vector size or broadcated. + if reshape_axis % s as usize == 0 { + return Vect::Max(s); + } + } else { + // Since the original tensor must share the same vectorization factor as the + // reshaped tensor, they must have compatible shapes when both are access + // independently. + if reshape_axis % s as usize == 0 && shape_axis % s as usize == 0 { + return Vect::Max(s); + } + } + } + + Vect::Max(1) + }; + + let mut max_current = u8::MAX; + + for (handle, tensor) in handles_inputs.zip(inputs) { + match vectorization_input(&handle, tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; + } + + for tensor in outputs { + match vectorization_output(tensor) { + Vect::Broadcated => vectorizations.insert(tensor.id, 1), + Vect::Max(val) => { + max_current = Ord::min(val, max_current); + vectorizations.insert(tensor.id, 0) + } + }; + } + + for (reshaped, original, multi_reads) in reshaped { + match vectorization_reshape(reshaped, original, multi_reads) { + Vect::Broadcated => { + vectorizations.insert(original.id, 1); + vectorizations.insert(reshaped.id, 1); + } + Vect::Max(val) => { + vectorizations.insert(original.id, 0); + vectorizations.insert(reshaped.id, 0); + max_current = Ord::min(val, max_current); + } + } + } + + for (_id, val) in vectorizations.iter_mut() { + if *val == 0 { + *val = max_current; + } + } +} diff --git a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs new file mode 100644 index 0000000000..1dca91dcb3 --- /dev/null +++ b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs @@ -0,0 +1,82 @@ +use std::{collections::BTreeMap, marker::PhantomData}; + +use burn_fusion::stream::Context; +use burn_tensor::repr::TensorId; + +use crate::{ + fusion::{ + on_write::{ir::ElemwiseOp, settings::FuseSettings}, + JitFusionHandle, + }, + JitRuntime, +}; + +use super::{HandleOutput, LaunchPlan, Reshape, TraceRunner}; + +pub struct VectorizationPlanner<'a, R: JitRuntime> { + settings: &'a FuseSettings, + reshapes: &'a Vec, + reads: &'a BTreeMap>, + _r: PhantomData, +} + +impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { + pub fn new( + reshapes: &'a Vec, + reads: &'a BTreeMap>, + settings: &'a FuseSettings, + ) -> Self { + Self { + settings, + reshapes, + reads, + _r: PhantomData, + } + } + pub fn run>( + self, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + ) { + let tensors_reshaped = self.reshapes.iter().map(|reshape| { + ( + context.tensors.get(&reshape.reshaped).unwrap(), + context.tensors.get(&reshape.original).unwrap(), + self.reads.get(&reshape.original).unwrap().len() > 1, + ) + }); + + Runner::vectorization( + &mut plan.vectorization, + plan.handle_inputs.iter().map(|item| &item.handle), + plan.global_inputs.iter(), + plan.global_outputs.iter(), + tensors_reshaped, + ); + + // If mix vectorization is disable, we set the vectorization factor of each tensor to the + // minimum value found. + if !self.settings.mix_vectorization { + let factor = plan.vectorization.values().min().cloned(); + if let Some(factor) = factor { + plan.vectorization + .iter_mut() + .for_each(|(_, vf)| *vf = factor); + } + } + + for handle in plan.handle_inputs.iter_mut() { + handle.vectorization = *plan.vectorization.get(&handle.global_id).unwrap(); + } + for handle in plan.handle_outputs.iter_mut() { + match handle { + HandleOutput::Owned { + vectorization, + global_id, + .. + } => *vectorization = *plan.vectorization.get(&global_id).unwrap(), + _ => {} + } + } + } +} From 6ce7c3ed6397c658880d208e5001fd75b6666486 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Sun, 2 Feb 2025 20:58:04 -0500 Subject: [PATCH 22/28] Add reshape optimization --- .../src/fusion/on_write/trace/base.rs | 2 +- .../src/fusion/on_write/trace/executor.rs | 17 +- .../src/fusion/on_write/trace/outputs.rs | 169 ++++++++++++++---- 3 files changed, 144 insertions(+), 44 deletions(-) diff --git a/crates/burn-jit/src/fusion/on_write/trace/base.rs b/crates/burn-jit/src/fusion/on_write/trace/base.rs index cad56f9609..c6298f4b4b 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/base.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/base.rs @@ -58,7 +58,7 @@ impl FuseOnWriteTrace { ) .run(context, &mut plan); - OutputsPlanner::::new(&self.inputs, &self.outputs) + OutputsPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) .run::(client, device, context, &mut plan); VectorizationPlanner::::new(&self.reshapes, &self.reads, &self.settings) diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs index a26bbf991c..55c1c39912 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -44,6 +44,18 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { context: &mut Context<'_, JitFusionHandle>, plan: LaunchPlan<'a, R>, ) -> Result<(), (Runner::Error, Vec>, Vec>)> { + let reference = match plan.reference { + Some(reference) => reference, + None => { + if plan.writes.is_empty() { + // Nothing to write, can skip execution. + return Ok(()); + } else { + panic!("An output should exist for the fused kernel") + } + } + }; + let inputs = self.register_inputs(context, &plan.handle_inputs); let outputs = self.register_outputs::(&plan.handle_outputs); @@ -65,10 +77,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { let config = ElemwiseConfig { rank: plan.rank as u32, - ref_layout: plan - .reference - .expect("An output should exist for the fused kernel") - .layout, + ref_layout: reference.layout, ops, }; diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs index 812880b197..bb84558091 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs @@ -7,14 +7,18 @@ use crate::{ on_write::ir::{Arg, ElemwiseOp, LayoutInfo}, strides_dyn_rank, JitFusionHandle, }, + tensor::is_contiguous, BoolElement, JitRuntime, }; -use super::{super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors}; +use super::{ + super::ir::ElemwisePrecision, HandleOutput, LaunchPlan, Reference, RegisteredTensors, Reshape, +}; use std::collections::BTreeMap; pub struct OutputsPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, + reshapes: &'a Vec, outputs_sorted: Vec>, handles: Vec>>, globals: Vec>, @@ -27,8 +31,18 @@ struct OutputSorted<'a> { tensor_relative: &'a TensorDescription, } +enum OutputKind { + Normal, + Inplace { input_pos: usize }, + Reshaped { reshape: Reshape }, +} + impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { - pub fn new(inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors) -> Self { + pub fn new( + inputs: &'a RegisteredTensors, + outputs: &'a RegisteredTensors, + reshapes: &'a Vec, + ) -> Self { let mut mapper = OutputPositionMapper::default(); let mut outputs_sorted: Vec<_> = outputs .iter() @@ -61,6 +75,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { Self { inputs, outputs_sorted, + reshapes, handles, globals, mapper, @@ -72,7 +87,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { client: &ComputeClient, device: &R::Device, context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchPlan<'a, R>, + plan: &mut LaunchPlan<'a, R>, ) { // So that we can borrow self during the iteration. let mut outputs = Vec::new(); @@ -86,30 +101,42 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { .clone(); let strides = strides_dyn_rank(&tensor_global.shape); - match Self::select_input_inplace(analysis, &tensor_global, &output, &strides) { - Some(index) => { - self.analyse_inplace(context, analysis, output, tensor_global, index); + match self.output_kind(plan, &tensor_global, &output, &strides) { + OutputKind::Inplace { input_pos } => { + self.inplace_output(context, plan, output, tensor_global, input_pos); + } + OutputKind::Normal => { + self.normal_output::( + client, + device, + context, + plan, + output, + tensor_global, + strides, + ); } - None => { - self.analyse_output::( + OutputKind::Reshaped { reshape } => { + self.reshaped_output::( client, device, context, - analysis, + plan, output, tensor_global, strides, + reshape, ); } } } for (handle, global) in self.handles.into_iter().zip(self.globals.into_iter()) { - analysis.handle_outputs.push(handle.unwrap()); - analysis.global_outputs.push(global.unwrap()); + plan.handle_outputs.push(handle.unwrap()); + plan.global_outputs.push(global.unwrap()); } - Self::add_layout_info_inputs(analysis); + Self::add_layout_info_inputs(plan); } fn add_layout_info_inputs(analysis: &mut LaunchPlan<'_, R>) { @@ -128,14 +155,24 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { } } - fn select_input_inplace( - analysis: &mut LaunchPlan<'a, R>, + fn output_kind( + &self, + plan: &mut LaunchPlan<'a, R>, tensor_global: &TensorDescription, output: &OutputSorted, strides: &[usize], - ) -> Option { - analysis - .potential_inplaces + ) -> OutputKind { + if let Some(reshape) = self + .reshapes + .iter() + .find(|r| r.reshaped == output.tensor_relative.id) + { + return OutputKind::Reshaped { + reshape: reshape.clone(), + }; + } + + plan.potential_inplaces .iter() .enumerate() .find(|(_pos, pi)| { @@ -144,35 +181,34 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { && pi.strides == strides }) .map(|(pos, _)| pos) + .map(|input_pos| OutputKind::Inplace { input_pos }) + .unwrap_or(OutputKind::Normal) } - fn analyse_inplace( + fn inplace_output( &mut self, context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchPlan<'a, R>, + plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorDescription, input_index: usize, ) { - let potential_inplace = analysis.potential_inplaces.remove(input_index); - let handle_input = analysis - .handle_inputs - .get(potential_inplace.input_pos) - .unwrap(); + let potential_inplace = plan.potential_inplaces.remove(input_index); + let handle_input = plan.handle_inputs.get(potential_inplace.input_pos).unwrap(); - if analysis.reference.is_none() { + if plan.reference.is_none() { let index_input = self .inputs .get_index(output.precision, potential_inplace.tensor_relative.id) .unwrap(); - analysis.reference = Some(Reference { + plan.reference = Some(Reference { layout: Arg::Input(index_input as u32, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: handle_input.handle.strides.clone(), }); - if let Some(ops) = analysis.reads.get_mut(&handle_input.relative_id) { + if let Some(ops) = plan.reads.get_mut(&handle_input.relative_id) { for op in ops.iter_mut() { if let ElemwiseOp::Assign(op) = op { op.input.add_layout_info(LayoutInfo::IsRef); @@ -180,9 +216,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { } } - if let Some(ElemwiseOp::Assign(op)) = - analysis.writes.get_mut(&output.tensor_relative.id) - { + if let Some(ElemwiseOp::Assign(op)) = plan.writes.get_mut(&output.tensor_relative.id) { op.out.add_layout_info(LayoutInfo::IsRef); }; } @@ -198,35 +232,34 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } - fn analyse_output( + fn normal_output( &mut self, client: &ComputeClient, device: &R::Device, context: &mut Context<'_, JitFusionHandle>, - analysis: &mut LaunchPlan<'a, R>, + plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorDescription, strides: Vec, ) { - if analysis.reference.is_none() { + if plan.reference.is_none() { let position = self .mapper .resolve_index(&output.precision, output.pos_original); - analysis.reference = Some(Reference { + plan.reference = Some(Reference { layout: Arg::Output(position, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.clone(), strides: strides.clone(), }); - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&output.tensor_relative.id).unwrap() + if let ElemwiseOp::Assign(op) = plan.writes.get_mut(&output.tensor_relative.id).unwrap() { op.out.add_layout_info(LayoutInfo::IsRef); }; - } else if let Some(reference) = analysis.reference.as_ref() { + } else if let Some(reference) = plan.reference.as_ref() { if reference.strides == strides && reference.shape == tensor_global.shape { if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&output.tensor_relative.id).unwrap() + plan.writes.get_mut(&output.tensor_relative.id).unwrap() { op.out.add_layout_info(LayoutInfo::SameAsRef); }; @@ -248,7 +281,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { dtype, }; - analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); + plan.rank = usize::max(tensor_global.shape.len(), plan.rank); context .handles .register_handle(tensor_global.id, handle.clone()); @@ -262,6 +295,64 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { }); self.globals[output.pos_original] = Some(tensor_global); } + + fn reshaped_output( + &mut self, + client: &ComputeClient, + device: &R::Device, + context: &mut Context<'_, JitFusionHandle>, + plan: &mut LaunchPlan<'a, R>, + output: OutputSorted, + tensor_global: TensorDescription, + strides: Vec, + reshape: Reshape, + ) { + let original_handle = plan + .handle_inputs + .iter() + .find(|handle| handle.relative_id == reshape.original) + .unwrap(); + + // We encode bool tensors as `B`. + let dtype = match tensor_global.dtype { + DType::Bool => BT::dtype(), + _ => tensor_global.dtype, + }; + + if is_contiguous( + &original_handle.global_shape, + &original_handle.handle.strides, + ) { + plan.writes.remove(&output.tensor_relative.id); + + let handle = JitFusionHandle { + client: client.clone(), + handle: original_handle.handle.handle.clone(), + device: device.clone(), + strides, + dtype, + }; + context + .handles + .register_handle(tensor_global.id, handle.clone()); + // IT will never be access, just a way to keep the original position working. + self.handles[output.pos_original] = Some(HandleOutput::Alias { + input_pos: 0, + precision: output.precision, + }); + self.globals[output.pos_original] = Some(tensor_global); + } else { + self.normal_output::( + client, + device, + context, + plan, + output, + tensor_global, + strides, + ); + } + } } /// Group output position by [element precision](ElemwisePrecision). From e1273254a20b3009a81cdafdd6e2d796effd737a Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 10:13:39 -0500 Subject: [PATCH 23/28] Cleanup --- .../burn-core/src/nn/rnn/gate_controller.rs | 9 +------- crates/burn-core/src/nn/rnn/lstm.rs | 5 ---- .../burn-core/src/nn/transformer/decoder.rs | 1 - .../burn-fusion/src/stream/execution/base.rs | 6 ----- crates/burn-jit/src/fusion/matmul/builder.rs | 1 - .../burn-jit/src/fusion/on_write/builder.rs | 23 ++++++++----------- crates/burn-jit/src/fusion/on_write/io.rs | 8 +++---- crates/burn-jit/src/fusion/on_write/kernel.rs | 4 +++- 8 files changed, 17 insertions(+), 40 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gate_controller.rs b/crates/burn-core/src/nn/rnn/gate_controller.rs index 4d5868a936..f20f7fa2a1 100644 --- a/crates/burn-core/src/nn/rnn/gate_controller.rs +++ b/crates/burn-core/src/nn/rnn/gate_controller.rs @@ -57,14 +57,7 @@ impl GateController { /// H = hidden state /// b = bias terms pub fn gate_product(&self, input: Tensor, hidden: Tensor) -> Tensor { - println!("{input}"); - let temp = self.input_transform.forward(input); - let temp2 = self.hidden_transform.forward(hidden); - println!("1: {temp}"); - println!("2: {temp2}"); - let temp = temp + temp2; - // panic!("3: {temp}"); - temp + self.input_transform.forward(input) + self.hidden_transform.forward(hidden) } /// Used to initialize a gate controller with known weight layers, diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 9968a48a0a..9a7c23399b 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -721,26 +721,21 @@ mod tests { output_with_init_state .to_data() .assert_approx_eq(&expected_output_with_init_state, 3); - println!("1"); output_without_init_state .to_data() .assert_approx_eq(&expected_output_without_init_state, 3); - println!("2"); state_with_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_with_init_state, 3); - println!("3"); state_with_init_state .cell .to_data() .assert_approx_eq(&expected_cn_with_init_state, 3); - println!("4"); state_without_init_state .hidden .to_data() .assert_approx_eq(&expected_hn_without_init_state, 3); - println!("5"); state_without_init_state .cell .to_data() diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index 4c509bee86..a3d7e49158 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -458,7 +458,6 @@ mod tests { use burn_tensor::Device; use super::*; - use crate::tensor::Distribution; use crate::{nn::attention::generate_autoregressive_mask, TestBackend}; #[test] diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index ef31748a37..149b2d095d 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -49,12 +49,6 @@ impl OperationQueue { let mut context = self.converter.context(handles); optimization.execute(&mut context); - log::info!("====== MATMUL ======"); - for op in &self.global[0..num_drained] { - log::info!("{op:?}") - } - log::info!("====== END ======"); - self.drain_queue(num_drained, handles); self.operations.drain(0..num_drained); } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index fb0ebfb48d..11e38626e3 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -47,7 +47,6 @@ impl OptimizationBuilder> for MatmulBuilder } if self.matmul.is_none() { - log::info!("New matmul fusion"); if let OperationDescription::Float(_, FloatOperationDescription::Matmul(op)) = operation { let lhs = self.builder.input_unhandled(&op.lhs); diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 717e80ff2d..bbfca0f3d2 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -71,7 +71,6 @@ impl TryFuseBuilder { impl OptimizationBuilder for FuseOnWriteBuilder { fn register(&mut self, op: &OperationDescription) { - log::info!("Register {op:?}"); if let OptimizationStatus::Closed = self.status { return; } @@ -107,14 +106,12 @@ impl OptimizationBuilder for FuseOnWriteBuilder { return; } } - // TODO - // - // OperationDescription::BaseBool(ops) => { - // if !self.register_base(ops) { - // self.status = OptimizationStatus::Closed; - // return; - // } - // } + OperationDescription::BaseBool(ops) => { + if !self.register_base(ops) { + self.status = OptimizationStatus::Closed; + return; + } + } _ => { self.status = OptimizationStatus::Closed; return; @@ -542,12 +539,10 @@ impl FuseOnWriteBuilder { return false; } - if updated != self.current_output_shape { - if updated != out.shape { - return false; - } - self.current_output_shape.clone_from_slice(&out.shape); + if updated != out.shape { + return false; } + self.current_output_shape.clone_from_slice(&out.shape); true } diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index 24f2b94aa3..cacb2decc5 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -57,7 +57,7 @@ pub fn read( config, comptime![Some(shape)], ), - _ => comptime![panic![]], + _ => comptime![panic!("Only input can be reshaped")], }, } } @@ -90,7 +90,7 @@ pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: Arg) -> u32 { let offset = comptime![inputs.s_u32.len() - pos - 1]; *inputs.s_u32.index(offset) } - _ => comptime![panic!["Not a scalar shape"]], + _ => comptime![panic!("Not a scalar shape")], } } @@ -773,8 +773,8 @@ fn index_offset_with_layout( ) -> u32 { match comptime![shape.clone()] { Some(shape) => { - let index_reshaped = reshaped_index(inputs, layout, index, rank, shape); - reshaped_index_to_original_index(tensor, index_reshaped, rank) + let index = reshaped_index(inputs, layout, index, rank, shape); + reshaped_index_to_original_index(tensor, index, rank) } None => { let offset_ref = index * layout.line_size(); diff --git a/crates/burn-jit/src/fusion/on_write/kernel.rs b/crates/burn-jit/src/fusion/on_write/kernel.rs index fa5abbfba9..cf79a0fb6b 100644 --- a/crates/burn-jit/src/fusion/on_write/kernel.rs +++ b/crates/burn-jit/src/fusion/on_write/kernel.rs @@ -404,7 +404,9 @@ pub fn fuse_on_write( ElemwisePrecision::U8 => { equal::(inputs, outputs, &mut locals, write_pos, op, config) } - _ => comptime![panic!("Unsupported precision {op:?}")], + ElemwisePrecision::Bool => { + equal::(inputs, outputs, &mut locals, write_pos, op, config) + } }, ElemwiseOp::Greater(op) => match op.lhs.precision() { ElemwisePrecision::F32 => { From 1a84818ed1fda5dcc0cdd752bb4670b48818f15e Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 10:23:26 -0500 Subject: [PATCH 24/28] Add some docs --- crates/burn-jit/src/fusion/on_write/trace/executor.rs | 3 +++ crates/burn-jit/src/fusion/on_write/trace/inputs.rs | 2 ++ crates/burn-jit/src/fusion/on_write/trace/outputs.rs | 3 +++ crates/burn-jit/src/fusion/on_write/trace/plan.rs | 2 ++ 4 files changed, 10 insertions(+) diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs index 55c1c39912..9b1d719068 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -16,6 +16,7 @@ use crate::{ BoolElement, JitRuntime, }; +/// Execute a [plan](LaunchPlan) using a [runner](TraceRunner) modifying the [context](Context). pub struct LaunchPlanExecutor<'a, R: JitRuntime> { scalars: &'a BTreeMap, reshapes: &'a Vec, @@ -84,6 +85,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { Runner::run(runner, client, inputs, outputs, &config) .map_err(|err| (err, plan.handle_inputs, plan.handle_outputs)) } + fn register_inputs<'h>( &self, context: &mut Context<'_, JitFusionHandle>, @@ -146,6 +148,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { } } + // Reshape values are pushed in reverse in the same scalar buffer for all `u32` for relative in self.reshapes.iter().rev() { let global = context.tensors.get(&relative.reshaped).unwrap(); diff --git a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs index 0ae74c708d..54791ec225 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/inputs.rs @@ -9,6 +9,8 @@ use std::marker::PhantomData; use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; +/// Fetch and register [input handles](HandleInput) and itendify potential inputs that +/// can be used inplace. pub struct InputsPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs index bb84558091..66624142fb 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/outputs.rs @@ -16,6 +16,9 @@ use super::{ }; use std::collections::BTreeMap; +/// Create or reuse handles for the outputs. +/// +/// It is also responsable to select the reference tensor. pub struct OutputsPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, reshapes: &'a Vec, diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs index 896b8d7f51..2fc68ba2be 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -9,6 +9,8 @@ use crate::{ }; use burn_tensor::repr::{TensorDescription, TensorId}; +/// The plan is responsable to keep runtime information related to the launch of a fused kernel +/// at one place. #[derive(Debug)] pub(crate) struct LaunchPlan<'a, R: JitRuntime> { pub potential_inplaces: Vec>, From 5b25f18e31cf0a82ff5c83ac569c63f4561ec513 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 11:41:09 -0500 Subject: [PATCH 25/28] Update cubecl ref --- Cargo.lock | 15 +++++++++++++++ Cargo.toml | 8 ++++---- .../src/fusion/on_write/trace/vectorization.rs | 1 + 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 62abf8ac0a..e2eca1da8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1475,6 +1475,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1489,6 +1490,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1509,6 +1511,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1529,6 +1532,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1542,6 +1546,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1557,6 +1562,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1582,6 +1588,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1599,6 +1606,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-core", @@ -1610,6 +1618,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "darling", @@ -1624,6 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "darling", "proc-macro2", @@ -1634,6 +1644,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1649,6 +1660,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1658,6 +1670,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "async-channel", "async-lock", @@ -1679,6 +1692,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1693,6 +1707,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index e225c304aa..2ffa714a6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } ### For local development. ### -cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs index 1dca91dcb3..e216b6e525 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs @@ -13,6 +13,7 @@ use crate::{ use super::{HandleOutput, LaunchPlan, Reshape, TraceRunner}; +/// Select the best vectorization factor for each tensor handle. pub struct VectorizationPlanner<'a, R: JitRuntime> { settings: &'a FuseSettings, reshapes: &'a Vec, From 651bb4ca1c55a46403e0e7ea655b27b279eb94be Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 11:52:47 -0500 Subject: [PATCH 26/28] Clippy + Fmt --- .../burn-jit/src/fusion/on_write/builder.rs | 1 + .../src/fusion/on_write/trace/base.rs | 36 +++++++++---------- .../src/fusion/on_write/trace/builder.rs | 34 ++++++++++-------- .../src/fusion/on_write/trace/executor.rs | 11 ++++-- .../on_write/trace/{inputs.rs => input.rs} | 9 +++-- .../burn-jit/src/fusion/on_write/trace/mod.rs | 4 +-- .../on_write/trace/{outputs.rs => output.rs} | 8 +++-- .../src/fusion/on_write/trace/plan.rs | 2 +- .../src/fusion/on_write/trace/runner.rs | 2 +- .../fusion/on_write/trace/vectorization.rs | 14 ++++---- 10 files changed, 67 insertions(+), 54 deletions(-) rename crates/burn-jit/src/fusion/on_write/trace/{inputs.rs => input.rs} (92%) rename crates/burn-jit/src/fusion/on_write/trace/{outputs.rs => output.rs} (98%) diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index bbfca0f3d2..ffb5dfcb79 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -512,6 +512,7 @@ impl FuseOnWriteBuilder { let mut updated = self.current_output_shape.clone(); + #[allow(clippy::needless_range_loop)] for i in 0..rank { let curr = self.current_output_shape[i]; let new = out.shape[i]; diff --git a/crates/burn-jit/src/fusion/on_write/trace/base.rs b/crates/burn-jit/src/fusion/on_write/trace/base.rs index c6298f4b4b..f017d6dc1b 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/base.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/base.rs @@ -6,8 +6,8 @@ use super::{ settings::FuseSettings, }, executor::LaunchPlanExecutor, - inputs::InputsPlanner, - outputs::OutputsPlanner, + input::InputPlanner, + output::OutputPlanner, vectorization::VectorizationPlanner, HandleInput, HandleOutput, LaunchPlan, TraceRunner, }; @@ -17,19 +17,19 @@ use cubecl::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -#[derive(new, Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug)] /// Trace containing all element wise operations as well as reads and writes. pub struct FuseOnWriteTrace { - outputs: RegisteredTensors, - inputs: RegisteredTensors, - settings: FuseSettings, - scalars: BTreeMap, - reshapes: Vec, - shape_ref: Vec, - ops: Vec, - reads: BTreeMap>, - writes: BTreeMap, - inputs_unhandled: Vec, + pub outputs: RegisteredTensors, + pub inputs: RegisteredTensors, + pub settings: FuseSettings, + pub scalars: BTreeMap, + pub reshapes: Vec, + pub shape_ref: Vec, + pub ops: Vec, + pub reads: BTreeMap>, + pub writes: BTreeMap, + pub inputs_unhandled: Vec, } #[derive(Clone, Serialize, Deserialize, Debug)] @@ -49,7 +49,7 @@ impl FuseOnWriteTrace { ) -> Result<(), Runner::Error> { let mut plan = LaunchPlan::new(&self.reads, &self.writes, self.shape_ref.len()); - InputsPlanner::::new( + InputPlanner::::new( &self.inputs, &self.inputs_unhandled, &self.reshapes, @@ -58,7 +58,7 @@ impl FuseOnWriteTrace { ) .run(context, &mut plan); - OutputsPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) + OutputPlanner::::new(&self.inputs, &self.outputs, &self.reshapes) .run::(client, device, context, &mut plan); VectorizationPlanner::::new(&self.reshapes, &self.reads, &self.settings) @@ -67,9 +67,9 @@ impl FuseOnWriteTrace { match LaunchPlanExecutor::::new(&self.scalars, &self.reshapes, &self.ops) .execute::<_, BT>(client, runner, context, plan) { - Err((err, handle_inputs, handle_outputs)) => { - self.rollback(context, handle_inputs, handle_outputs); - Err(err) + Err(err) => { + self.rollback(context, err.handles_input, err.handles_output); + Err(err.runner_error) } Ok(val) => Ok(val), } diff --git a/crates/burn-jit/src/fusion/on_write/trace/builder.rs b/crates/burn-jit/src/fusion/on_write/trace/builder.rs index 8d30ec0283..896d7f3b1c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/builder.rs @@ -102,8 +102,10 @@ impl FuseOnWriteTraceBuilder { let out = self.locals.create(precision, tensor.id); let input = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); - let reads = if !self.reads.contains_key(&tensor.id) { - self.reads.insert(tensor.id, Vec::with_capacity(1)); + let reads = if let std::collections::btree_map::Entry::Vacant(e) = + self.reads.entry(tensor.id) + { + e.insert(Vec::with_capacity(1)); self.reads.get_mut(&tensor.id).unwrap() } else { self.reads.get_mut(&tensor.id).unwrap() @@ -180,8 +182,8 @@ impl FuseOnWriteTraceBuilder { let index = self.reshapes.len(); self.reshapes.push(Reshape { - reshaped: output.id.clone(), - original: tensor.id.clone(), + reshaped: output.id, + original: tensor.id, }); let rank = output.shape.len(); @@ -195,12 +197,13 @@ impl FuseOnWriteTraceBuilder { shape, }; - let reads = if !self.reads.contains_key(&tensor.id) { - self.reads.insert(tensor.id, Vec::with_capacity(1)); - self.reads.get_mut(&tensor.id).unwrap() - } else { - self.reads.get_mut(&tensor.id).unwrap() - }; + let reads = + if let std::collections::btree_map::Entry::Vacant(e) = self.reads.entry(tensor.id) { + e.insert(Vec::with_capacity(1)); + self.reads.get_mut(&tensor.id).unwrap() + } else { + self.reads.get_mut(&tensor.id).unwrap() + }; reads.push(ElemwiseOp::Assign(UnaryElemwiseArgs { input, @@ -226,7 +229,7 @@ impl FuseOnWriteTraceBuilder { Arg::Scalar(new_index, precision) } - pub fn build(&self, shape: Vec) -> FuseOnWriteTrace { + pub fn build(&self, shape_ref: Vec) -> FuseOnWriteTrace { let inputs = self.inputs.clone(); let outputs = self.output_tensors(); let ops = self.ops.clone(); @@ -250,19 +253,20 @@ impl FuseOnWriteTraceBuilder { let reshapes = self.reshapes.clone(); let settings = self.settings; + let inputs_unhandled = self.inputs_unhandled.clone(); - FuseOnWriteTrace::new( + FuseOnWriteTrace { outputs, inputs, settings, scalars, reshapes, - shape, + shape_ref, ops, reads, writes, - self.inputs_unhandled.clone(), - ) + inputs_unhandled, + } } fn output_tensors(&self) -> RegisteredTensors { diff --git a/crates/burn-jit/src/fusion/on_write/trace/executor.rs b/crates/burn-jit/src/fusion/on_write/trace/executor.rs index 9b1d719068..749e74340e 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/executor.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/executor.rs @@ -24,6 +24,13 @@ pub struct LaunchPlanExecutor<'a, R: JitRuntime> { _r: PhantomData, } +#[derive(new)] +pub struct ExecutionError> { + pub runner_error: Runner::Error, + pub handles_input: Vec>, + pub handles_output: Vec>, +} + impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { pub fn new( scalars: &'a BTreeMap, @@ -44,7 +51,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { runner: &Runner, context: &mut Context<'_, JitFusionHandle>, plan: LaunchPlan<'a, R>, - ) -> Result<(), (Runner::Error, Vec>, Vec>)> { + ) -> Result<(), ExecutionError> { let reference = match plan.reference { Some(reference) => reference, None => { @@ -83,7 +90,7 @@ impl<'a, R: JitRuntime> LaunchPlanExecutor<'a, R> { }; Runner::run(runner, client, inputs, outputs, &config) - .map_err(|err| (err, plan.handle_inputs, plan.handle_outputs)) + .map_err(|err| ExecutionError::new(err, plan.handle_inputs, plan.handle_outputs)) } fn register_inputs<'h>( diff --git a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs b/crates/burn-jit/src/fusion/on_write/trace/input.rs similarity index 92% rename from crates/burn-jit/src/fusion/on_write/trace/inputs.rs rename to crates/burn-jit/src/fusion/on_write/trace/input.rs index 54791ec225..a243fddac1 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/inputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/input.rs @@ -11,7 +11,7 @@ use super::{HandleInput, LaunchPlan, PotentialInplace, RegisteredTensors}; /// Fetch and register [input handles](HandleInput) and itendify potential inputs that /// can be used inplace. -pub struct InputsPlanner<'a, R: JitRuntime> { +pub struct InputPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, reshapes: &'a Vec, @@ -20,7 +20,7 @@ pub struct InputsPlanner<'a, R: JitRuntime> { _r: PhantomData, } -impl<'a, R: JitRuntime> InputsPlanner<'a, R> { +impl<'a, R: JitRuntime> InputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, inputs_unhandled: &'a Vec, @@ -51,11 +51,10 @@ impl<'a, R: JitRuntime> InputsPlanner<'a, R> { && status == &TensorStatus::ReadWrite && handle.handle.can_mut() && !self.inputs_unhandled.contains(&tensor_relative.id) - && self + && !self .reshapes .iter() - .find(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) - .is_none() + .any(|r| r.reshaped == tensor_relative.id || r.original == tensor_relative.id) && self.shape_ref == &tensor_relative.shape { plan.potential_inplaces.push(PotentialInplace { diff --git a/crates/burn-jit/src/fusion/on_write/trace/mod.rs b/crates/burn-jit/src/fusion/on_write/trace/mod.rs index a3f0575299..64de887986 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/mod.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod executor; -pub(crate) mod inputs; -pub(crate) mod outputs; +pub(crate) mod input; +pub(crate) mod output; pub(crate) mod vectorization; mod base; diff --git a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs b/crates/burn-jit/src/fusion/on_write/trace/output.rs similarity index 98% rename from crates/burn-jit/src/fusion/on_write/trace/outputs.rs rename to crates/burn-jit/src/fusion/on_write/trace/output.rs index 66624142fb..0964974c7a 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/outputs.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/output.rs @@ -19,7 +19,7 @@ use std::collections::BTreeMap; /// Create or reuse handles for the outputs. /// /// It is also responsable to select the reference tensor. -pub struct OutputsPlanner<'a, R: JitRuntime> { +pub struct OutputPlanner<'a, R: JitRuntime> { inputs: &'a RegisteredTensors, reshapes: &'a Vec, outputs_sorted: Vec>, @@ -40,7 +40,7 @@ enum OutputKind { Reshaped { reshape: Reshape }, } -impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { +impl<'a, R: JitRuntime> OutputPlanner<'a, R> { pub fn new( inputs: &'a RegisteredTensors, outputs: &'a RegisteredTensors, @@ -235,6 +235,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } + #[allow(clippy::too_many_arguments)] fn normal_output( &mut self, client: &ComputeClient, @@ -299,6 +300,7 @@ impl<'a, R: JitRuntime> OutputsPlanner<'a, R> { self.globals[output.pos_original] = Some(tensor_global); } + #[allow(clippy::too_many_arguments)] fn reshaped_output( &mut self, client: &ComputeClient, @@ -377,7 +379,7 @@ impl OutputPositionMapper { /// Returns the right position from the precision and the global position in all outputs. pub fn resolve_index(&mut self, precision: &ElemwisePrecision, pos_handle: usize) -> u32 { self.map - .get(&precision) + .get(precision) .unwrap() .iter() .enumerate() diff --git a/crates/burn-jit/src/fusion/on_write/trace/plan.rs b/crates/burn-jit/src/fusion/on_write/trace/plan.rs index 2fc68ba2be..89a11a188c 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/plan.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/plan.rs @@ -25,7 +25,7 @@ pub(crate) struct LaunchPlan<'a, R: JitRuntime> { pub rank: usize, } -impl<'a, R: JitRuntime> LaunchPlan<'a, R> { +impl LaunchPlan<'_, R> { pub fn new( reads: &BTreeMap>, writes: &BTreeMap, diff --git a/crates/burn-jit/src/fusion/on_write/trace/runner.rs b/crates/burn-jit/src/fusion/on_write/trace/runner.rs index dc9e2a8f83..fc3109327d 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/runner.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/runner.rs @@ -113,7 +113,7 @@ fn vectorization_default<'a, R: JitRuntime>( let mut max_current = u8::MAX; for (handle, tensor) in handles_inputs.zip(inputs) { - match vectorization_input(&handle, tensor) { + match vectorization_input(handle, tensor) { Vect::Broadcated => vectorizations.insert(tensor.id, 1), Vect::Max(val) => { max_current = Ord::min(val, max_current); diff --git a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs index e216b6e525..ff775e3327 100644 --- a/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs +++ b/crates/burn-jit/src/fusion/on_write/trace/vectorization.rs @@ -70,13 +70,13 @@ impl<'a, R: JitRuntime> VectorizationPlanner<'a, R> { handle.vectorization = *plan.vectorization.get(&handle.global_id).unwrap(); } for handle in plan.handle_outputs.iter_mut() { - match handle { - HandleOutput::Owned { - vectorization, - global_id, - .. - } => *vectorization = *plan.vectorization.get(&global_id).unwrap(), - _ => {} + if let HandleOutput::Owned { + vectorization, + global_id, + .. + } = handle + { + *vectorization = *plan.vectorization.get(global_id).unwrap() } } } From 4ce799319f4406139f605c32fdbb2ed80392ac69 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 12:48:28 -0500 Subject: [PATCH 27/28] Add vulkan in example --- examples/text-classification/Cargo.toml | 2 +- .../text-classification/examples/ag-news-train.rs | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 043c61672d..d380dcc9cb 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,7 +16,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -vulkan = ["wgpu", "burn/vulkan"] +vulkan = ["burn/vulkan", "burn/default"] remote = ["burn/remote"] cuda = ["burn/cuda"] hip = ["burn/hip"] diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 927c190b2c..9e85cbc48b 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -93,6 +93,16 @@ mod wgpu { } } +#[cfg(feature = "vulkan")] +mod vulkan { + use crate::{launch, ElemType}; + use burn::backend::{Vulkan, Autodiff}; + + pub fn run() { + launch::>>(vec![Default::default()]); + } +} + #[cfg(feature = "remote")] mod remote { use crate::{launch, ElemType}; @@ -143,4 +153,6 @@ fn main() { hip::run(); #[cfg(feature = "remote")] remote::run(); + #[cfg(feature = "vulkan")] + vulkan::run(); } From 12de7659d3f1fc633744d9f61a239457b14a0d07 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Mon, 3 Feb 2025 21:43:10 -0500 Subject: [PATCH 28/28] WIP --- Cargo.lock | 15 --------------- Cargo.toml | 8 ++++---- crates/burn-jit/src/fusion/on_write/builder.rs | 3 +++ crates/burn-jit/src/kernel/reduce/base.rs | 17 ++++++++++++----- crates/burn-jit/src/kernel/reduce/tune.rs | 16 +++++++++++++--- .../examples/ag-news-train.rs | 2 +- 6 files changed, 33 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e2eca1da8a..62abf8ac0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1475,7 +1475,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1490,7 +1489,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1511,7 +1509,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1532,7 +1529,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1546,7 +1542,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1562,7 +1557,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-common", @@ -1588,7 +1582,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1606,7 +1599,6 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bytemuck", "cubecl-core", @@ -1618,7 +1610,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "darling", @@ -1633,7 +1624,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "darling", "proc-macro2", @@ -1644,7 +1634,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1660,7 +1649,6 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1670,7 +1658,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "async-channel", "async-lock", @@ -1692,7 +1679,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1707,7 +1693,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=276b8db0b5402492cc72013ce8da9b63be3a165f#276b8db0b5402492cc72013ce8da9b63be3a165f" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 2ffa714a6c..331ac11c18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "276b8db0b5402492cc72013ce8da9b63be3a165f" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### # cubecl = { version = "0.4.0", default-features = false } # cubecl-common = { version = "0.4.0", default-features = false } diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index ffb5dfcb79..d1584af3b7 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -183,6 +183,9 @@ impl FuseOnWriteBuilder { pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { if self.current_output_shape.is_empty() { self.current_output_shape = tensor.shape.clone(); + } else if self.current_output_shape.iter().sum::() < tensor.shape.iter().sum() { + // The larguest shape win. + self.current_output_shape = tensor.shape.clone(); } self.builder.builder.output_unhandled(tensor) diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index ccfcc3ef9e..75c7207885 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -25,11 +25,18 @@ pub fn sum( match cube_count { SumStrategy::OneShot(cube_count) => { - let output = shared_sum::(&client, tensor.as_handle_ref(), cube_count)?; - Ok(from_data::( - TensorData::new(vec![output], vec![1]), - &device, - )) + let handle = client.empty(E::size().unwrap()); + let output = + JitTensor::new_contiguous(client.clone(), device, [1].into(), handle, E::dtype()); + + shared_sum::( + &client, + tensor.as_handle_ref(), + output.as_handle_ref(), + cube_count, + )?; + + Ok(output) } SumStrategy::Chained(strategy) => reduce::(tensor, strategy), #[cfg(feature = "autotune")] diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index cd5cd61157..b26d852656 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -275,10 +275,20 @@ mod sum_ops { pub(crate) fn sum_one_shot( input: JitTensor, ) -> Result, String> { + let client = input.client.clone(); let device = input.device.clone(); - cubecl::reduce::shared_sum::(&input.client, input.as_handle_ref(), C) - .map(|output| from_data::(TensorData::new(vec![output], vec![1]), &device)) - .map_err(|e| e.to_string()) + let handle = client.empty(E::size().unwrap()); + let output = JitTensor::new_contiguous(client, device, [1].into(), handle, E::dtype()); + + cubecl::reduce::shared_sum::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + C, + ) + .map_err(|e| e.to_string())?; + + Ok(output) } #[cfg(feature = "autotune")] diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 9e85cbc48b..610cd821c4 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -96,7 +96,7 @@ mod wgpu { #[cfg(feature = "vulkan")] mod vulkan { use crate::{launch, ElemType}; - use burn::backend::{Vulkan, Autodiff}; + use burn::backend::{Autodiff, Vulkan}; pub fn run() { launch::>>(vec![Default::default()]);