Skip to content

Commit

Permalink
Replace return with terminate (#2742)
Browse files Browse the repository at this point in the history
* replace return with terminate

* bump cubecl

* cargo fmt
  • Loading branch information
maxtremblay authored Jan 27, 2025
1 parent 894fdbc commit 29c383b
Show file tree
Hide file tree
Showing 35 changed files with 74 additions and 61 deletions.
41 changes: 27 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
Expand All @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/binary_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub(crate) fn kernel_scalar_binop_int<C: Int, O: BinaryOpIntFamily>(
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
Expand All @@ -105,7 +105,7 @@ pub(crate) fn kernel_binop_int<C: Int, O: BinaryOpIntFamily>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) fn cast_element<I: CubePrimitive, O: CubePrimitive>(
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
return;
terminate!();
}

let offset_input = index_offset_with_layout::<I, O>(
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp<SS: ScalarOpSpec, O: ComparisonOp<SS::C>>(
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar)));
Expand All @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp<SS: ScalarOpSpec, O: ComparisonOp<SS::C>>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ fn col2im_kernel<F: Float>(
#[comptime] has_bias: bool,
) {
if ABSOLUTE_POS >= image.len() {
return;
terminate!();
}

let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn direct_conv2d_kernel<F: Float>(
#[comptime] kernel_size_1_unroll: Option<u32>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_channels = weight.shape(1);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn im2col_kernel<F: Float>(
let out_w = args.out_w;

if ABSOLUTE_POS > args.num_elements {
return;
terminate!();
}

let out_x = ABSOLUTE_POS % out_w;
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel<E: Numeric>(
let batch = CUBE_POS_Z;

if batch >= input.shape(0) {
return;
terminate!();
}

let batch_offset = batch * input.stride(0);
Expand Down Expand Up @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel<E: Numeric>(
let hw = base_hw + mat_hw;

if hw >= shape_hw {
return;
terminate!();
}

let mat_c_start = mat_hw_start;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel<E: Numeric>(
args: ConvArgs,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_c_per_group = weight.shape(0) / args.groups;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn conv3d_kernel<F: Float>(
#[comptime] kernel_size_2_unroll: Option<u32>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_channels = weight.shape(1);
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel<F: Float>(
// Alternatively : [batch, offset_channels, out_h, out_w]

if ABSOLUTE_POS >= grad_offset.len() {
return;
terminate!();
}

let offset_channels = offset.shape(1);
Expand Down Expand Up @@ -551,7 +551,7 @@ fn deform_col2img_kernel<F: Float, FAdd: FloatAtomicAdd>(
) {
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
if ABSOLUTE_POS >= columns.len() {
return;
terminate!();
}

let n_in_channels = args.in_channels;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/flip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn flip_kernel<E: CubePrimitive, Bool: Int>(
#[comptime] rank: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn gather_kernel<T: Numeric, I: Numeric>(
dim: &u32,
) {
if ABSOLUTE_POS >= indices.len() {
return;
terminate!();
}

let index = indices[ABSOLUTE_POS];
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/repeat_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked)]
fn repeat_dim_kernel<E: CubePrimitive>(input: &Tensor<E>, output: &mut Tensor<E>, dim: u32) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn scatter_kernel<T: Numeric, I: Int>(

let should_stop = ABSOLUTE_POS >= num_elems;
if should_stop {
return;
terminate!();
}

for i in 0..shape_value {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn select_kernel<T: Numeric, I: Numeric>(
dim: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
}

if ABSOLUTE_POS >= num_elems {
return;
terminate!();
}

let strides_tensor_dim = tensor.stride(dim);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn slice_kernel<E: CubePrimitive>(
#[comptime] rank: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/bicubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch)]
fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/bilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch)]
fn interpolate_bilinear_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/nearest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch_unchecked)]
fn interpolate_nearest_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/nearest_backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch_unchecked)]
fn interpolate_nearest_backward_kernel<F: Float>(grad: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let out_h = output.shape(2);
Expand Down
Loading

0 comments on commit 29c383b

Please sign in to comment.