Skip to content

Commit

Permalink
Add flip tensor operator (#1468)
Browse files Browse the repository at this point in the history
  • Loading branch information
carrotflakes authored Mar 19, 2024
1 parent 6e58663 commit 8911093
Show file tree
Hide file tree
Showing 37 changed files with 758 additions and 34 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_permute(tensor, axes)
}

fn bool_flip<const D: usize>(tensor: BoolTensor<B, D>, axes: &[usize]) -> BoolTensor<B, D> {
B::bool_flip(tensor, axes)
}

fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
B::bool_argwhere(tensor)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_permute(tensor, axes)
}

fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
B::int_flip(tensor, axes)
}

fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}
Expand Down
56 changes: 56 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,62 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct FlipDim;

#[derive(new, Debug)]
struct RetroFlipDims<B: Backend, const D: usize> {
input_id: NodeID,
axes: Vec<usize>,
_backend: PhantomData<B>,
}

impl<B: Backend, const D: usize> RetroForward for RetroFlipDims<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let input = states.get_state::<B::FloatTensorPrimitive<D>>(&self.input_id);
let out = B::float_flip(input, &self.axes);
states.save(out_node, out)
}
}

impl<B: Backend, const D: usize> Backward<B, D, 1> for FlipDim {
type State = Vec<usize>;

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let axes = ops.state;

unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::float_flip(grad, &axes)
});
}
}

match FlipDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroFlipDims::<B, D>::new(
tensor.node.id.clone(),
axes.to_vec(),
))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(prep) => {
prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes))
}
OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)),
}
}

fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
Expand Down
28 changes: 28 additions & 0 deletions crates/burn-autodiff/src/tests/flip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_flip)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_flip() {
let data_1: Data<f32, 3> = Data::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2: Data<f32, 3> = Data::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3

let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let tensor_3 = tensor_2.clone().flip([1, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

assert_eq!(grad_1.to_data(), Data::from([[[7.2, 12.0], [7.2, 12.0]]])); // 1x2x2
assert_eq!(
grad_2.to_data(),
Data::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]) // 1x2x3
);
}
}
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod cross_entropy;
mod div;
mod erf;
mod exp;
mod flip;
mod gather_scatter;
mod gelu;
mod gradients;
Expand Down Expand Up @@ -114,6 +115,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_permute!();
burn_autodiff::testgen_ad_flip!();
burn_autodiff::testgen_ad_nonzero!();
burn_autodiff::testgen_ad_sign!();
};
Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ mod tests {
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();

Expand Down
21 changes: 20 additions & 1 deletion crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ pub fn from_data<E: CandleElement, const D: usize>(
) -> CandleTensor<E, D> {
CandleTensor::from_data(data, *device)
}

pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> Data<E, D> {
Data::new(
tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(),
Expand Down Expand Up @@ -60,6 +59,26 @@ pub fn permute<E: CandleElement, const D: usize>(
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
}

pub fn flip<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
axes: &[usize],
) -> CandleTensor<E, D> {
// FIXME: Replace with an appropriate method when Candle provides one.
let mut tensor = tensor.tensor;
for &axis in axes {
let indexes = candle_core::Tensor::arange_step(
tensor.dim(axis).unwrap() as i64 - 1,
-1,
-1,
tensor.device(),
)
.unwrap();
tensor = tensor.index_select(&indexes, axis).unwrap();
}

CandleTensor::new(tensor)
}

pub fn reshape<E: CandleElement, const D1: usize, const D2: usize>(
tensor: CandleTensor<E, D1>,
shape: Shape<D2>,
Expand Down
11 changes: 8 additions & 3 deletions crates/burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
super::base::empty(shape, device)
Expand Down Expand Up @@ -133,6 +131,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
tensor: BoolTensor<Self, D>,
axes: [usize; D],
) -> BoolTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}

fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],
) -> BoolTensor<Self, D> {
super::base::flip(tensor, axes)
}
}
8 changes: 5 additions & 3 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
super::base::empty(shape, device)
Expand Down Expand Up @@ -425,7 +423,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}

fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
super::base::flip(tensor, axes)
}

// TODO add sign operator once Candle supports it:
Expand Down
11 changes: 8 additions & 3 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data<const D: usize>(
data: Data<F, D>,
Expand Down Expand Up @@ -522,7 +520,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}

fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
super::base::flip(tensor, axes)
}

// TODO add sign operator once Candle supports it:
Expand Down
41 changes: 38 additions & 3 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::{
ops::binary::binary_ops_shape,
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
SwapDimsDescription, UnaryOperationDescription,
CatOperationDescription, FlipOperationDescription, Operation, OperationDescription,
PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription,
},
Fusion, FusionBackend,
};
Expand Down Expand Up @@ -466,4 +466,39 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {

out
}

fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],
) -> BoolTensor<Self, D> {
#[derive(new)]
struct FlipOps<const D: usize> {
desc: FlipOperationDescription,
}

impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_flip(input, self.desc.axes.as_slice());
handles.register_bool_tensor(&self.desc.out.id, output);
}
}

let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = FlipOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
axes: axes.to_vec(),
};

out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())),
FlipOps::<D>::new(desc),
);

out
}
}
49 changes: 42 additions & 7 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use crate::{
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
stream::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, FloatOperationDescription, GatherOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
ClampOperationDescription, FlipOperationDescription, FloatOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
},
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
Expand Down Expand Up @@ -1846,4 +1846,39 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

out
}

fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
#[derive(new)]
struct FlipOps<const D: usize> {
desc: FlipOperationDescription,
}

impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::float_flip(input, &self.desc.axes);
handles.register_float_tensor(&self.desc.out.id, output);
}
}

let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());

let desc = FlipOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())),
FlipOps::<D>::new(desc),
);

out
}
}
Loading

0 comments on commit 8911093

Please sign in to comment.