Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace return with terminate #2742

Merged
merged 3 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
Expand All @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/binary_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub(crate) fn kernel_scalar_binop_int<C: Int, O: BinaryOpIntFamily>(
output: &mut Tensor<Line<C>>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
Expand All @@ -105,7 +105,7 @@ pub(crate) fn kernel_binop_int<C: Int, O: BinaryOpIntFamily>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) fn cast_element<I: CubePrimitive, O: CubePrimitive>(
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
return;
terminate!();
}

let offset_input = index_offset_with_layout::<I, O>(
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp<SS: ScalarOpSpec, O: ComparisonOp<SS::C>>(
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
return;
terminate!();
}

output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar)));
Expand All @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp<SS: ScalarOpSpec, O: ComparisonOp<SS::C>>(
let mut offset_rhs = ABSOLUTE_POS;

if offset_out >= out.len() {
return;
terminate!();
}

if to_contiguous_lhs {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ fn col2im_kernel<F: Float>(
#[comptime] has_bias: bool,
) {
if ABSOLUTE_POS >= image.len() {
return;
terminate!();
}

let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn direct_conv2d_kernel<F: Float>(
#[comptime] kernel_size_1_unroll: Option<u32>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_channels = weight.shape(1);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn im2col_kernel<F: Float>(
let out_w = args.out_w;

if ABSOLUTE_POS > args.num_elements {
return;
terminate!();
}

let out_x = ABSOLUTE_POS % out_w;
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel<E: Numeric>(
let batch = CUBE_POS_Z;

if batch >= input.shape(0) {
return;
terminate!();
}

let batch_offset = batch * input.stride(0);
Expand Down Expand Up @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel<E: Numeric>(
let hw = base_hw + mat_hw;

if hw >= shape_hw {
return;
terminate!();
}

let mat_c_start = mat_hw_start;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel<E: Numeric>(
args: ConvArgs,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_c_per_group = weight.shape(0) / args.groups;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv3d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn conv3d_kernel<F: Float>(
#[comptime] kernel_size_2_unroll: Option<u32>,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let in_channels = weight.shape(1);
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel<F: Float>(
// Alternatively : [batch, offset_channels, out_h, out_w]

if ABSOLUTE_POS >= grad_offset.len() {
return;
terminate!();
}

let offset_channels = offset.shape(1);
Expand Down Expand Up @@ -551,7 +551,7 @@ fn deform_col2img_kernel<F: Float, FAdd: FloatAtomicAdd>(
) {
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
if ABSOLUTE_POS >= columns.len() {
return;
terminate!();
}

let n_in_channels = args.in_channels;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/flip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn flip_kernel<E: CubePrimitive, Bool: Int>(
#[comptime] rank: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn gather_kernel<T: Numeric, I: Numeric>(
dim: &u32,
) {
if ABSOLUTE_POS >= indices.len() {
return;
terminate!();
}

let index = indices[ABSOLUTE_POS];
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/repeat_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked)]
fn repeat_dim_kernel<E: CubePrimitive>(input: &Tensor<E>, output: &mut Tensor<E>, dim: u32) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn scatter_kernel<T: Numeric, I: Int>(

let should_stop = ABSOLUTE_POS >= num_elems;
if should_stop {
return;
terminate!();
}

for i in 0..shape_value {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn select_kernel<T: Numeric, I: Numeric>(
dim: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/select_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn select_assign_kernel<F: Numeric, I: Numeric>(
}

if ABSOLUTE_POS >= num_elems {
return;
terminate!();
}

let strides_tensor_dim = tensor.stride(dim);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/index/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn slice_kernel<E: CubePrimitive>(
#[comptime] rank: u32,
) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let mut offset_input = 0;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/bicubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch)]
fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/bilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch)]
fn interpolate_bilinear_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/nearest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch_unchecked)]
fn interpolate_nearest_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/interpolate/nearest_backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime};
#[cube(launch_unchecked)]
fn interpolate_nearest_backward_kernel<F: Float>(grad: &Tensor<F>, output: &mut Tensor<F>) {
if ABSOLUTE_POS >= output.len() {
return;
terminate!();
}

let out_h = output.shape(2);
Expand Down
Loading
Loading