diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d799d1caea..d7c4d789ab 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -4,34 +4,60 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, + tensor_line_size_parallel, }; use super::into_contiguous; +pub(crate) trait BinaryOpFamily: Send + Sync + 'static { + type BinaryOp: BinaryOp; +} + #[cube] pub(crate) trait BinaryOp: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Line, rhs: Line) -> Line; } -pub(crate) trait BinaryOpSpec: Send + Sync + 'static { - type C: Numeric; -} -pub(crate) struct Spec { - _c: PhantomData, -} - -impl BinaryOpSpec for Spec { - type C = C; -} - pub(crate) struct AddOp; pub(crate) struct SubOp; pub(crate) struct MulOp; pub(crate) struct DivOp; pub(crate) struct RemainderOp; -pub(crate) struct PowOp; + +/// Since Powf only works on float, but we still want to implement the numeric binary op family, we +/// set another precision in the family type to cast, when necessary, the input value to a valid +/// float. +/// +/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using +/// the family pattern for [PowOp], but at least we don't duplicate code. +pub(crate) struct PowOp { + _f: PhantomData, +} + +impl BinaryOpFamily for AddOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for SubOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for MulOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for DivOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for RemainderOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for PowOp { + type BinaryOp = Self; +} #[cube] impl BinaryOp for AddOp { @@ -69,30 +95,34 @@ impl BinaryOp for RemainderOp { } #[cube] -impl BinaryOp for PowOp { +impl BinaryOp for PowOp { fn execute(lhs: Line, rhs: Line) -> Line { - Line::powf(lhs, rhs) + let lhs = Line::::cast_from(lhs); + let rhs = Line::::cast_from(rhs); + let out = Line::powf(lhs, rhs); + + Line::cast_from(out) } } -#[cube(launch)] -pub(crate) fn kernel_scalar_binop>( - input: &Tensor>, - scalar: BS::C, - output: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { return; } - output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], Line::new(scalar)); + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); } -#[cube(launch)] -pub(crate) fn kernel_binop>( - lhs: &Tensor>, - rhs: &Tensor>, - out: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, #[comptime] rank: Option, #[comptime] to_contiguous_lhs: bool, #[comptime] to_contiguous_rhs: bool, @@ -106,7 +136,7 @@ pub(crate) fn kernel_binop>( } if to_contiguous_lhs { - offset_lhs = index_offset_with_layout::( + offset_lhs = index_offset_with_layout::( lhs, out, offset_out, @@ -117,7 +147,7 @@ pub(crate) fn kernel_binop>( } if to_contiguous_rhs { - offset_rhs = index_offset_with_layout::( + offset_rhs = index_offset_with_layout::( rhs, out, offset_out, @@ -127,20 +157,27 @@ pub(crate) fn kernel_binop>( ); } - out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]); + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); } -pub(crate) fn launch_binop>( +pub(crate) fn launch_binop( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { let ndims = lhs.shape.num_dims(); - let vectorization_factor_lhs = - tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); - let vectorization_factor_rhs = - tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1); - - let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); let mut shape_out = vec![0; ndims]; lhs.shape @@ -157,59 +194,60 @@ pub(crate) fn launch_binop>( let num_elems = shape_out.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if lhs.can_mut_broadcast(&rhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - None, - false, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - ); - - lhs - } else if rhs.can_mut_broadcast(&lhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(1), - None, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - false, - ); - - rhs - } else { - let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); - let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; - let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; - - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - None, - to_contiguous_lhs, - to_contiguous_rhs, - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } } } -pub(crate) fn launch_scalar_binop>( +pub(crate) fn launch_scalar_binop( mut tensor: JitTensor, scalar: E, ) -> JitTensor { @@ -219,42 +257,47 @@ pub(crate) fn launch_scalar_binop>( // Vectorization is only enabled when the last dimension is contiguous. let ndims = tensor.shape.num_dims(); - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); let client = tensor.client.clone(); let num_elems = tensor.shape.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if tensor.can_mut() { - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - TensorArg::alias(0), - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - output.as_tensor_arg::(vectorization_factor), - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } } } diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 683e8aff8f..ec2bc93d1f 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -1,7 +1,11 @@ use cubecl::prelude::*; -use crate::kernel::{launch_unary, UnaryOp}; -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}, + tensor::JitTensor, + JitRuntime, +}; #[derive(CubeLaunch)] struct Options { @@ -16,28 +20,25 @@ pub(crate) fn clamp( ) -> JitTensor { struct ClampOp; - impl UnaryOp for ClampOp { - type Options = Options; + #[cube] + impl NumericUnaryOp for ClampOp { + type Options = Options; - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - options: OptionsExpand, - ) -> as CubeType>::ExpandType { - #[cube] - fn execute(input: Line, options: &Options) -> Line { - Line::clamp( - input, - Line::new(options.min_value), - Line::new(options.max_value), - ) - } - - execute::expand(context, input, options) + fn execute(input: Line, options: &Self::Options) -> Line { + Line::clamp( + input, + Line::new(options.min_value), + Line::new(options.max_value), + ) } } - launch_unary::(input, |_| { + impl NumericUnaryOpFamily for ClampOp { + type Options = Options; + type Unary = Self; + } + + launch_unary_numeric::(input, |_| { OptionsLaunch::new(ScalarArg::new(min_value), ScalarArg::new(max_value)) }) } diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index afa8ecd6fa..660ae2f6fd 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -5,13 +5,15 @@ mod comparison; mod contiguous; mod index; mod mask; -mod unary; +mod unary_float; +mod unary_numeric; pub(crate) use binary::*; pub use cast::*; pub use contiguous::*; pub use mask::*; -pub(crate) use unary::*; +pub(crate) use unary_float::*; +pub(crate) use unary_numeric::*; pub use cubecl::{Kernel, PLANE_DIM_APPROX}; diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs deleted file mode 100644 index 09f9c77689..0000000000 --- a/crates/burn-jit/src/kernel/unary.rs +++ /dev/null @@ -1,158 +0,0 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use cubecl::{ - calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, unexpanded, -}; - -#[cube] -pub(crate) trait UnaryOp: 'static + Send + Sync { - type Options: LaunchArg; - - /// Execute a unary operation. - fn execute(_input: Line, _options: &Self::Options) -> Line { - unexpanded!(); - } -} - -#[cube(launch)] -pub(crate) fn unary_kernel>( - input: &Tensor>, - output: &mut Tensor>, - options: &O::Options, - #[comptime] rank: Option, - #[comptime] to_contiguous: bool, -) { - let offset_output = ABSOLUTE_POS; - - if offset_output >= output.len() { - return; - } - - if to_contiguous { - let offset_input = index_offset_with_layout::( - input, - output, - offset_output, - 0, - rank.unwrap_or_else(|| output.rank()), - rank.is_some(), - ); - - output[offset_output] = O::execute(input[offset_input], options); - } else { - output[offset_output] = O::execute(input[offset_output], options); - } -} - -pub(crate) fn launch_unary, F>( - tensor: JitTensor, - options: F, -) -> JitTensor -where - // Magic fix for lifetime, the closure is supposed to capture everything required to create the - // argument. - for<'a> F: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, -{ - let ndims = tensor.shape.num_dims(); - // Vectorization is only enabled when the last dimension is contiguous. - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); - - let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); - - let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let is_contiguous = tensor.is_contiguous(); - - if tensor.can_mut() && tensor.is_contiguous_buffer() { - unary_kernel::launch::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - options(&()), - None, - false, - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - unary_kernel::launch::( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - options(&()), - Some(ndims as u32), - !is_contiguous, - ); - output - } -} - -macro_rules! unary_op { - ($name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = (); - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - _options: ::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input) - } - } - }; - (scalar $name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = C; - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - scalar: C::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input, scalar) - } - } - }; - (float($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Float, $exp); - launch_unary::($tensor, |_| ()) - }}; - (int($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Numeric, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; - (float($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Float, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; -} - -pub(crate) use unary_op; diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs new file mode 100644 index 0000000000..33a311ecbc --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -0,0 +1,181 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: FloatUnaryOp>; +} + +#[cube] +pub(crate) trait FloatUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_float( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_float(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Float, + O: FloatUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_float::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_float::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +/// Use comptime enum to implement all unary operations that don't have any input argument in the +/// kernel definition. +pub(crate) mod unary_basic { + use crate::execute_with_dtype; + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicFloatUnaryKind, + { + execute_with_dtype!( + float(tensor.dtype), + F, + launch_unary_float::(tensor, |input| { + BasicFloatUnaryOptionsLaunch::new(args(input)) + }) + ) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicFloatUnaryKind { + Exp, + Log, + Log1p, + Sqrt, + Abs, + Cos, + Sin, + Tanh, + Round, + Floor, + Ceil, + Erf, + Recip, + } + + #[derive(CubeLaunch)] + struct BasicFloatUnaryOptions { + #[cube(comptime)] + kind: BasicFloatUnaryKind, + } + struct BasicFloatUnary; + + #[cube] + impl FloatUnaryOp for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicFloatUnaryKind::Exp => Line::exp(input), + BasicFloatUnaryKind::Log => Line::log(input), + BasicFloatUnaryKind::Log1p => Line::log1p(input), + BasicFloatUnaryKind::Sqrt => Line::sqrt(input), + BasicFloatUnaryKind::Abs => Line::abs(input), + BasicFloatUnaryKind::Cos => Line::cos(input), + BasicFloatUnaryKind::Sin => Line::sin(input), + BasicFloatUnaryKind::Tanh => Line::tanh(input), + BasicFloatUnaryKind::Round => Line::round(input), + BasicFloatUnaryKind::Floor => Line::floor(input), + BasicFloatUnaryKind::Ceil => Line::ceil(input), + BasicFloatUnaryKind::Erf => Line::erf(input), + BasicFloatUnaryKind::Recip => Line::recip(input), + } + } + } + + impl FloatUnaryOpFamily for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs new file mode 100644 index 0000000000..0b8dcb2cbc --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -0,0 +1,106 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: NumericUnaryOp>; +} + +#[cube] +pub(crate) trait NumericUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_numeric( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_numeric(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Numeric, + O: NumericUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_numeric::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_numeric::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 2dc8a4a6f2..6090e895ae 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -1,6 +1,9 @@ use super::{expand, numeric, permute}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; -use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; +use crate::kernel::unary_basic::BasicFloatUnaryKind; +use crate::kernel::{ + self, launch_unary_float, reduce, unary_basic, FloatUnaryOp, FloatUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::matmul::{matmul, MatmulStrategy}, @@ -389,185 +392,75 @@ where } fn float_exp(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::exp(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Exp) } fn float_log(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log1p(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log1p) } fn float_powf_scalar(lhs: FloatTensor, rhs: f32) -> FloatTensor { + struct Powf; + + #[cube] + impl FloatUnaryOp for Powf { + type Options = F; + + fn execute(input: Line, options: &Self::Options) -> Line { + Line::powf(input, Line::new(*options)) + } + } + + impl FloatUnaryOpFamily for Powf { + type Options = F; + type Unary = Self; + } + execute_with_dtype!( float(lhs.dtype), F, - unary_op!(float(lhs, rhs.elem::()) => |context, tensor, scalar| { - #[cube] - fn execute(input: Line, scalar: C) -> Line { - Line::powf(input, Line::new(scalar)) - } - execute::expand::(context, tensor, scalar) - }) + launch_unary_float::(lhs, |_| ScalarArg::new(rhs.elem::())) ) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sqrt(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sqrt) } fn float_abs(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::abs(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Abs) } fn float_cos(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::cos(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Cos) } fn float_sin(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sin(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sin) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::tanh(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Tanh) } fn float_round(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::round(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Round) } fn float_floor(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::floor(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Floor) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::ceil(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Ceil) } fn float_erf(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::erf(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Erf) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { @@ -603,17 +496,7 @@ where } fn float_recip(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::recip(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Recip) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 25bb92521f..a0e181a9c7 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,5 @@ use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -229,13 +229,23 @@ where } fn int_abs(tensor: IntTensor) -> IntTensor { - unary_op!(int(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { + struct Abs; + + #[cube] + impl NumericUnaryOp for Abs { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { Line::abs(input) } - execute::expand::(context, tensor) - }) + } + + impl NumericUnaryOpFamily for Abs { + type Options = (); + type Unary = Self; + } + + launch_unary_numeric::(tensor, |_| ()) } fn int_into_float(tensor: IntTensor) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 5632425198..d0d5be8468 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -137,5 +137,5 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) } pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_binop::(lhs, rhs) + launch_binop::>(lhs, rhs) } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index e114b2f8e6..b586c4a6b7 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,5 +1,5 @@ use crate::element::JitElement; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; use crate::JitRuntime; use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; @@ -314,15 +314,29 @@ where /// Copy the current tensor. pub fn copy(&self) -> Self { - execute_with_dtype!(self.dtype, E, { - unary_op!(numeric(self.clone()) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - input - } - execute::expand::(context, tensor) - }) - }) + struct Copy; + + #[cube] + impl NumericUnaryOp for Copy { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { + input + } + } + + impl NumericUnaryOpFamily for Copy { + type Options = (); + type Unary = Self; + } + + let tensor = self.clone(); + + execute_with_dtype!( + tensor.dtype, + E, + launch_unary_numeric::(tensor, |_| ()) + ) } /// Check if the tensor is safe to mutate.