Skip to content

Commit

Permalink
Chore/update cubecl (#2465)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 7, 2024
1 parent 099b6dc commit 69de0ef
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 25 deletions.
25 changes: 13 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4198b192c60bffff7ec51920bcd7191560d6c98b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4198b192c60bffff7ec51920bcd7191560d6c98b" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
13 changes: 12 additions & 1 deletion crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_tensor::quantization::QuantizationScheme;
use burn_tensor::repr::{QuantizedKind, TensorHandle};
use burn_tensor::DType;
use burn_tensor::{repr::ReprBackend, Shape};
use core::marker::PhantomData;
use cubecl::client::ComputeClient;
Expand Down Expand Up @@ -187,6 +188,7 @@ pub struct JitFusionHandle<R: JitRuntime> {
pub handle: cubecl::server::Handle,
/// The device of the current tensor.
pub device: R::Device,
pub(crate) dtype: DType,
pub(crate) strides: Vec<usize>,
}

Expand All @@ -207,6 +209,7 @@ impl<R: JitRuntime> Clone for JitFusionHandle<R> {
handle: self.handle.clone(),
device: self.device.clone(),
strides: self.strides.clone(),
dtype: self.dtype,
}
}
}
Expand All @@ -232,14 +235,21 @@ impl<R: JitRuntime> JitFusionHandle<R> {
strides: &self.strides,
shape,
runtime: PhantomData,
elem_size: self.dtype.size(),
}
}
/// Return the reference to a tensor argument.
pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> {
let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape);

unsafe {
TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation)
TensorArg::from_raw_parts_and_size(
handle.handle,
handle.strides,
handle.shape,
vectorisation,
self.dtype.size(),
)
}
}
}
Expand All @@ -251,6 +261,7 @@ impl<R: JitRuntime, E: JitElement> From<JitTensor<R, E>> for JitFusionHandle<R>
handle: value.handle,
device: value.device,
strides: value.strides,
dtype: E::dtype(),
}
}
}
1 change: 1 addition & 0 deletions crates/burn-jit/src/fusion/on_write/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ impl FuseOnWriteTrace {
handle: client.empty(size),
device: device.clone(),
strides,
dtype,
};

analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub(crate) fn select<R: JitRuntime, E: JitElement, I: JitElement>(
cube_dim,
tensor.as_tensor_arg(1),
// Ignore shape and stride
TensorArg::from_raw_parts(&indices.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<I>(&indices.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
ScalarArg::new(dim as u32),
)
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(crate) fn select_assign<R: JitRuntime, E: JitElement, I: JitElement>(
cube_dim,
tensor.as_tensor_arg(1),
// Ignored shape + custom strides.
TensorArg::from_raw_parts(&indices.handle, &strides, &strides, 1),
TensorArg::from_raw_parts::<I>(&indices.handle, &strides, &strides, 1),
value.as_tensor_arg(1),
ScalarArg::new(dim as u32),
);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/matmul/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement>(
cube_count,
CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
lhs.as_tensor_arg(vectorization_factor),
TensorArg::from_raw_parts(
TensorArg::from_raw_parts::<E>(
&rhs.handle,
&rhs.strides,
&rhs_original_shape.dims, // We need the original shape.
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-jit/src/kernel/quantization/dequantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ where
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<I>(&offset.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
Expand All @@ -184,7 +184,7 @@ where
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
output.as_tensor_arg(1),
vectorization_factor > 1,
)
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-jit/src/kernel/quantization/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ where
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<I>(&offset.handle, &dummy_array, &dummy_array, 1),
ScalarArg::new(i8::MIN as f32),
ScalarArg::new(i8::MAX as f32),
output.as_tensor_arg(1),
Expand All @@ -180,7 +180,7 @@ where
cube_dim,
tensor.as_tensor_arg(vectorization_factor),
// Ignore shape and stride
TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1),
TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
ScalarArg::new(-i8::MAX as f32),
ScalarArg::new(i8::MAX as f32),
output.as_tensor_arg(1),
Expand Down
8 changes: 7 additions & 1 deletion crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ where
strides: &self.strides,
shape: &self.shape.dims,
runtime: PhantomData,
elem_size: E::dtype().size(),
}
}

Expand All @@ -142,7 +143,12 @@ where
let handle: TensorHandleRef<'a, R> = self.as_handle_ref();

unsafe {
TensorArg::from_raw_parts(handle.handle, handle.strides, handle.shape, vectorisation)
TensorArg::from_raw_parts::<E>(
handle.handle,
handle.strides,
handle.shape,
vectorisation,
)
}
}

Expand Down

0 comments on commit 69de0ef

Please sign in to comment.