Skip to content

Commit

Permalink
Add cumsum tensor op
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 6, 2025
1 parent ec8e45a commit b538f24
Show file tree
Hide file tree
Showing 33 changed files with 395 additions and 3 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.contains_nan()` | N/A |
| `tensor.cumsum(dim)` | `tensor.cumsum(dim)` |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_sum(tensor)
}

fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cumsum(tensor, dim)
}

fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_sum_dim(tensor, dim)
}
Expand Down
32 changes: 32 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,38 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
#[derive(Debug)]
struct CumSum;

impl<B: Backend> Backward<B, 1> for CumSum {
type State = usize;

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let dim = ops.state;

unary::<B, _>(ops.parents, ops.node, grads, |grad| {
let cumsum = B::float_cumsum(grad.clone(), dim);
B::float_flip(cumsum.clone(), &[dim])
});
}
}

match CumSum
.prepare::<C>([tensor.node])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(dim, B::float_cumsum(tensor.primitive, dim)),
OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)),
}
}

fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
#[derive(Debug)]
struct MeanDim;
Expand Down
22 changes: 22 additions & 0 deletions crates/burn-autodiff/src/tests/cumsum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#[burn_tensor_testgen::testgen(ad_cumsum)]
mod tests {
use super::*;
use burn_tensor::{loss, Tensor, TensorData};

#[test]
fn should_diff_cumsum() {
let device = Default::default();
let tensor_0 =
TestAutodiffTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device)
.require_grad();

let dim = 1;
let tensor_1 = tensor_0.clone().cumsum(dim);

let grads = tensor_1.backward();

let grad_0 = tensor_0.grad(&grads).unwrap();
let grad_0_expected = TensorData::from([[3., 2., 1.], [3., 2., 1.]]);
grad_0.into_data().assert_approx_eq(&grad_0_expected, 2);
}
}
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod conv_transpose2d;
mod conv_transpose3d;
mod cos;
mod cross_entropy;
mod cumsum;
mod deform_conv2d;
mod div;
mod erf;
Expand Down Expand Up @@ -188,5 +189,6 @@ macro_rules! testgen_with_float_param {
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat_dim!();
burn_autodiff::testgen_ad_cumsum!();
};
}
2 changes: 2 additions & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mod tests {
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
burn_tensor::testgen_cumsum!();

// TODO: https://github.com/tracel-ai/burn/issues/1237
//
Expand Down Expand Up @@ -175,4 +176,5 @@ mod tests {
burn_autodiff::testgen_ad_round!();
burn_autodiff::testgen_ad_floor!();
burn_autodiff::testgen_ad_ceil!();
burn_autodiff::testgen_ad_cumsum!();
}
28 changes: 28 additions & 0 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,31 @@ pub fn mask_where_broadcasted(

CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap())
}

// Taken from: https://github.com/mokeyish/candle-ext/blob/main/src/cumsum.rs
fn cumsum_ext<D: candle_core::shape::Dim>(
input: &candle_core::Tensor,
dim: D,
) -> candle_core::Result<candle_core::Tensor> {
let dim = dim.to_index(input.shape(), "cumsum")?;
let dim_size = input.dim(dim)?;

let mut tensors = Vec::with_capacity(dim_size);

let mut a = input.clone();
for i in 0..dim_size {
if i > 0 {
a = a.narrow(dim, 1, dim_size - i)?;
let b = input.narrow(dim, 0, dim_size - i)?;
a = (a + b)?;
}
tensors.push(a.narrow(dim, 0, 1)?);
}
let cumsum = candle_core::Tensor::cat(&tensors, dim)?;
Ok(cumsum)
}

/// Cumulative sum (used for int tensors since the default candle implementation uses matmul).
pub fn cumsum(tensor: CandleTensor, dim: usize) -> CandleTensor {
CandleTensor::new(cumsum_ext(&tensor.tensor, dim).unwrap())
}
4 changes: 4 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
super::base::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,4 +481,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}
}
27 changes: 27 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2263,4 +2263,31 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

out
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
scalar_float_ops!(CumsumOps, B::float_cumsum, usize, noconvert);

let stream = tensor.stream;
let dtype = tensor.dtype;
let shape = tensor.shape.clone();
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericFloat(
dtype,
NumericOperationDescription::CumSum(desc.clone()),
),
CumsumOps::<B>::new(desc),
);

out
}
}
27 changes: 27 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1819,4 +1819,31 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

out
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
scalar_int_ops!(CumsumOps, B::int_cumsum, usize, noconvert);

let stream = tensor.stream;
let dtype = tensor.dtype;
let shape = tensor.shape.clone();
let out = tensor
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::NumericInt(
dtype,
NumericOperationDescription::CumSum(desc.clone()),
),
CumsumOps::<B>::new(desc),
);

out
}
}
7 changes: 7 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
out: desc.out.to_relative(converter),
})
}
NumericOperationDescription::CumSum(desc) => {
NumericOperationDescription::CumSum(ScalarOperationDescription {
lhs: desc.lhs.to_relative(converter),
rhs: desc.rhs,
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-jit/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,4 +665,8 @@ where
_ => unimplemented!("Unsupported floating point type cast"),
}
}

fn float_cumsum(_tensor: FloatTensor<Self>, _dim: usize) -> FloatTensor<Self> {
todo!()
}
}
4 changes: 4 additions & 0 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,4 +283,8 @@ where
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
kernel::flip::<R, I, BT>(tensor, axes)
}

fn int_cumsum(_tensor: IntTensor<Self>, _dim: usize) -> IntTensor<Self> {
todo!()
}
}
10 changes: 10 additions & 0 deletions crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,16 @@ where
NdArrayTensor::from_data(data)
}

pub fn cumsum(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
let mut array = tensor.array.into_owned();
array.accumulate_axis_inplace(Axis(dim), |&prev, curr| {
*curr += prev;
});
let array = array.into_shared();

NdArrayTensor { array }
}

pub fn mean_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
let ndims = tensor.shape().num_dims();
match ndims {
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps
fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
NdArrayOps::expand(tensor, shape)
}

fn int_cumsum(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
NdArrayMathOps::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,4 +575,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
_ => panic!("Invalid cast types"),
}
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
}
}
19 changes: 19 additions & 0 deletions crates/burn-router/src/ops/op_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1491,4 +1491,23 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {

out
}

fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};

client.register(OperationDescription::NumericFloat(
dtype,
NumericOperationDescription::CumSum(desc),
));

out
}
}
19 changes: 19 additions & 0 deletions crates/burn-router/src/ops/op_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1173,4 +1173,23 @@ impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {

out
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);

let desc = ScalarOperationDescription {
lhs: tensor.into_description(),
rhs: dim,
out: out.to_description_out(),
};

client.register(OperationDescription::NumericInt(
dtype,
NumericOperationDescription::CumSum(desc),
));

out
}
}
6 changes: 6 additions & 0 deletions crates/burn-router/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
NumericOperationDescription::Powf(desc) => {
binary_float_ops!(handles, desc, B::float_powf)
}
NumericOperationDescription::CumSum(desc) => {
scalar_float_dim_ops!(handles, desc, B::float_cumsum)
}
},
OperationDescription::NumericInt(_dtype, op) => match op {
NumericOperationDescription::Add(desc) => {
Expand Down Expand Up @@ -764,6 +767,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
let output = B::int_powf(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
}
NumericOperationDescription::CumSum(desc) => {
scalar_int_dim_ops!(handles, desc, B::int_cumsum)
}
},
OperationDescription::Bool(op) => match op {
BoolOperationDescription::IntoFloat(desc) => {
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ impl TchOps {
TchTensor::new(tensor)
}

pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
tensor.storage,
)
}

pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,8 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
TchOps::argsort(tensor, dim, descending)
}

fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
TchOps::cumsum(tensor, dim)
}
}
4 changes: 4 additions & 0 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,8 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
TchTensor::new(tensor.tensor.to_kind(kind))
}
}

fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchOps::cumsum(tensor, dim)
}
}
Loading

0 comments on commit b538f24

Please sign in to comment.