Skip to content

Commit

Permalink
Update cubecl (#2764)
Browse files Browse the repository at this point in the history
* Update cubecl

* Update to scope merge

* Fix bitwise shift

* Update

* Update lock for OpenSSL fix
  • Loading branch information
wingertge authored Feb 3, 2025
1 parent cb0854c commit c8f385c
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 236 deletions.
396 changes: 200 additions & 196 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/matmul/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState {
}

impl Init for FusedMatmulStateExpand {
fn init(self, _context: &mut CubeContext) -> Self {
fn init(self, _context: &mut Scope) -> Self {
self
}
}
4 changes: 2 additions & 2 deletions crates/burn-jit/src/fusion/on_write/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ impl CubeType for Arg {
}

impl Init for Arg {
fn init(self, _context: &mut CubeContext) -> Self {
fn init(self, _context: &mut Scope) -> Self {
self
}
}

impl IntoRuntime for Arg {
fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType {
self
}
}
Expand Down
12 changes: 4 additions & 8 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use burn_tensor::{
use cubecl::{
flex32,
ir::{Elem, FloatKind},
linalg::matmul::{self},
linalg::matmul::{self, kernels::MatmulLaunchError},
tensor_line_size, tf32, Feature,
};
use half::{bf16, f16};
Expand Down Expand Up @@ -195,18 +195,14 @@ where
let cube_count = Alg::cube_count(&selection, &problem);

let advanced_config = Default::default();
let config = match Alg::make_config(
let config = Alg::make_config(
config_input,
&problem,
&cube_dim,
&cube_count,
&advanced_config,
) {
Ok(val) => val,
Err(err) => {
panic!("Can't launch conv kernel because of an invalid config: {err}")
}
};
)
.map_err(MatmulLaunchError::InvalidConfig)?;

let bias = bias.unwrap_or_else(|| {
empty_device::<R, SP::EG>(input.client.clone(), input.device.clone(), Shape::new([1]))
Expand Down
38 changes: 25 additions & 13 deletions crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,38 @@ fn im2col_kernel<F: Float>(
}

#[cfg(not(test))]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX);
pub(crate) fn batches_per_run(
batch_size: usize,
out_h: usize,
out_w: usize,
) -> Result<usize, ConvLaunchError> {
use cubecl::linalg::matmul::kernels::MatmulAvailabilityError;

let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
if max_simultaneous == 0 {
return None;
return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static(
cube_count_per_batch as u32,
1,
1,
))
.into());
}
Some(
(0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.expect("Logically not possible"),
)
Ok((0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.expect("Logically not possible"))
}

#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
Some(1)
pub(crate) fn batches_per_run(
batch_size: usize,
out_h: usize,
out_w: usize,
) -> Result<usize, ConvLaunchError> {
Ok(1)
}

fn im2col<R: JitRuntime, E: FloatElement>(
Expand Down Expand Up @@ -214,8 +227,7 @@ pub fn conv2d_im2col<R: JitRuntime, E: FloatElement>(
return execute_1x1_kernel::<R, E>(input, weight, bias, options);
}

let batches_per_run = batches_per_run(batch_size, out_h, out_w)
.expect("Image too large to run even one batch at once");
let batches_per_run = batches_per_run(batch_size, out_h, out_w)?;
let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]);

let mut out = if batches_per_run != batch_size {
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-jit/src/kernel/conv/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::fmt::Debug;
use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError};
use cubecl::{
linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError},
tune::AutotuneError,
};

pub enum ConvLaunchError {
Matmul(MatmulLaunchError),
Expand Down Expand Up @@ -30,6 +33,12 @@ impl From<MatmulLaunchError> for ConvLaunchError {
}
}

impl From<MatmulAvailabilityError> for ConvLaunchError {
fn from(value: MatmulAvailabilityError) -> Self {
Self::Matmul(MatmulLaunchError::Unavailable(value))
}
}

#[allow(clippy::from_over_into)]
impl Into<AutotuneError> for ConvLaunchError {
fn into(self) -> AutotuneError {
Expand Down
16 changes: 4 additions & 12 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,26 +328,18 @@ where
}

fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let lhs_cast = kernel::cast::<R, I, u32>(lhs);
let rhs_cast = kernel::cast::<R, I, u32>(rhs);
launch_binop_int::<R, u32, kernel::BitwiseShlOp>(lhs_cast, rhs_cast)
launch_binop_int::<R, I, kernel::BitwiseShlOp>(lhs, rhs)
}

fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
let lhs_cast = kernel::cast::<R, I, u32>(lhs);
let rhs_cast = rhs.elem::<u32>();
launch_scalar_binop_int::<R, u32, BitwiseShlOp>(lhs_cast, rhs_cast)
launch_scalar_binop_int::<R, I, BitwiseShlOp>(lhs, rhs)
}

fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
let lhs_cast = kernel::cast::<R, I, u32>(lhs);
let rhs_cast = kernel::cast::<R, I, u32>(rhs);
launch_binop_int::<R, u32, BitwiseShrOp>(lhs_cast, rhs_cast)
launch_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
}

fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
let lhs_cast = kernel::cast::<R, I, u32>(lhs);
let rhs_cast = rhs.elem::<u32>();
launch_scalar_binop_int::<R, u32, BitwiseShrOp>(lhs_cast, rhs_cast)
launch_scalar_binop_int::<R, I, BitwiseShrOp>(lhs, rhs)
}
}
2 changes: 1 addition & 1 deletion crates/burn-tensor/src/tensor/quantization/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme {
}
#[cfg(feature = "cubecl")]
impl cubecl::frontend::Init for QuantizationScheme {
fn init(self, _context: &mut CubeContext) -> Self {
fn init(self, _scope: &mut cubecl::ir::Scope) -> Self {
self
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tensor/src/tests/ops/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ mod tests {

#[test]
fn should_apply_bitwise_left_shift_2d() {
if (IntType::MAX as u32) < 512 {
return;
}

let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]);
let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]);

Expand Down

0 comments on commit c8f385c

Please sign in to comment.