diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index c9a874209e..0b858c6da6 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -237,6 +237,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` | | `tensor.sum()` | `tensor.sum()` | | `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` | + | `tensor.cumsum(dim)` | `tensor.cumsum(dim)` | | `tensor.topk(k, dim)` | `tensor.topk(k, dim).values` | | `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` | | `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` | diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 29ae1a03e3..63b8dd42fb 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -154,6 +154,10 @@ impl IntTensorOps for Autodiff { B::int_sum_dim(tensor, dim) } + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_cumsum(tensor, dim) + } + fn int_mean(tensor: IntTensor) -> IntTensor { B::int_mean(tensor) } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index b5626167ec..541f9fb269 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1578,6 +1578,44 @@ impl FloatTensorOps for Autodiff } } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + #[derive(Debug)] + struct CumSum; + + impl Backward for CumSum { + type State = usize; + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let dim = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let cumsum = B::float_cumsum(grad.clone(), dim); + B::float_flip(cumsum.clone(), &[dim]) + }); + } + } + + match CumSum + .prepare::([tensor.node]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(prep) => prep.finish( + dim, + B::float_cumsum(tensor.primitive, dim), + ), + OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)), + } + } + fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmax(tensor.primitive, dim) } diff --git a/crates/burn-autodiff/src/tests/cumsum.rs b/crates/burn-autodiff/src/tests/cumsum.rs new file mode 100644 index 0000000000..a9e1d19e5c --- /dev/null +++ b/crates/burn-autodiff/src/tests/cumsum.rs @@ -0,0 +1,20 @@ +#[burn_tensor_testgen::testgen(ad_cumsum)] +mod tests { + use super::*; + use burn_tensor::{loss, Tensor, TensorData}; + + #[test] + fn should_diff_cumsum() { + let device = Default::default(); + let tensor_0 = TestAutodiffTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device).require_grad(); + + let dim = 1; + let tensor_1 = tensor_0.clone().cumsum(dim); + + let grads = tensor_1.backward(); + + let grad_0 = tensor_0.grad(&grads).unwrap(); + let grad_0_expected = TensorData::from([[3., 2., 1.], [3., 2., 1.]]); + grad_0.into_data().assert_approx_eq(&grad_0_expected, 2); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index af583d4ae1..df99e87695 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -21,6 +21,8 @@ mod conv_transpose2d; mod conv_transpose3d; mod cos; mod cross_entropy; + +mod cumsum; mod div; mod erf; mod exp; @@ -103,6 +105,7 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_cat!(); burn_autodiff::testgen_ad_cos!(); burn_autodiff::testgen_ad_cross_entropy_loss!(); + burn_autodiff::testgen_ad_cumsum!(); burn_autodiff::testgen_ad_div!(); burn_autodiff::testgen_ad_erf!(); burn_autodiff::testgen_ad_exp!(); diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 3502bea5c3..28be228210 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -321,6 +321,10 @@ impl IntTensorOps for Candle(tensor: IntTensor, dim: usize) -> IntTensor { + CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) + } + fn int_prod(tensor: IntTensor) -> IntTensor { todo!("prod is not implemented for Candle IntTensor (see https://github.com/tracel-ai/burn/issues/1454)") } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index b50dfc193e..5172ca7954 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -373,6 +373,13 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap()) } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) + } + fn float_mean_dim( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 676992ccb3..04f64257f8 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -1335,6 +1335,32 @@ impl FloatTensorOps for Fusion { out } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + scalar_float_ops!(CumSumOps, B::float_cumsum, usize, noconvert); + + let stream = tensor.stream; + let shape = tensor.shape.clone(); + let out = tensor + .client + .tensor_uninitialized(shape, B::FloatElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: tensor.into_description(), + rhs: dim, + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::NumericFloat(NumericOperationDescription::CumSum(desc.clone())), + CumSumOps::::new(desc), + ); + + out + } + fn float_mean(tensor: FloatTensor) -> FloatTensor { unary_float_ops!(MeanOps, B::float_mean, reduce); diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index e6012a07da..e859232b52 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -628,6 +628,13 @@ impl RelativeOpsScalar for NumericOperationDescription { out: desc.out.to_relative(converter), }) } + NumericOperationDescription::CumSum(desc) => { + NumericOperationDescription::CumSum(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, // Dim should stay the same. + out: desc.out.to_relative(converter), + }) + } NumericOperationDescription::Prod(desc) => { NumericOperationDescription::Prod(UnaryOperationDescription { input: desc.input.to_relative(converter), diff --git a/crates/burn-jit/src/kernel/accumulate/base.rs b/crates/burn-jit/src/kernel/accumulate/base.rs new file mode 100644 index 0000000000..74b117dd9b --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/base.rs @@ -0,0 +1,91 @@ +#[cfg(feature = "autotune")] +use crate::kernel::accumulate::accumulate_dim_autotune; +use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; + +use super::{ + naive::{base::AccumulateDimNaive, shader::accumulate_dim_naive}, + shared::{base::AccumulateDimShared, shader::accumulate_dim_shared}, +}; + +#[allow(dead_code)] +pub(crate) trait AccumulateDimAlgorithm: +AccumulateDimNaive + AccumulateDimShared +{ +} + +/// Creates an empty output tensor with accumulate output shape +pub fn init_accumulate_output( + input: &JitTensor, +) -> JitTensor { + let mut shape_out = input.shape.clone(); + + // Create output handle + let num_elems_output = shape_out.num_elements(); + let handle = input + .client + .empty(num_elems_output * core::mem::size_of::()); + JitTensor::new_contiguous( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + handle, + ) +} + +#[derive(Copy, Clone, Debug)] +#[allow(missing_docs)] +pub enum AccumulateStrategy { + Naive, + SharedMemory, + #[cfg(feature = "autotune")] + Autotune, +} + +impl Default for AccumulateStrategy { + fn default() -> Self { + // if autotune is enabled, default to autotune + #[cfg(feature = "autotune")] + return AccumulateStrategy::Autotune; + + #[cfg(not(feature = "autotune"))] + AccumulateStrategy::Naive + } +} + +#[cfg(feature = "autotune")] +#[cfg(not(feature = "autotune"))] +impl Default for AccumulateStrategy { + fn default() -> Self { + AccumulateStrategy::Naive + } +} + +macro_rules! accumulate_operation { + ($name:ident, $ops:ident) => { + pub(crate) struct $ops; + impl AccumulateDimAlgorithm for $ops {} + + /// Executes the accumulate operation with the given strategy. + pub fn $name( + tensor: JitTensor, + dim: usize, + strategy: AccumulateStrategy, + ) -> JitTensor { + match strategy { + AccumulateStrategy::Naive => { + let output = init_accumulate_output(&tensor, dim); + accumulate_dim_naive::<$ops, R, EI, EO, D>(tensor, output, dim) + } + AccumulateStrategy::SharedMemory => { + let output = init_accumulate_output(&tensor, dim); + accumulate_dim_shared::<$ops, R, EI, EO, D>(tensor, output, dim) + } + #[cfg(feature = "autotune")] + AccumulateStrategy::Autotune => accumulate_dim_autotune::<$ops, R, EI, EO, D>(tensor, dim), + } + } + }; +} + +// Autotunable reduce operation variants +accumulate_operation!(cumsum, CumSum); diff --git a/crates/burn-jit/src/kernel/accumulate/mod.rs b/crates/burn-jit/src/kernel/accumulate/mod.rs new file mode 100644 index 0000000000..675510eb55 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/mod.rs @@ -0,0 +1,7 @@ +//! Code for accumulate kernels +//! +//! Accumulate is similar to reduce but the output shape is the same as the input shape. +//! Each element in the output contains the accumulated value up to that point. +mod base; +mod naive; +mod shared; \ No newline at end of file diff --git a/crates/burn-jit/src/kernel/accumulate/naive/base.rs b/crates/burn-jit/src/kernel/accumulate/naive/base.rs new file mode 100644 index 0000000000..5659fba245 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/naive/base.rs @@ -0,0 +1,32 @@ +use cubecl::ir::{Item, Scope, Variable}; + +use crate::JitElement; + +/// Specifies the accumulate dim algorithm in use +pub trait AccumulateDimNaive: Send + Sync + 'static { + /// The accumulator + type Accumulator: Copy; + + /// Initialization for naive algorithm + fn initialize_naive( + scope: &mut Scope, + input_item: Item, + output_item: Item, + ) -> Self::Accumulator; + + /// Inner loop for naive algorithm + fn inner_loop_naive( + scope: &mut Scope, + accumulator: Self::Accumulator, + current_value: Variable, + i: Variable, + ); + + /// Assignation for naive algorithm + fn assign_naive( + scope: &mut Scope, + output: Variable, + accumulator: Self::Accumulator, + shape_reduce_dim: Variable, + ); +} diff --git a/crates/burn-jit/src/kernel/accumulate/naive/mod.rs b/crates/burn-jit/src/kernel/accumulate/naive/mod.rs new file mode 100644 index 0000000000..2b36d11cd3 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/naive/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod base; +pub(crate) mod shader; \ No newline at end of file diff --git a/crates/burn-jit/src/kernel/accumulate/naive/shader.rs b/crates/burn-jit/src/kernel/accumulate/naive/shader.rs new file mode 100644 index 0000000000..cf3f0680c1 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/naive/shader.rs @@ -0,0 +1,172 @@ +use cubecl::{ + cpa, + frontend::TensorHandleRef, + ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, + OutputInfo, +}; +use std::marker::PhantomData; + +use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; + +use super::base::AccumulateDimNaive; + +pub(crate) struct NaiveAccumulateDimComputeShader> { + tensor: Variable, + dim: usize, + output: Variable, + _accumulate_dim: PhantomData, + _elem: PhantomData, +} + +#[derive(new)] +pub(crate) struct NaiveAccumulateDimEagerKernel< + RD: AccumulateDimNaive, + R: JitRuntime, + EI: JitElement, + EO: JitElement, +> { + dim: usize, + accumulate_dim: PhantomData, + _runtime: PhantomData, + _elem_in: PhantomData, + _elem_out: PhantomData, +} + +impl, R: JitRuntime, EI: JitElement, EO: JitElement> Kernel +for NaiveReduceDimEagerKernel +{ + fn define(&self) -> KernelDefinition { + let mut scope = Scope::root(); + let item_input = EI::cube_elem().into(); + let item_output = EO::cube_elem().into(); + + let tensor = Variable::GlobalInputArray { + id: 0, + item: item_input, + }; + let output = Variable::GlobalOutputArray { + id: 0, + item: item_output, + }; + + NaiveAccumulateDimComputeShader { + tensor, + dim: self.dim, + output, + _accumulate_dim: PhantomData::, + _elem: PhantomData::, + } + .expand(&mut scope); + + scope.write_global_custom(output); + + let tensor = InputInfo::Array { + item: item_input, + visibility: Visibility::Read, + }; + + let out = OutputInfo::Array { item: item_output }; + + let info = KernelExpansion { + inputs: vec![tensor], + outputs: vec![out], + scope, + }; + + let settings = KernelSettings::default(); + KernelIntegrator::new(info).integrate(settings) + } + + fn id(&self) -> cubecl::KernelId { + cubecl::KernelId::new::().info(self.dim) + } +} + +impl> NaiveAccumulateDimComputeShader { + pub(crate) fn expand(self, scope: &mut Scope) { + let tensor = self.tensor; + let dim: Variable = self.dim.into(); + let id = Variable::AbsolutePos; + let output = self.output; + + let offset_input = scope.zero(Elem::UInt); + let stride_input_dim = scope.create_local(Elem::UInt); + let shape_input_dim = scope.create_local(Elem::UInt); + + cpa!( + scope, + range(0u32, Variable::Rank).for_each(|i, scope| { + let stride_input = scope.create_local(Elem::UInt); + let stride_output = scope.create_local(Elem::UInt); + let shape_output = scope.create_local(Elem::UInt); + + cpa!(scope, stride_input = stride(tensor, i)); + cpa!(scope, stride_output = stride(output, i)); + cpa!(scope, shape_output = shape(output, i)); + + let offset_local = scope.create_local(Elem::UInt); + cpa!(scope, offset_local = id / stride_output); + cpa!(scope, offset_local = offset_local % shape_output); + + let is_dim_accumulate = scope.create_local(Elem::Bool); + cpa!(scope, is_dim_accumulate = i == dim); + + cpa!(scope, if(is_dim_accumulate).then(|scope|{ + cpa!(scope, shape_input_dim = shape(tensor, i)); + cpa!(scope, stride_input_dim = stride_input); + cpa!(scope, offset_input += offset_local); + }).else(|scope|{ + cpa!(scope, offset_local = offset_local * stride_input); + cpa!(scope, offset_input += offset_local); + })); + }) + ); + + let accumulator = RD::initialize_naive(scope, tensor.item(), output.item()); + + cpa!( + scope, + range(0u32, shape_input_dim).for_each(|i, scope| { + let index = scope.create_local(Elem::UInt); + cpa!(scope, index = i * stride_input_dim); + cpa!(scope, index += offset_input); + let value = scope.create_local(tensor.item()); + cpa!(scope, value = tensor[index]); + RD::inner_loop_naive(scope, accumulator, value, i); + }) + ); + + RD::assign_naive(scope, output, accumulator, shape_input_dim); + } +} + +/// Executes the naive kernel for accumulate dim +pub fn accumulate_dim_naive< + RD: AccumulateDimNaive, + R: JitRuntime, + EI: JitElement, + EO: JitElement, + const D: usize, +>( + input: JitTensor, + output: JitTensor, + dim: usize, +) -> JitTensor { + let kernel = NaiveAccumulateDimEagerKernel::::new(dim); + + Execution::start(kernel, input.client) + .inputs(&[TensorHandleRef::::new( + &input.handle, + &input.strides, + &input.shape.dims, + )]) + .outputs(&[TensorHandleRef::new( + &output.handle, + &output.strides, + &output.shape.dims, + )]) + .execute(CubeCountSettings::Output { pos: 0 }); + + output +} diff --git a/crates/burn-jit/src/kernel/accumulate/shared/base.rs b/crates/burn-jit/src/kernel/accumulate/shared/base.rs new file mode 100644 index 0000000000..294b2f1371 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/shared/base.rs @@ -0,0 +1,49 @@ +use cubecl::ir::{Item, Scope, Variable}; + +use crate::JitElement; + +/// Specifies the accumulate dim algorithm in use +pub trait AccumulateDimShared: Send + Sync + 'static { + /// The accumulator + type Accumulator: Copy; + + /// Initialization for shared algorithm + fn initialize_shared( + scope: &mut Scope, + shared_memory_size: u32, + write_position: Variable, + input_item: Item, + ) -> Self::Accumulator; + + /// How to write to shared memory + fn write_to_shared( + scope: &mut Scope, + shared_memory: Self::Accumulator, + write_position: Variable, + value: Self::Accumulator, + ); + + /// How to read from input in shared algorithm + fn read_from_input( + scope: &mut Scope, + input: Variable, + read_position: Variable, + i: Variable, + ) -> Self::Accumulator; + + /// How to read from shared memory + fn read_from_shared( + scope: &mut Scope, + shared_memory: Self::Accumulator, + read_position: Variable, + ) -> Self::Accumulator; + + /// How to assign from shared memory + fn assign_shared( + scope: &mut Scope, + shared_memory: Self::Accumulator, + output: Variable, + write_position: Variable, + shape_reduce_dim: Variable, + ); +} diff --git a/crates/burn-jit/src/kernel/accumulate/shared/mod.rs b/crates/burn-jit/src/kernel/accumulate/shared/mod.rs new file mode 100644 index 0000000000..2b36d11cd3 --- /dev/null +++ b/crates/burn-jit/src/kernel/accumulate/shared/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod base; +pub(crate) mod shader; \ No newline at end of file diff --git a/crates/burn-jit/src/kernel/accumulate/shared/shader.rs b/crates/burn-jit/src/kernel/accumulate/shared/shader.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index cb3bbb8a19..1e48c51bcd 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -27,6 +27,7 @@ pub mod pool; pub mod prng; /// Reduction algorithms pub mod reduce; +mod accumulate; pub(crate) use clamp::*; pub(crate) use comparison::*; diff --git a/crates/burn-jit/src/kernel/reduce/naive/cumsum.rs b/crates/burn-jit/src/kernel/reduce/naive/cumsum.rs new file mode 100644 index 0000000000..71591478dd --- /dev/null +++ b/crates/burn-jit/src/kernel/reduce/naive/cumsum.rs @@ -0,0 +1,29 @@ +use crate::{kernel::reduce::SumDim, JitElement}; +use cubecl::{ + cpa, + ir::{Item, Scope, Variable}, +}; + +use super::base::ReduceDimNaive; + +impl ReduceDimNaive for SumDim { + type Accumulator = Variable; + + fn initialize_naive(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable { + scope.zero(output_item) + } + + fn inner_loop_naive(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) { + cpa!(scope, accumulator += value); + } + + fn assign_naive( + scope: &mut Scope, + output: Variable, + accumulator: Variable, + _shape_reduce_dim: Variable, + ) { + let id = Variable::AbsolutePos; + cpa!(scope, output[id] = accumulator); + } +} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mod.rs b/crates/burn-jit/src/kernel/reduce/naive/mod.rs index 22dcfe141f..c59d9070e4 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/mod.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod argmax; pub(crate) mod argmin; pub(crate) mod base; +pub(crate) mod cumsum; pub(crate) mod mean_dim; pub(crate) mod prod_dim; pub(crate) mod shader; diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index d025ee3bcf..58e5e42590 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -310,6 +310,13 @@ where reduce::sum_dim(tensor, dim, Default::default()) } + fn float_cumsum( + tensor: FloatTensor, + dim: usize, + ) -> FloatTensor { + reduce::cumsum(tensor, dim, Default::default()) + } + fn float_mean_dim( tensor: FloatTensor, dim: usize, diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 5cc0573130..271c32fc07 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -3,7 +3,7 @@ use burn_tensor::ElementConversion; use burn_tensor::TensorData; use core::fmt::Debug; use core::{marker::PhantomData, ops::Range}; -use ndarray::s; +use ndarray::{ArcArray, s}; use ndarray::Array2; use ndarray::IntoDimension; use ndarray::SliceInfo; @@ -291,6 +291,16 @@ where } } + pub fn cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + let mut new = tensor.array.to_owned(); + new.accumulate_axis_inplace(Axis(dim), |&prev, curr| { + *curr += prev; + }); + NdArrayTensor { + array: ArcArray::from(new) + } + } + pub fn prod_dim( tensor: NdArrayTensor, dim: usize, diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index bd1e78eb33..6aae40498d 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -286,6 +286,13 @@ impl IntTensorOps for NdArray( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::cumsum(tensor, dim) + } + fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { NdArrayMathOps::prod(tensor) } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 7013504b1d..8523594584 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -338,6 +338,13 @@ impl FloatTensorOps for NdArray( + tensor: NdArrayTensor, + dim: usize, + ) -> NdArrayTensor { + NdArrayMathOps::cumsum(tensor, dim) + } + fn float_argmax( tensor: NdArrayTensor, dim: usize, diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index a1d62fac00..e3415b370b 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -329,6 +329,15 @@ impl TchOps { ) } + pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchTensor::from_existing( + tensor + .tensor + .cumsum(dim as i64, E::KIND), + tensor.storage, + ) + } + pub fn prod(tensor: TchTensor) -> TchTensor { let tensor = tensor.tensor.prod(E::KIND); TchTensor::new(tensor) diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index cf53bf1c36..59857ebc72 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -269,6 +269,10 @@ impl IntTensorOps for LibTorch { TchOps::sum_dim(tensor, dim) } + fn int_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cumsum(tensor, dim) + } + fn int_prod(tensor: TchTensor) -> TchTensor { TchOps::prod(tensor) } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index e2b979ca5c..ed1b1d1f19 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -331,6 +331,10 @@ impl FloatTensorOps for LibTorch { TchOps::sum_dim(tensor, dim) } + fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cumsum(tensor, dim) + } + fn float_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::mean_dim(tensor, dim) } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 00591e20e2..4fa75bcfc4 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -321,6 +321,11 @@ pub enum NumericOperationDescription { /// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim). /// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim). SumDim(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Float => [cumsum](crate::ops::FloatTensorOps::float_cumsum). + /// Int => [cumsum](crate::ops::IntTensorOps::int_cumsum). + CumSum(ScalarOperationDescription), /// Operation corresponding to: /// diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 0f9432cc0a..1d359bc8b7 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -142,6 +142,12 @@ where Self::new(K::sum_dim(self.primitive, dim)) } + /// Aggregate all elements along the given *dimension* or *axis* with the + /// cumulative sum operation. + pub fn cumsum(self, dim: usize) -> Tensor { + Tensor::new(K::cumsum(self.primitive, dim)) + } + /// Aggregate all elements along the given *dimension* or *axis* /// in the tensor with the product operation. pub fn prod(self) -> Tensor { @@ -1162,6 +1168,27 @@ where /// which is more high-level and designed for public use. fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Computes the cumulative sum of all the elements of the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// The cumulative sum of all the elements of the tensor along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::cumsum](Tensor::cumsum) function, + /// which is more high-level and designed for public use. + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Computes the product of all the elements of the tensor. /// /// # Arguments @@ -2162,6 +2189,10 @@ impl Numeric for Int { B::int_sum_dim(tensor, dim) } + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_cumsum(tensor, dim) + } + fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { B::int_prod(tensor) } @@ -2504,6 +2535,10 @@ impl Numeric for Float { TensorPrimitive::Float(B::float_sum_dim(tensor.tensor(), dim)) } + fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + TensorPrimitive::Float(B::float_cumsum(tensor.tensor(), dim)) + } + fn prod(tensor: Self::Primitive) -> Self::Primitive<1> { TensorPrimitive::Float(B::float_prod(tensor.tensor())) } diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 28a2eb803b..007bef266e 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -752,6 +752,18 @@ pub trait IntTensorOps { /// The sum of all elements in the tensor along the dimension. fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor; + /// Computes the cumulative sum of all elements in the tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension to sum along. + /// + /// # Returns + /// + /// The cumulative sum of all elements in the tensor along the dimension. + fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor; + /// Computes the product of all elements in the tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 0edd5c8ee4..2f2daacebf 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -832,6 +832,18 @@ pub trait FloatTensorOps { /// A tensor with the sum of all elements in `tensor` along `dim`. fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Computes the cumulative sum of all elements in a tensor along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to sum. + /// * `dim` - The dimension along which to sum. + /// + /// # Returns + /// + /// A tensor with the cumulative sum of all elements in `tensor` along `dim`. + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Product of all elements in a tensor. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index b4c92a84bc..0a5169f014 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -53,6 +53,7 @@ macro_rules! testgen_all { burn_tensor::testgen_close!(); burn_tensor::testgen_cos!(); burn_tensor::testgen_create_like!(); + burn_tensor::testgen_cumsum!(); burn_tensor::testgen_div!(); burn_tensor::testgen_erf!(); burn_tensor::testgen_exp!(); diff --git a/crates/burn-tensor/src/tests/ops/cumsum.rs b/crates/burn-tensor/src/tests/ops/cumsum.rs new file mode 100644 index 0000000000..bc34f66b2d --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/cumsum.rs @@ -0,0 +1,40 @@ +#[burn_tensor_testgen::testgen(cumsum)] +mod tests { + use super::*; + use burn_tensor::{backend::Backend, Int, Tensor, TensorData}; + + #[test] + fn should_support_cumsum_ops() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let device = Default::default(); + let tensor = Tensor::::from_data(data, &device); + + let output = tensor.clone().cumsum(0); + let expected = TensorData::from([[0.0, 1.0, 2.0], [3.0, 5.0, 7.0]]); + + output.into_data().assert_eq(&expected, false); + + let output = tensor.cumsum(1); + let expected = TensorData::from([[0.0, 1.0, 3.0], [3.0, 7.0, 12.0]]); + + output.into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_cumsum_ops_int() { + let data = TensorData::from([[0, 1, 2], [3, 4, 5]]); + let device = Default::default(); + let tensor = Tensor::::from_data(data, &device); + + let output = tensor.clone().cumsum(0); + let expected = TensorData::from([[0, 1, 2], [3, 5, 7]]); + + output.into_data().assert_eq(&expected, false); + + let output = tensor.cumsum(1); + let expected = TensorData::from([[0, 1, 3], [3, 7, 12]]); + + output.into_data().assert_eq(&expected, false); + } + +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 6299b800f4..fb8ac50aeb 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -16,6 +16,8 @@ mod clamp; mod close; mod cos; mod create_like; + +mod cumsum; mod div; mod erf; mod exp;