Skip to content

Commit

Permalink
Interpolate tensor operation (Inference Only) (#1246)
Browse files Browse the repository at this point in the history
* squash

feat: bilinear interpolation for tch, ndarray and wgpu backend

fix: reduce test case size to avoid exceeding floating-point precision limits

feat: support nearest-neighbor interpolation for ndarray backend

feat: support nearest-neighbor interpolation for wgpu backend

feat: support fusion backend

fix: no-std support

build: upgrade dependencies

* feat: bicubic interpolation for ndarray backend

* fix: test case precision

* feat: bicubic interpolation for wgpu backend

* Update Cargo.lock

---------

Co-authored-by: Dilshod Tadjibaev <[email protected]>
Co-authored-by: Aasheesh Singh <[email protected]>
  • Loading branch information
3 people authored Mar 2, 2024
1 parent 1117757 commit 7d44f0b
Show file tree
Hide file tree
Showing 24 changed files with 1,256 additions and 249 deletions.
392 changes: 164 additions & 228 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions crates/burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,14 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<4> {
panic!("Can't differentiate adaptive avg pool2d backward.");
}

fn interpolate(
_x: AutodiffTensor<B, 4>,
_output_size: [usize; 2],
_options: InterpolateOptions,
) -> AutodiffTensor<B, 4> {
unimplemented!()
}
}

#[derive(Debug)]
Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod tests {
// test module
burn_tensor::testgen_module_forward!();
burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_nearest_interpolate!();
// burn_tensor::testgen_module_conv2d!();
// burn_tensor::testgen_module_conv_transpose1d!();
// burn_tensor::testgen_module_conv_transpose2d!();
Expand Down
25 changes: 23 additions & 2 deletions crates/burn-candle/src/ops/module.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn_tensor::{
ops::{
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool2dBackward,
MaxPool2dWithIndices, ModuleOps, UnfoldOptions,
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateMode,
InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, UnfoldOptions,
},
Shape,
};
Expand Down Expand Up @@ -236,4 +236,25 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
) -> FloatTensor<Self, 4> {
panic!("adaptive_avg_pool2d_backward is not supported by Candle")
}

fn interpolate(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
let tensor = match options.mode {
InterpolateMode::Nearest => x
.tensor
.upsample_nearest2d(output_size[0], output_size[1])
.unwrap(),
InterpolateMode::Bilinear => {
panic!("bilinear interpolation is not supported by Candle")
}
InterpolateMode::Bicubic => {
panic!("bicubic interpolation is not supported by Candle")
}
};

CandleTensor::new(tensor)
}
}
49 changes: 43 additions & 6 deletions crates/burn-fusion/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use crate::{
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
ConvTranspose2dDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Operation,
OperationDescription,
ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, Operation, OperationDescription,
},
Fusion, FusionBackend, HandleContainer,
};
Expand All @@ -17,8 +17,8 @@ use burn_tensor::ops::{
calculate_conv_output_size, calculate_conv_transpose_output_size,
calculate_pool_output_size,
},
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, MaxPool1dBackward,
MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};

macro_rules! make_ops {
Expand Down Expand Up @@ -976,4 +976,41 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {

out
}

fn interpolate(
x: FloatTensor<Self, 4>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<Self, 4> {
make_ops!(
InterpolateOps,
InterpolateDescription,
|args: InterpolateDescription, handles: &mut HandleContainer<B>| {
let x = handles.get_float_tensor(&args.x);
let output = B::interpolate(x, args.output_size, args.options.clone().into());
handles.register_float_tensor(&args.out.id, output);
}
);

let stream = x.stream;
let shape = vec![x.shape[0], x.shape[1], output_size[0], output_size[1]];
let out = x.client.tensor_uninitialized(shape);

let desc = InterpolateDescription {
x: x.into_description(),
output_size,
options: options.into(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::Module(crate::stream::ModuleOperationDescription::Interpolate(
desc.clone(),
)),
InterpolateOps::new(desc),
);

out
}
}
25 changes: 17 additions & 8 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ use super::{
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
EmbeddingBackwardDescription, EmbeddingDescription, FloatOperationDescription,
GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription,
OperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
GatherOperationDescription, IntOperationDescription, InterpolateDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceOperationDescription, SwapDimsDescription,
UnaryOperationDescription,
};
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
use burn_tensor::{Element, ElementConversion};
Expand Down Expand Up @@ -313,6 +314,14 @@ impl ModuleOperationDescription {
},
)
}
ModuleOperationDescription::Interpolate(desc) => {
ModuleOperationDescription::Interpolate(InterpolateDescription {
x: desc.x.to_relative(converter),
output_size: desc.output_size,
options: desc.options.clone(),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
66 changes: 65 additions & 1 deletion crates/burn-fusion/src/stream/operation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::FusionBackend;
use crate::{HandleContainer, TensorDescription};
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions};
use burn_tensor::{Distribution, Element};
use serde::{Deserialize, Serialize};
use std::ops::Range;
Expand Down Expand Up @@ -120,6 +120,8 @@ pub enum ModuleOperationDescription {
/// Operation corresponding to
/// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward).
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
/// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate).
Interpolate(InterpolateDescription),
}

/// Basic operations that can be done on any tensor type.
Expand Down Expand Up @@ -902,6 +904,65 @@ pub struct MaxPool2dWithIndicesBackwardDescription {
pub out: TensorDescription,
}

#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum InterpolateModeDescription {
Nearest,
Bilinear,
Bicubic,
}

#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateOptionsDescription {
pub mode: InterpolateModeDescription,
}

#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateDescription {
pub x: TensorDescription,
pub output_size: [usize; 2],
pub options: InterpolateOptionsDescription,
pub out: TensorDescription,
}

impl From<InterpolateModeDescription> for InterpolateMode {
fn from(val: InterpolateModeDescription) -> Self {
match val {
InterpolateModeDescription::Nearest => Self::Nearest,
InterpolateModeDescription::Bilinear => Self::Bilinear,
InterpolateModeDescription::Bicubic => Self::Bicubic,
}
}
}

impl From<InterpolateOptionsDescription> for InterpolateOptions {
fn from(val: InterpolateOptionsDescription) -> Self {
Self {
mode: val.mode.into(),
}
}
}

impl From<InterpolateMode> for InterpolateModeDescription {
fn from(val: InterpolateMode) -> Self {
match val {
InterpolateMode::Nearest => Self::Nearest,
InterpolateMode::Bilinear => Self::Bilinear,
InterpolateMode::Bicubic => Self::Bicubic,
}
}
}

impl From<InterpolateOptions> for InterpolateOptionsDescription {
fn from(val: InterpolateOptions) -> Self {
Self {
mode: val.mode.into(),
}
}
}

impl OperationDescription {
/// Cleanup the remaining tensor handles that have not been used.
pub(crate) fn nodes(&self) -> Vec<&TensorDescription> {
Expand Down Expand Up @@ -1192,6 +1253,9 @@ impl ModuleOperationDescription {
ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc) => {
vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
}
ModuleOperationDescription::Interpolate(desc) => {
vec![&desc.x, &desc.out]
}
}
}
}
Expand Down
Loading

0 comments on commit 7d44f0b

Please sign in to comment.