From da8de562b0f67869c8a8c629b8535f938fd317f9 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 8 Jan 2025 15:11:59 -0500 Subject: [PATCH] Fix/autotune error handling (#2670) --- Cargo.lock | 24 +- Cargo.toml | 4 +- .../src/fusion/matmul/optimization.rs | 3 +- .../burn-jit/src/kernel/conv/conv2d/base.rs | 16 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 20 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 6 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 12 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 23 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 6 +- .../kernel/conv/conv2d/transpose_direct.rs | 6 +- .../src/kernel/conv/conv2d/tune/conv2d.rs | 40 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 10 +- .../kernel/conv/deform_conv_transpose2d.rs | 32 +- crates/burn-jit/src/kernel/conv/error.rs | 20 + crates/burn-jit/src/kernel/conv/mod.rs | 2 + crates/burn-jit/src/kernel/matmul/base.rs | 12 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 12 +- crates/burn-jit/src/kernel/reduce/base.rs | 4 +- .../src/kernel/reduce/naive/kernel.rs | 4 +- crates/burn-jit/src/kernel/reduce/prod.rs | 2 +- .../src/kernel/reduce/shared/kernel.rs | 4 +- .../src/kernel/reduce/subcube/kernel.rs | 4 +- crates/burn-jit/src/kernel/reduce/sum.rs | 2 +- crates/burn-jit/src/ops/float_ops.rs | 12 +- crates/burn-jit/src/ops/int_ops.rs | 10 +- crates/burn-jit/src/ops/module_ops.rs | 6 +- crates/burn-jit/src/tests/mod.rs | 2 - crates/burn-jit/src/tests/reduce.rs | 566 ------------------ 28 files changed, 158 insertions(+), 706 deletions(-) create mode 100644 crates/burn-jit/src/kernel/conv/error.rs delete mode 100644 crates/burn-jit/src/tests/reduce.rs diff --git a/Cargo.lock b/Cargo.lock index 20d136df9c..11ae2b3de3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1581,7 +1581,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1613,7 +1613,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1630,7 +1630,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1649,7 +1649,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-core", @@ -1717,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1769,7 +1769,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "async-channel", "async-lock", @@ -1790,7 +1790,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1804,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index f51aa31ce3..3a03582fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index b1d8431c67..d0cd8749ad 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -122,7 +122,8 @@ impl MatmulOptimization { rhs_tensor, None, matmul::MatmulStrategy::default(), - ); + ) + .unwrap(); (out_tensor, out) }; context diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 0b3a35dc45..f015677a2b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -1,6 +1,8 @@ use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; -use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime}; +use crate::{ + kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime, +}; #[cfg(feature = "autotune")] use super::{conv2d_autotune, conv_transpose2d_autotune}; @@ -75,11 +77,11 @@ pub fn conv2d( bias: Option>, options: ConvOptions<2>, strategy: Conv2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] - Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Autotune => Ok(conv2d_autotune::(input, weight, bias, options)), Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), Conv2dStrategy::ImplicitGemmComplex => { @@ -102,15 +104,15 @@ pub fn conv_transpose2d( bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { ConvTranspose2dStrategy::Direct => { conv_transpose2d_direct::(input, weight, bias, options) } #[cfg(feature = "autotune")] - ConvTranspose2dStrategy::Autotune => { - conv_transpose2d_autotune::(input, weight, bias, options) - } + ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::( + input, weight, bias, options, + )), ConvTranspose2dStrategy::Gemm => { conv_transpose2d_col2im::(input, weight, bias, options) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0d9c48dc30..11fb3b4aee 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -6,6 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ + conv::ConvLaunchError, into_contiguous, matmul::{matmul, MatmulStrategy}, slice, @@ -29,7 +30,7 @@ pub fn conv_transpose2d_col2im( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); let [batch_size, _, input_h, input_w] = input.shape.dims(); let groups = options.groups; @@ -94,9 +95,12 @@ pub fn conv_transpose2d_col2im( options.clone(), kernel_h, kernel_w, - ); + )?; } - reshape(image, Shape::new([batch_size, im_channels, im_h, im_w])) + Ok(reshape( + image, + Shape::new([batch_size, im_channels, im_h, im_w]), + )) } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); @@ -108,8 +112,8 @@ pub fn conv_transpose2d_col2im( options, kernel_h, kernel_w, - ); - image + )?; + Ok(image) } } @@ -135,7 +139,7 @@ fn execute( options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [batch_size, _, input_h, input_w] = input.shape.dims(); let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims(); @@ -145,12 +149,14 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = matmul::(weight, input, None, MatmulStrategy::default()); + let columns = matmul::(weight, input, None, MatmulStrategy::default())?; let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( columns, bias, image, kernel_h, kernel_w, input_h, input_w, options, ); + + Ok(()) } #[allow(clippy::too_many_arguments)] diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index d5154ecc4b..c724cfc3a3 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -5,7 +5,7 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -125,7 +125,7 @@ pub fn conv2d_direct( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let channels_per_group = out_channels / options.groups; @@ -193,5 +193,5 @@ pub fn conv2d_direct( kernel_w_unroll, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index c99861c82d..abc94d1a9a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -23,7 +23,7 @@ use crate::{ algorithm::{Algorithm, ImplicitCmmaConv}, base::{ConvolutionLaunch, ConvolutionProblem}, }, - nchw_to_nhwc, Conv2dAutotuneKey, + nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError, }, into_contiguous, }, @@ -44,7 +44,7 @@ pub fn conv2d_gemm_cmma_large_m( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -60,7 +60,7 @@ pub fn conv2d_gemm_cmma_balanced( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -74,7 +74,7 @@ fn conv2d_gemm_cmma_strategy< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { if TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) } else if TypeId::of::() == TypeId::of::() || TypeId::of::() == TypeId::of::() @@ -102,7 +102,7 @@ pub fn conv2d_gemm_with_algo< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor +) -> Result, ConvLaunchError> where SP::EG: JitElement, { @@ -221,7 +221,7 @@ where // Reset to NCHW let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels])); - permute(out, &[0, 3, 1, 2]) + Ok(permute(out, &[0, 3, 1, 2])) } pub fn problem_from_key( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index a65c29466c..7f9914989a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -6,7 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ - conv::index, + conv::{index, ConvLaunchError}, into_contiguous, launch_binop, matmul::{matmul, MatmulStrategy}, AddOp, @@ -188,7 +188,7 @@ pub fn conv2d_im2col( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, in_channels, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -237,13 +237,13 @@ pub fn conv2d_im2col( options.clone(), out_h, out_w, - ); + )?; } let out = swap_dims(out, 1, 2); reshape(out, Shape::new([batch_size, out_channels, out_h, out_w])) } else { let out = empty_device::(input.client.clone(), input.device.clone(), matmul_shape); - execute::(input, weight, out.clone(), options, out_h, out_w); + execute::(input, weight, out.clone(), options, out_h, out_w)?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); swap_dims(out, 0, 1) }; @@ -252,7 +252,8 @@ pub fn conv2d_im2col( let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); out = launch_binop::(out, bias) } - out + + Ok(out) } fn execute_1x1_kernel( @@ -260,7 +261,7 @@ fn execute_1x1_kernel( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, height, width] = input.shape.dims(); let [out_channels, in_c_per_grp, _, _] = weight.shape.dims(); let groups = options.groups; @@ -271,7 +272,7 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = matmul::(weight, input, None, MatmulStrategy::default()); + let out = matmul::(weight, input, None, MatmulStrategy::default())?; let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { @@ -279,7 +280,7 @@ fn execute_1x1_kernel( out = launch_binop::(out, bias) } - swap_dims(out, 0, 1) + Ok(swap_dims(out, 0, 1)) } fn execute( @@ -289,7 +290,7 @@ fn execute( options: ConvOptions<2>, out_h: usize, out_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -301,5 +302,7 @@ fn execute( let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); - matmul::(weight, columns, Some(out), Default::default()); + matmul::(weight, columns, Some(out), Default::default())?; + + Ok(()) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 2e4f469068..2e8e147170 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -12,7 +12,7 @@ use cubecl::{ use half::f16; use crate::{ - kernel::{into_contiguous, slice, slice_assign}, + kernel::{conv::ConvLaunchError, into_contiguous, slice, slice_assign}, ops::{ numeric::{empty_device, zeros_device}, permute, @@ -35,7 +35,7 @@ pub fn conv2d_implicit_gemm( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32) && input .client @@ -210,7 +210,7 @@ pub fn conv2d_implicit_gemm( let out = slice::(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); // Reset to NCHW - permute(out, &[0, 3, 1, 2]) + Ok(permute(out, &[0, 3, 1, 2])) } fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 6a97ab8759..d3e91d5947 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ element::JitElement, - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -126,7 +126,7 @@ pub fn conv_transpose2d_direct( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims(); @@ -184,5 +184,5 @@ pub fn conv_transpose2d_direct( ), ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 157d4d443d..c6eb31ea9c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -97,34 +97,6 @@ macro_rules! check_algo { _ => can_launch::<$algo, R, ($float, $float, f32)>($input, $problem), } }; - - ($algo:tt, $input:expr, $problem:expr) => { - let plane_dim = 32; - let conv_problem = $problem; - - let (selection, config_input) = $algo::select_kernel::(plane_dim); - let cube_dim = ImplicitCmmaConv::cube_dim(&selection); - let cube_count = ImplicitCmmaConv::cube_count(&selection, &conv_problem); - - let advanced_config = Default::default(); - let config = ImplicitCmmaConv::make_config( - config_input, - &conv_problem, - &cube_dim, - &cube_count, - &advanced_config, - ); - - match config { - Ok(config) => ImplicitCmmaConv::can_launch::( - &op.input.client, - &conv_problem, - &config, - &selection, - ), - Err(_) => false, - } - }; } fn should_run( @@ -180,13 +152,21 @@ fn can_launch, R: JitRuntime, CS: ConvPrecisio input: &JitTensor, conv_problem: &ConvolutionProblem, ) -> bool { - let plane_dim = 32; + let plane_dim = match input + .client + .properties() + .hardware_properties() + .defined_plane_size() + { + Some(val) => val, + None => return false, + }; let (selection, config_input) = S::select_kernel::(plane_dim); let cube_dim = ImplicitCmmaConv::cube_dim(&selection); let cube_count = ImplicitCmmaConv::cube_count(&selection, conv_problem); - let advanced_config = Default::default(); + let config = ImplicitCmmaConv::make_config( config_input, conv_problem, diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index b22821aef1..300d714335 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -19,6 +19,8 @@ use crate::{ FloatElement, JitRuntime, }; +use super::ConvLaunchError; + #[derive(CubeLaunch)] struct DeformConv2dArgs { conv_stride_h: u32, @@ -262,7 +264,7 @@ pub(crate) fn deform_conv2d( mask: Option>, bias: Option>, options: DeformConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let offset = into_contiguous(offset); let weight = into_contiguous(weight); @@ -298,15 +300,15 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = matmul::(weight, columns, None, MatmulStrategy::default()); + let out = matmul::(weight, columns, None, MatmulStrategy::default())?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - launch_binop::(out, bias) + Ok(launch_binop::(out, bias)) } else { - out + Ok(out) } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index b75ac43182..ad9e11c6c5 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -19,7 +19,7 @@ use crate::{ FloatElement, IntElement, JitBackend, JitRuntime, }; -use super::{bilinear_interpolate, deform_im2col, index}; +use super::{bilinear_interpolate, deform_im2col, index, ConvLaunchError}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] @@ -36,7 +36,7 @@ pub(crate) fn deform_conv2d_backward< bias: Option>, out_grad: JitTensor, options: DeformConvOptions<2>, -) -> DeformConv2dBackward> { +) -> Result>, ConvLaunchError> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); @@ -60,7 +60,7 @@ pub(crate) fn deform_conv2d_backward< out_grad.clone(), &options, (kernel_h, kernel_w), - ); + )?; let weight_grad = compute_weight_grad::( input, @@ -70,15 +70,15 @@ pub(crate) fn deform_conv2d_backward< options, (kernel_h, kernel_w), (out_h, out_w), - ); + )?; - DeformConv2dBackward::new( + Ok(DeformConv2dBackward::new( input_gradient, offset_gradient, weight_grad, mask_gradient, gradient_bias, - ) + )) } fn compute_weight_grad( @@ -89,7 +89,7 @@ fn compute_weight_grad( options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), -) -> JitTensor { +) -> Result, ConvLaunchError> { let [_, in_channels, _, _] = input.shape.dims(); let [_, out_channels, _, _] = out_grad.shape.dims(); let (kernel_h, kernel_w) = kernel_dims; @@ -108,12 +108,12 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default()); + let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default())?; - reshape( + Ok(reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), - ) + )) } type InputGradients = (JitTensor, JitTensor, Option>); @@ -126,7 +126,7 @@ fn backward_gradient_inputs( out_grad: JitTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> InputGradients { +) -> Result, ConvLaunchError> { let client = out_grad.client.clone(); let device = out_grad.device.clone(); @@ -150,7 +150,7 @@ fn backward_gradient_inputs( for group in 0..groups { let weight = swap_dims(index::(weight.clone(), group), 0, 1); let out_grad = index::(out_grad.clone(), group); - let values = matmul::(weight, out_grad, None, MatmulStrategy::default()); + let values = matmul::(weight, out_grad, None, MatmulStrategy::default())?; let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign::( columns, @@ -169,12 +169,12 @@ fn backward_gradient_inputs( mask.clone(), options, kernel_dims, - ); + )?; let input_gradient = compute_input_grad::(columns, offset, mask, options, kernel_dims, input_shape); - (input_gradient, offset_gradient, mask_gradient) + Ok((input_gradient, offset_gradient, mask_gradient)) } fn compute_offset_and_mask_gradient( @@ -184,7 +184,7 @@ fn compute_offset_and_mask_gradient( mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> (JitTensor, Option>) { +) -> Result<(JitTensor, Option>), ConvLaunchError> { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_height, kernel_width) = kernel_dims; @@ -238,7 +238,7 @@ fn compute_offset_and_mask_gradient( }; let mask_gradient = if use_mask { Some(grad_mask) } else { None }; - (grad_offset, mask_gradient) + Ok((grad_offset, mask_gradient)) } #[derive(CubeLaunch)] diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs new file mode 100644 index 0000000000..2f15bc9886 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -0,0 +1,20 @@ +use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; + +#[derive(Debug)] +pub enum ConvLaunchError { + Matmul(MatmulLaunchError), + Unknown, +} + +impl From for ConvLaunchError { + fn from(value: MatmulLaunchError) -> Self { + Self::Matmul(value) + } +} + +#[allow(clippy::from_over_into)] +impl Into for ConvLaunchError { + fn into(self) -> AutotuneError { + AutotuneError::Unknown(format!("{self:?}")) + } +} diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 5d6794495f..04794e9b42 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -3,11 +3,13 @@ mod conv3d; mod conv_transpose3d; mod deform_conv2d; mod deform_conv_transpose2d; +mod error; pub(crate) use conv2d::*; pub(crate) use conv3d::*; pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv2d::*; pub(crate) use deform_conv_transpose2d::*; +pub(crate) use error::*; pub use conv2d::{conv2d, conv_transpose2d, nchw_to_nhwc, Conv2dStrategy, ConvTranspose2dStrategy}; diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 7fa141cf67..611f1e32d4 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,3 +1,5 @@ +use cubecl::linalg::matmul::kernels::MatmulLaunchError; + use super::init_matmul_output; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; @@ -30,7 +32,7 @@ pub fn matmul( rhs: JitTensor, out: Option>, strategy: MatmulStrategy, -) -> JitTensor { +) -> Result, MatmulLaunchError> { match strategy { MatmulStrategy::Cube => { let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); @@ -43,11 +45,11 @@ pub fn matmul( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ) - .unwrap(); - out + )?; + + Ok(out) } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs, out), + MatmulStrategy::Autotune => Ok(matmul_autotune::(lhs, rhs, out)), } } diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 3f3232db10..46b1dfacc6 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -83,7 +83,7 @@ fn matmul_accelerated( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Standard, &lhs.client, @@ -91,14 +91,14 @@ fn matmul_accelerated( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_tiling2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Tiling2D(Tiling2dConfig::default()), &lhs.client, @@ -106,14 +106,14 @@ fn matmul_tiling2d( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_simple( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Simple, &lhs.client, @@ -121,5 +121,5 @@ fn matmul_simple( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 730cc83f37..57cdf13b1e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -63,13 +63,13 @@ macro_rules! reduce_operation { tensor: JitTensor, dim: usize, strategy: ReduceStrategy, - ) -> JitTensor { + ) -> Result, String> { match strategy { ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim), + ReduceStrategy::Autotune => Ok(reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim)), } } }; diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs index a3a1a5441b..c862e7070d 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs @@ -50,7 +50,7 @@ fn naive_reduce, EI: Numeric, EO: Numeric>( pub fn reduce_dim_naive( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let output = init_reduce_output::(&input, dim); let cube_dim = CubeDim::default(); @@ -67,5 +67,5 @@ pub fn reduce_dim_naive( let shape = Shape::new([input.shape.num_elements()]); let input: JitTensor = JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy) + prod_dim::(input, 0, strategy).unwrap() } diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs index 1b2dcb356e..1c15e4523f 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs @@ -85,7 +85,7 @@ pub fn reduce_dim_shared< >( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let output = init_reduce_output::(&input, dim); let num_elems_output = output.shape.num_elements(); @@ -113,5 +113,5 @@ pub fn reduce_dim_shared< divisible_shape, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index 4a32b5d641..26f65f5d68 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -88,7 +88,7 @@ pub fn reduce_dim_subcube< >( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let topology = input.client.properties().hardware_properties(); if !input.client.properties().feature_enabled(Feature::Plane) @@ -130,5 +130,5 @@ pub fn reduce_dim_subcube< divisible_shape, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs index fea80bccf0..d3c9416dc1 100644 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ b/crates/burn-jit/src/kernel/reduce/sum.rs @@ -11,5 +11,5 @@ pub fn sum( let shape = Shape::new([input.shape.num_elements()]); let input: JitTensor = JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy) + sum_dim::(input, 0, strategy).unwrap() } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 6090e895ae..c59b9df83c 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -165,7 +165,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - matmul::(lhs, rhs, None, MatmulStrategy::default()) + matmul::(lhs, rhs, None, MatmulStrategy::default()).unwrap() ) } @@ -363,7 +363,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum_dim::(tensor, dim, Default::default()) + reduce::sum_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -371,7 +371,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::mean_dim::(tensor, dim, Default::default()) + reduce::mean_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -387,7 +387,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod_dim::(tensor, dim, Default::default()) + reduce::prod_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -467,7 +467,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmax::(tensor, dim, Default::default()) + reduce::argmax::(tensor, dim, Default::default()).unwrap() ) } @@ -475,7 +475,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmin::(tensor, dim, Default::default()) + reduce::argmin::(tensor, dim, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index a0e181a9c7..ed99258826 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -197,7 +197,7 @@ where } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim::(tensor, dim, Default::default()) + kernel::reduce::sum_dim::(tensor, dim, Default::default()).unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { @@ -205,19 +205,19 @@ where } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim::(tensor, dim, Default::default()) + kernel::reduce::prod_dim::(tensor, dim, Default::default()).unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim::(tensor, dim, Default::default()) + kernel::reduce::mean_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax::(tensor, dim, Default::default()) + kernel::reduce::argmax::(tensor, dim, Default::default()).unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin::(tensor, dim, Default::default()) + kernel::reduce::argmin::(tensor, dim, Default::default()).unwrap() } fn int_clamp( diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index b5c96058f9..c7f7b18b32 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -25,7 +25,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()).unwrap() } fn deform_conv2d( @@ -36,7 +36,7 @@ where bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { - kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options).unwrap() } fn deform_conv2d_backward( @@ -57,6 +57,7 @@ where output_grad, options, ) + .unwrap() } fn conv3d( @@ -81,6 +82,7 @@ where options, ConvTranspose2dStrategy::default(), ) + .unwrap() } fn conv_transpose3d( diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index f60edc2a1b..378eb035ed 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -17,7 +17,6 @@ mod max_pool2d; mod max_pool2d_backward; mod normal; mod quantization; -mod reduce; mod repeat_dim; mod scatter; mod select; @@ -48,7 +47,6 @@ macro_rules! testgen_all { mod kernel { use super::*; - burn_jit::testgen_reduction!(); burn_jit::testgen_conv2d!(); burn_jit::testgen_conv3d!(); burn_jit::testgen_conv_transpose2d!(); diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs deleted file mode 100644 index 3e8f81fa8c..0000000000 --- a/crates/burn-jit/src/tests/reduce.rs +++ /dev/null @@ -1,566 +0,0 @@ -#[burn_tensor_testgen::testgen(reduction)] -mod reduction { - use super::*; - use burn_jit::kernel::reduce::{ - argmax, argmin, mean_dim, prod, prod_dim, sum, sum_dim, ReduceStrategy, - }; - use burn_tensor::{ - backend::Backend, ops::IntTensorOps, Distribution, Int, Shape, Tensor, TensorData, - TensorPrimitive, - }; - - #[test] - fn reduction_sum_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.sum_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.prod_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmin_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn sum_dim_should_work_with_int() { - let summed_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(sum_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - let sum_as_data = TensorData::from([10]); - val.into_data().assert_approx_eq(&sum_as_data, 1); - } - - #[test] - fn mean_dim_should_work_with_int() { - let mean_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(mean_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - // Mean calculation truncates to an integer - let mean_as_data = TensorData::from([2]); - val.into_data().assert_approx_eq(&mean_as_data, 1); - } - - #[test] - fn reduction_sum_dim_shared_memory_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium_divisible() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_medium_divisible() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium_not_divisible() { - let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_medium_not_divisible() { - let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_mean_dim_shared_memory_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_mean_dim_subcube_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmin_shared_memory_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmin_subcube_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_shared_memory_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_subcube_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_sum_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.sum(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.prod(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_float() { - let data = TensorData::from([-999999., -999997., -999998.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_float() { - let data = TensorData::from([999999., 999998., 999997.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_i32() { - let data = TensorData::from([999999, 999998, 999997]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_i32() { - let data = TensorData::from([-999999, -999997, -999998]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } -}