diff --git a/Cargo.lock b/Cargo.lock index 5d4be36767..7ecc9112ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1553,7 +1553,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1584,7 +1584,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1601,7 +1601,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1619,7 +1619,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1649,7 +1649,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1674,7 +1674,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "bytemuck", "cubecl-core", @@ -1685,7 +1685,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1700,7 +1700,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1737,7 +1737,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "async-channel", "async-lock", @@ -1758,7 +1758,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1772,7 +1772,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4198b192c60bffff7ec51920bcd7191560d6c98b#4198b192c60bffff7ec51920bcd7191560d6c98b" +source = "git+https://github.com/tracel-ai/cubecl?rev=3882ed25b47506d49562c501a179b7468e61702e#3882ed25b47506d49562c501a179b7468e61702e" dependencies = [ "ash", "async-channel", @@ -1786,6 +1786,7 @@ dependencies = [ "derive-new 0.6.0", "hashbrown 0.14.5", "log", + "sanitize-filename", "web-time", "wgpu", ] diff --git a/Cargo.toml b/Cargo.toml index 8e48e1186b..c6a4a2154e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4198b192c60bffff7ec51920bcd7191560d6c98b" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4198b192c60bffff7ec51920bcd7191560d6c98b" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 527c4bd98e..bc17501916 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -8,6 +8,7 @@ use crate::{ use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; use burn_tensor::repr::{QuantizedKind, TensorHandle}; +use burn_tensor::DType; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use cubecl::client::ComputeClient; @@ -187,6 +188,7 @@ pub struct JitFusionHandle { pub handle: cubecl::server::Handle, /// The device of the current tensor. pub device: R::Device, + pub(crate) dtype: DType, pub(crate) strides: Vec, } @@ -207,6 +209,7 @@ impl Clone for JitFusionHandle { handle: self.handle.clone(), device: self.device.clone(), strides: self.strides.clone(), + dtype: self.dtype, } } } @@ -232,6 +235,7 @@ impl JitFusionHandle { strides: &self.strides, shape, runtime: PhantomData, + elem_size: self.dtype.size(), } } /// Return the reference to a tensor argument. @@ -239,7 +243,13 @@ impl JitFusionHandle { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); unsafe { - TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation) + TensorArg::from_raw_parts_and_size( + handle.handle, + handle.strides, + handle.shape, + vectorisation, + self.dtype.size(), + ) } } } @@ -251,6 +261,7 @@ impl From> for JitFusionHandle handle: value.handle, device: value.device, strides: value.strides, + dtype: E::dtype(), } } } diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index c7ee6bde15..5f55140657 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -285,6 +285,7 @@ impl FuseOnWriteTrace { handle: client.empty(size), device: device.clone(), strides, + dtype, }; analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 716187416f..545bbe6f01 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -51,7 +51,7 @@ pub(crate) fn select( cube_dim, tensor.as_tensor_arg(1), // Ignore shape and stride - TensorArg::from_raw_parts(&indices.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&indices.handle, &dummy_array, &dummy_array, 1), output.as_tensor_arg(1), ScalarArg::new(dim as u32), ) diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 9da13e498c..37a39a6331 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -82,7 +82,7 @@ pub(crate) fn select_assign( cube_dim, tensor.as_tensor_arg(1), // Ignored shape + custom strides. - TensorArg::from_raw_parts(&indices.handle, &strides, &strides, 1), + TensorArg::from_raw_parts::(&indices.handle, &strides, &strides, 1), value.as_tensor_arg(1), ScalarArg::new(dim as u32), ); diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 2adae02666..8e09abf5aa 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -128,7 +128,7 @@ pub fn matmul_simple( cube_count, CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), lhs.as_tensor_arg(vectorization_factor), - TensorArg::from_raw_parts( + TensorArg::from_raw_parts::( &rhs.handle, &rhs.strides, &rhs_original_shape.dims, // We need the original shape. diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 65d50f48fc..61dc0733d8 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -170,8 +170,8 @@ where cube_dim, tensor.as_tensor_arg(vectorization_factor), // Ignore shape and stride - TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), - TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), output.as_tensor_arg(1), vectorization_factor > 1, ) @@ -184,7 +184,7 @@ where cube_dim, tensor.as_tensor_arg(vectorization_factor), // Ignore shape and stride - TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), output.as_tensor_arg(1), vectorization_factor > 1, ) diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e11a8da0d1..c951bf40b6 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -164,8 +164,8 @@ where cube_dim, tensor.as_tensor_arg(vectorization_factor), // Ignore shape and stride - TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), - TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&offset.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(i8::MIN as f32), ScalarArg::new(i8::MAX as f32), output.as_tensor_arg(1), @@ -180,7 +180,7 @@ where cube_dim, tensor.as_tensor_arg(vectorization_factor), // Ignore shape and stride - TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), ScalarArg::new(-i8::MAX as f32), ScalarArg::new(i8::MAX as f32), output.as_tensor_arg(1), diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 50132ddae0..a1aae766c3 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -134,6 +134,7 @@ where strides: &self.strides, shape: &self.shape.dims, runtime: PhantomData, + elem_size: E::dtype().size(), } } @@ -142,7 +143,12 @@ where let handle: TensorHandleRef<'a, R> = self.as_handle_ref(); unsafe { - TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation) + TensorArg::from_raw_parts::( + handle.handle, + handle.strides, + handle.shape, + vectorisation, + ) } }