Skip to content

Commit

Permalink
Autodiff/training support for Nearest Interpolation (#1414)
Browse files Browse the repository at this point in the history
Add training support for nearest interpolation

---------

Co-authored-by: yurzhang <[email protected]>
Co-authored-by: Dilshod Tadjibaev <[email protected]>
  • Loading branch information
3 people authored Mar 6, 2024
1 parent 0601dc7 commit 0c92c8c
Show file tree
Hide file tree
Showing 14 changed files with 469 additions and 8 deletions.
52 changes: 49 additions & 3 deletions crates/burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,57 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
}

fn interpolate(
_x: AutodiffTensor<B, 4>,
x: AutodiffTensor<B, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> AutodiffTensor<B, 4> {
#[derive(Debug)]
struct Interpolate;
impl<B: Backend> Backward<B, 4, 1> for Interpolate {
type State = (NodeID, [usize; 2], InterpolateOptions);

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);

let (x_state, output_size, options) = ops.state;
let state = checkpointer.retrieve_node_output(x_state);

if let Some(node) = node_parent {
let grad = B::interpolate_backward(state, grad, output_size, options);
grads.register::<B, 4>(node, grad);
}
}
}

match Interpolate
.prepare::<C>([x.node.clone()], [x.graph.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let output = B::interpolate(x.primitive.clone(), output_size, options.clone());
prep.finish((x_state, output_size, options), output)
}
OpsKind::UnTracked(prep) => {
prep.finish(B::interpolate(x.primitive, output_size, options))
}
}
}

fn interpolate_backward(
_x: FloatTensor<Autodiff<B, C>, 4>,
_grad: FloatTensor<Autodiff<B, C>, 4>,
_output_size: [usize; 2],
_options: InterpolateOptions,
) -> AutodiffTensor<B, 4> {
unimplemented!()
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<4> {
panic!("Can't differentiate interpolate backward.");
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod maxpool1d;
mod maxpool2d;
mod mul;
mod multithread;
mod nearest_interpolate;
mod neg;
mod nonzero;
mod permute;
Expand Down Expand Up @@ -77,6 +78,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_adaptive_avg_pool1d!();
burn_autodiff::testgen_ad_adaptive_avg_pool2d!();
burn_autodiff::testgen_module_backward!();
burn_autodiff::testgen_ad_nearest_interpolate!();

// Tensor
burn_autodiff::testgen_ad_complex!();
Expand Down
100 changes: 100 additions & 0 deletions crates/burn-autodiff/src/tests/nearest_interpolate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#[burn_tensor_testgen::testgen(ad_nearest_interpolate)]
mod tests {
use super::*;
use burn_tensor::module::interpolate;
use burn_tensor::ops::{InterpolateMode, InterpolateOptions};
use burn_tensor::{Data, Shape, Tensor};

#[test]
fn test_upsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 2,
channels: 1,
height: 7,
width: 5,
height_out: 8,
width_out: 7,
};

test.assert_output(TestTensor::from([
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
[[
[4., 2., 4., 2., 2.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
[2., 1., 2., 1., 1.],
]],
]));
}

#[test]
fn test_downsample_interpolation() {
let test = InterpolateTestCase {
batch_size: 1,
channels: 1,
height: 8,
width: 8,
height_out: 4,
width_out: 6,
};

test.assert_output(TestTensor::from([[[
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 1., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
]]]));
}

struct InterpolateTestCase {
batch_size: usize,
channels: usize,
height: usize,
width: usize,
height_out: usize,
width_out: usize,
}

impl InterpolateTestCase {
fn assert_output(self, x_grad: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &x_grad.device())
.reshape(shape_x)
.into_data()
.convert(),
&device,
)
.require_grad();

let output = interpolate(
x.clone(),
[self.height_out, self.width_out],
InterpolateOptions::new(InterpolateMode::Nearest),
);

let grads = output.backward();
let x_grad_actual = x.grad(&grads).unwrap();

x_grad
.to_data()
.assert_approx_eq(&x_grad_actual.into_data(), 3);
}
}
}
9 changes: 9 additions & 0 deletions crates/burn-candle/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I

CandleTensor::new(tensor)
}

