From 29c383b87d58190499c2671eb08f47cd48cc4e43 Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 27 Jan 2025 09:57:28 -0500 Subject: [PATCH] Replace return with terminate (#2742) * replace return with terminate * bump cubecl * cargo fmt --- Cargo.lock | 41 ++++++++++++------- Cargo.toml | 4 +- crates/burn-jit/src/kernel/binary.rs | 4 +- crates/burn-jit/src/kernel/binary_int.rs | 4 +- crates/burn-jit/src/kernel/cast/base.rs | 2 +- crates/burn-jit/src/kernel/comparison.rs | 4 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 2 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 2 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 2 +- .../src/kernel/conv/conv2d/layout_swap.rs | 4 +- .../kernel/conv/conv2d/transpose_direct.rs | 2 +- crates/burn-jit/src/kernel/conv/conv3d.rs | 2 +- .../kernel/conv/deform_conv_transpose2d.rs | 4 +- crates/burn-jit/src/kernel/index/flip.rs | 2 +- crates/burn-jit/src/kernel/index/gather.rs | 2 +- .../burn-jit/src/kernel/index/repeat_dim.rs | 2 +- crates/burn-jit/src/kernel/index/scatter.rs | 2 +- crates/burn-jit/src/kernel/index/select.rs | 2 +- .../src/kernel/index/select_assign.rs | 2 +- crates/burn-jit/src/kernel/index/slice.rs | 2 +- .../src/kernel/interpolate/bicubic.rs | 2 +- .../src/kernel/interpolate/bilinear.rs | 2 +- .../src/kernel/interpolate/nearest.rs | 2 +- .../kernel/interpolate/nearest_backward.rs | 2 +- crates/burn-jit/src/kernel/mask/mask_fill.rs | 4 +- crates/burn-jit/src/kernel/mask/mask_where.rs | 4 +- .../src/kernel/pool/avg_pool2d_backward.rs | 2 +- .../src/kernel/pool/max_pool2d_backward.rs | 2 +- .../src/kernel/quantization/dequantize.rs | 4 +- .../src/kernel/quantization/quantize.rs | 10 ++--- crates/burn-jit/src/kernel/unary_float.rs | 2 +- crates/burn-jit/src/kernel/unary_int.rs | 2 +- crates/burn-jit/src/kernel/unary_numeric.rs | 2 +- crates/burn-jit/src/ops/numeric.rs | 2 +- examples/custom-cubecl-kernel/src/kernel.rs | 2 +- 35 files changed, 74 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f739c341ad..5f5b948a86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1590,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,9 +1688,11 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", + "cubecl-macros-internal", + "derive_more 1.0.0", "float-ord", "half", "num-traits", @@ -1701,7 +1703,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1715,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "darling", @@ -1725,10 +1727,21 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "cubecl-macros-internal" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1757,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1767,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1789,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index af983f62cb..db1073ad67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d7c4d789ab..f0da764a7a 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs index 06706a7d28..390bfc479e 100644 --- a/crates/burn-jit/src/kernel/binary_int.rs +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -85,7 +85,7 @@ pub(crate) fn kernel_scalar_binop_int( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -105,7 +105,7 @@ pub(crate) fn kernel_binop_int( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 798b79a0f0..43b24f071a 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -12,7 +12,7 @@ pub(crate) fn cast_element( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } let offset_input = index_offset_with_layout::( diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index e33687fb5a..a6de9025bb 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp>( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar))); @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 11fb3b4aee..4f6931f86d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -241,7 +241,7 @@ fn col2im_kernel( #[comptime] has_bias: bool, ) { if ABSOLUTE_POS >= image.len() { - return; + terminate!(); } let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index c724cfc3a3..1cd24f7c0c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -35,7 +35,7 @@ fn direct_conv2d_kernel( #[comptime] kernel_size_1_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 6b738ab988..f74cdaf8bc 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -53,7 +53,7 @@ fn im2col_kernel( let out_w = args.out_w; if ABSOLUTE_POS > args.num_elements { - return; + terminate!(); } let out_x = ABSOLUTE_POS % out_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index 62f0e56d8f..7cbe09dbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel( let batch = CUBE_POS_Z; if batch >= input.shape(0) { - return; + terminate!(); } let batch_offset = batch * input.stride(0); @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel( let hw = base_hw + mat_hw; if hw >= shape_hw { - return; + terminate!(); } let mat_c_start = mat_hw_start; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index d3e91d5947..a8cd1ceb7f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel( args: ConvArgs, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 157610794b..a616c432b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -41,7 +41,7 @@ fn conv3d_kernel( #[comptime] kernel_size_2_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index ddee1360e4..5840f4dc9f 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel( // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.len() { - return; + terminate!(); } let offset_channels = offset.shape(1); @@ -551,7 +551,7 @@ fn deform_col2img_kernel( ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] if ABSOLUTE_POS >= columns.len() { - return; + terminate!(); } let n_in_channels = args.in_channels; diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 583e0346d3..a682a76eac 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -11,7 +11,7 @@ fn flip_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 9e9b5685bb..c1aa56072e 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -12,7 +12,7 @@ fn gather_kernel( dim: &u32, ) { if ABSOLUTE_POS >= indices.len() { - return; + terminate!(); } let index = indices[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 3887bfbd8b..b19f9e2b21 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor, dim: u32) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 4ddd9c00fb..4cca94f824 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -46,7 +46,7 @@ fn scatter_kernel( let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { - return; + terminate!(); } for i in 0..shape_value { diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index b104bf504f..fe664ab420 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -10,7 +10,7 @@ fn select_kernel( dim: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index a0fed49dbd..cd4c013f63 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -29,7 +29,7 @@ fn select_assign_kernel( } if ABSOLUTE_POS >= num_elems { - return; + terminate!(); } let strides_tensor_dim = tensor.stride(dim); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 7f20f033b8..b6daba8da5 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -52,7 +52,7 @@ fn slice_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 1d545d79c7..3f77ef1302 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 3557fcdbb8..f0cb95b536 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 0743a13567..0e6ba32552 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 5ea860a7ae..f0442ec92e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let out_h = output.shape(2); diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 386e7a5039..95096c7994 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -16,7 +16,7 @@ fn mask_fill_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -35,7 +35,7 @@ fn mask_fill_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 5518e9648b..99384fde98 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -16,7 +16,7 @@ fn mask_where_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -36,7 +36,7 @@ fn mask_where_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index bba68c7166..d2a5a21d0a 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -24,7 +24,7 @@ fn avg_pool2d_backward_kernel( #[comptime] count_include_pad: bool, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 6da6e2b37c..40259c4573 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -16,7 +16,7 @@ fn max_pool2d_with_indices_backward_kernel( #[comptime] kernel_size_1: i32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..270e32f854 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -48,7 +48,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( ) { // Last two positions contain the qparams if ABSOLUTE_POS >= input.len() - 2 { - return; + terminate!(); } let qparams = QParams::new(scheme); @@ -85,7 +85,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { - return; + terminate!(); } let qparams = QParams::new(scheme); diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..0a7b0ea553 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -34,7 +34,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -43,13 +43,13 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } // Cast the offset to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 2 { output[ABSOLUTE_POS] = u32::bitcast_from(offset); - return; + terminate!(); } let line_size = comptime!(input.line_size()); @@ -120,7 +120,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -128,7 +128,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } let line_size = comptime!(input.line_size()); diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs index 33a311ecbc..4664d3c0b3 100644 --- a/crates/burn-jit/src/kernel/unary_float.rs +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_float( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs index 5e60898699..17bced52d1 100644 --- a/crates/burn-jit/src/kernel/unary_int.rs +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_int( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs index 0b8dcb2cbc..aaeadbb685 100644 --- a/crates/burn-jit/src/kernel/unary_numeric.rs +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_numeric( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 2c2c7987ab..cf15916aab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -31,7 +31,7 @@ pub fn full_device( #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { - return; + terminate!(); } tensor[ABSOLUTE_POS] = value; diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 0809971327..08d4ded4d7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -17,7 +17,7 @@ pub fn fused_matmul_add_relu_kernel( let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { - return; + terminate!(); } let offset_output = batch * n_rows * n_cols;