fn interpolate_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
panic!("interpolate_backward is not supported by Candle")
}
}
41 changes: 41 additions & 0 deletions crates/burn-fusion/src/ops/module.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::stream::InterpolateBackwardDescription;
use crate::{
client::FusionClient,
stream::{
Expand Down Expand Up @@ -1013,4 +1014,44 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {

out
}

fn interpolate_backward(
x: FloatTensor<Self, 4>,
grad: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
make_ops!(
InterpolateBackwardOps,
InterpolateBackwardDescription,
|args: InterpolateBackwardDescription, handles: &mut HandleContainer<B>| {
let x = handles.get_float_tensor(&args.x);
let grad = handles.get_float_tensor(&args.grad);
let output =
B::interpolate_backward(x, grad, args.output_size, args.options.clone().into());

handles.register_float_tensor(&args.out.id, output);
}
);

let stream_1 = x.stream;
let stream_2 = grad.stream;
let out = x.client.tensor_uninitialized(x.shape.clone());

let desc = InterpolateBackwardDescription {
x: x.into_description(),
grad: grad.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};
out.client.register(
vec![stream_1, stream_2],
OperationDescription::Module(
crate::stream::ModuleOperationDescription::InterpolateBackward(desc.clone()),
),
InterpolateBackwardOps::new(desc),
);
out
}
}
17 changes: 13 additions & 4 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use super::{
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
EmbeddingBackwardDescription, EmbeddingDescription, FloatOperationDescription,
GatherOperationDescription, IntOperationDescription, InterpolateDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
GatherOperationDescription, IntOperationDescription, InterpolateBackwardDescription,
InterpolateDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
Expand Down Expand Up @@ -322,6 +322,15 @@ impl ModuleOperationDescription {
out: desc.out.to_relative(converter),
})
}
ModuleOperationDescription::InterpolateBackward(desc) => {
ModuleOperationDescription::InterpolateBackward(InterpolateBackwardDescription {
x: desc.x.to_relative(converter),
grad: desc.grad.to_relative(converter),
output_size: desc.output_size,
options: desc.options.clone(),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-fusion/src/stream/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ pub enum ModuleOperationDescription {
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
/// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate).
Interpolate(InterpolateDescription),
/// Operation corresponding to [interpolate backward](burn_tensor::ops::ModuleOps::interpolate_backward).
InterpolateBackward(InterpolateBackwardDescription),
}

/// Basic operations that can be done on any tensor type.
Expand Down Expand Up @@ -982,6 +984,16 @@ impl From<InterpolateOptions> for InterpolateOptionsDescription {
}
}

#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateBackwardDescription {
pub x: TensorDescription,
pub grad: TensorDescription,
pub output_size: [usize; 2],
pub options: InterpolateOptionsDescription,
pub out: TensorDescription,
}

impl OperationDescription {
/// Cleanup the remaining tensor handles that have not been used.
pub(crate) fn nodes(&self) -> Vec<&TensorDescription> {
Expand Down Expand Up @@ -1279,9 +1291,13 @@ impl ModuleOperationDescription {
ModuleOperationDescription::Interpolate(desc) => {
vec![&desc.x, &desc.out]
}
ModuleOperationDescription::InterpolateBackward(desc) => {
vec![&desc.x, &desc.out, &desc.grad]
}
}
}
}

impl core::hash::Hash for RandomOperationDescription {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.out.hash(state);
Expand All @@ -1294,6 +1310,7 @@ impl core::hash::Hash for RandomOperationDescription {
}
}
}

impl<E> core::hash::Hash for ScalarOperationDescription<E> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.lhs.hash(state);
Expand Down
37 changes: 37 additions & 0 deletions crates/burn-ndarray/src/ops/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,43 @@ pub(crate) fn nearest_interpolate<E: FloatNdArrayElement>(
NdArrayTensor::new(output.into_dyn().into_shared())
}

pub(crate) fn nearest_interpolate_backward<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
grad: NdArrayTensor<E, 4>,
output_size: [usize; 2],
) -> NdArrayTensor<E, 4> {
let [batch_size, channels, input_height, input_width] = x.shape().dims;
let [output_height, output_width] = output_size;

let mut output_grad =
Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);

run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;

let output_grad = unsafe_shared_out.get();

for oh in 0..output_height {
for ow in 0..output_width {
let ih = start_index(oh, output_height, input_height);
let iw = start_index(ow, output_width, input_width);

output_grad[[b, c, ih, iw]] += grad.array[[b, c, oh, ow]]
}
}
})
});

NdArrayTensor::new(output_grad.into_dyn().into_shared())
}

fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
libm::floorf((output_size_index as f32 * input_size as f32) / output_size as f32) as usize
}

pub(crate) fn bilinear_interpolate<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
output_size: [usize; 2],
Expand Down
18 changes: 18 additions & 0 deletions crates/burn-ndarray/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
use crate::ops::interpolate::nearest_interpolate_backward;
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArray};
use burn_tensor::ops::*;

Expand Down Expand Up @@ -112,4 +113,21 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
InterpolateMode::Bicubic => bicubic_interpolate(x, output_size),
}
}

fn interpolate_backward(
x: NdArrayTensor<E, 4>,
grad: NdArrayTensor<E, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> NdArrayTensor<E, 4> {
match options.mode {
InterpolateMode::Nearest => nearest_interpolate_backward(x, grad, output_size),
InterpolateMode::Bilinear => {
panic!("bilinear interpolation backward is not supported for ndarray backend")
}
InterpolateMode::Bicubic => {
panic!("bicubic interpolation backward is not supported for ndarray backend")
}
}
}
}
Loading

0 comments on commit 0c92c8c

Please sign in to comment.