Skip to content

Commit 71bd5ef

Browse files
authored
feat: resize onnx import (#1863)
* feat: resize onnx import * fix: resize import proc macro output * fix: lint * fix: simplify resize onnx * fix: onnx-tests passing * feedback: remove dead code and resolve merge conflicts
1 parent 671ec8c commit 71bd5ef

File tree

12 files changed

+344
-5
lines changed

12 files changed

+344
-5
lines changed

crates/burn-import/SUPPORTED-ONNX-OPS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ represent the corresponding Burn Op.
147147
| [ReduceSumSquare][140] |||
148148
| [Relu][141] |||
149149
| [Reshape][142] |||
150-
| [Resize][143] | ||
150+
| [Resize][143] | ||
151151
| [ReverseSequence][144] |||
152152
| [RNN][145] |||
153153
| [RoiAlign][146] |||

crates/burn-import/onnx-tests/build.rs

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ fn main() {
5656
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
5757
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
5858
.input("tests/reshape/reshape.onnx")
59+
.input("tests/resize/resize.onnx")
5960
.input("tests/shape/shape.onnx")
6061
.input("tests/sigmoid/sigmoid.onnx")
6162
.input("tests/sign/sign.onnx")

crates/burn-import/onnx-tests/tests/onnx_tests.rs

+25
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ include_models!(
6767
reduce_sum_opset11,
6868
relu,
6969
reshape,
70+
resize,
7071
shape,
7172
sigmoid,
7273
sign,
@@ -789,6 +790,30 @@ mod tests {
789790
assert_eq!(output.to_data(), expected);
790791
}
791792

793+
#[test]
794+
fn resize() {
795+
// Initialize the model without weights (because the exported file does not contain them)
796+
let device = Default::default();
797+
let model: resize::Model<Backend> = resize::Model::new(&device);
798+
799+
// Run the model
800+
let input = Tensor::<Backend, 4>::from_floats(
801+
[[[
802+
[0.0, 1.0, 2.0, 3.0],
803+
[4.0, 5.0, 6.0, 7.0],
804+
[8.0, 9.0, 10.0, 11.0],
805+
[12.0, 13.0, 14.0, 15.0],
806+
]]],
807+
&device,
808+
);
809+
let size = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 3], &device);
810+
811+
let output = model.forward(input, size);
812+
let expected = Data::from([[[[0.0, 1.5, 3.0], [12.0, 13.5, 15.0]]]]);
813+
814+
assert_eq!(output.to_data(), expected);
815+
}
816+
792817
#[test]
793818
fn shape() {
794819
let device = Default::default();
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/resize/resize.onnx
4+
5+
import onnx
6+
from onnx import helper, TensorProto
7+
8+
def main() -> None:
9+
input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4])
10+
sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4])
11+
12+
resize_node = helper.make_node(
13+
"Resize",
14+
name="resize_node",
15+
inputs=["input_tensor", "", "", "sizes"],
16+
outputs=["output"],
17+
mode="linear",
18+
)
19+
20+
graph_def = helper.make_graph(
21+
nodes=[resize_node],
22+
name="ResizeGraph",
23+
inputs=[input_tensor, sizes_tensor],
24+
outputs=[
25+
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2])
26+
],
27+
)
28+
29+
model_def = helper.make_model(graph_def, producer_name="resize")
30+
31+
onnx.save(model_def, "resize.onnx")
32+
33+
34+
if __name__ == "__main__":
35+
main()

crates/burn-import/src/burn/node/base.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use super::{
77
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
88
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
99
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
10-
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
11-
unsqueeze::UnsqueezeNode,
10+
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
11+
unary::UnaryNode, unsqueeze::UnsqueezeNode,
1212
};
1313
use crate::burn::{BurnImports, Scope, Type};
1414
use burn::backend::NdArray;
@@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
102102
MaxPool2d(MaxPool2dNode),
103103
Range(RangeNode),
104104
Reshape(ReshapeNode),
105+
Resize(ResizeNode),
105106
Slice(SliceNode),
106107
Squeeze(SqueezeNode),
107108
Sum(SumNode),
@@ -140,6 +141,7 @@ macro_rules! match_all {
140141
Node::MaxPool2d(node) => $func(node),
141142
Node::Range(node) => $func(node),
142143
Node::Reshape(node) => $func(node),
144+
Node::Resize(node) => $func(node),
143145
Node::Slice(node) => $func(node),
144146
Node::Squeeze(node) => $func(node),
145147
Node::Sum(node) => $func(node),
@@ -188,6 +190,7 @@ impl<PS: PrecisionSettings> Node<PS> {
188190
Node::MaxPool2d(_) => "max_pool2d",
189191
Node::Range(_) => "range",
190192
Node::Reshape(_) => "reshape",
193+
Node::Resize(_) => "resize",
191194
Node::Slice(_) => "slice",
192195
Node::Squeeze(_) => "squeeze",
193196
Node::Sum(_) => "add",

crates/burn-import/src/burn/node/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub(crate) mod random_normal;
2727
pub(crate) mod random_uniform;
2828
pub(crate) mod range;
2929
pub(crate) mod reshape;
30+
pub(crate) mod resize;
3031
pub(crate) mod slice;
3132
pub(crate) mod squeeze;
3233
pub(crate) mod sum;
+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{OtherType, Scope, TensorType, Type};
3+
use burn::module::Module;
4+
use burn::record::PrecisionSettings;
5+
use proc_macro2::TokenStream;
6+
use quote::quote;
7+
8+
#[derive(Module, Debug, Clone)]
9+
pub enum ResizeMode {
10+
Nearest,
11+
Linear,
12+
Cubic,
13+
}
14+
15+
#[derive(new, Module, Debug, Clone)]
16+
pub struct ResizeOptions {
17+
pub mode: ResizeMode,
18+
}
19+
20+
#[derive(Debug, Clone)]
21+
pub struct ResizeNode {
22+
pub field: OtherType,
23+
pub input: TensorType,
24+
pub output: TensorType,
25+
pub output_size: TensorType,
26+
pub config: ResizeOptions,
27+
}
28+
29+
impl ResizeNode {
30+
pub fn new<S: AsRef<str>>(
31+
name: S,
32+
input: TensorType,
33+
output: TensorType,
34+
output_size: TensorType,
35+
config: ResizeOptions,
36+
) -> Self {
37+
Self {
38+
field: OtherType::new(
39+
name,
40+
quote! {
41+
burn::module::Ignored<InterpolateOptions>
42+
},
43+
),
44+
input,
45+
output,
46+
output_size,
47+
config,
48+
}
49+
}
50+
}
51+
52+
impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
53+
fn output_types(&self) -> Vec<Type> {
54+
vec![Type::Tensor(self.output.clone())]
55+
}
56+
57+
fn input_types(&self) -> Vec<Type> {
58+
vec![
59+
Type::Tensor(self.input.clone()),
60+
Type::Tensor(self.output_size.clone()),
61+
]
62+
}
63+
64+
fn field_type(&self) -> Option<Type> {
65+
Some(Type::Other(self.field.clone()))
66+
}
67+
68+
fn field_init(&self) -> Option<TokenStream> {
69+
let name = &self.field.name;
70+
71+
let mode = match self.config.mode {
72+
ResizeMode::Linear => quote! { InterpolateMode::Bilinear },
73+
ResizeMode::Nearest => quote! { InterpolateMode::Nearest },
74+
ResizeMode::Cubic => quote! { InterpolateMode::Bicubic },
75+
};
76+
77+
let tokens = quote! {
78+
let #name = InterpolateOptions {
79+
mode: #mode,
80+
};
81+
let #name = burn::module::Ignored(#name);
82+
};
83+
84+
Some(tokens)
85+
}
86+
87+
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88+
S::serialize_none(serializer)
89+
}
90+
91+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
92+
let input = scope.tensor_use_owned(&self.input, node_position);
93+
let output_size = scope.tensor_use_owned(&self.output_size, node_position);
94+
let output = &self.output.name;
95+
96+
let field = &self.field.name;
97+
98+
quote! {
99+
let output_size_raw = #output_size.to_data().value;
100+
let mut output_size = [0usize; 2];
101+
102+
for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() {
103+
output_size[i] = x.elem::<i64>() as usize;
104+
}
105+
106+
let #output = interpolate(
107+
#input,
108+
output_size,
109+
self.#field.0.clone(),
110+
);
111+
}
112+
}
113+
114+
fn into_node(self) -> Node<PS> {
115+
Node::Resize(self)
116+
}
117+
118+
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
119+
imports.register("burn::tensor::ElementConversion");
120+
imports.register("burn::tensor::module::interpolate");
121+
imports.register("burn::tensor::ops::InterpolateMode");
122+
imports.register("burn::tensor::ops::InterpolateOptions");
123+
}
124+
}
125+
126+
#[cfg(test)]
127+
mod tests {
128+
use burn::record::FullPrecisionSettings;
129+
130+
use super::*;
131+
use crate::burn::{
132+
graph::BurnGraph,
133+
node::{resize::ResizeNode, test::assert_tokens},
134+
TensorType,
135+
};
136+
137+
#[test]
138+
fn test_codegen_nodes() {
139+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
140+
141+
graph.register(ResizeNode::new(
142+
"resize",
143+
TensorType::new_float("tensor1", 4),
144+
TensorType::new_float("tensor2", 4),
145+
TensorType::new_int("output_size", 1),
146+
ResizeOptions::new(ResizeMode::Linear),
147+
));
148+
149+
graph.register_input_output(
150+
vec!["tensor1".to_string(), "output_size".to_string()],
151+
vec!["tensor2".to_string()],
152+
);
153+
154+
let expected = quote! {
155+
use burn::tensor::module::interpolate;
156+
use burn::tensor::ops::InterpolateMode;
157+
use burn::tensor::ops::InterpolateOptions;
158+
use burn::tensor::ElementConversion;
159+
use burn::tensor::Int;
160+
use burn::{
161+
module::Module,
162+
tensor::{backend::Backend, Tensor},
163+
};
164+
165+
#[derive(Module, Debug)]
166+
pub struct Model<B: Backend> {
167+
resize: burn::module::Ignored<InterpolateOptions>,
168+
phantom: core::marker::PhantomData<B>,
169+
device: burn::module::Ignored<B::Device>,
170+
}
171+
172+
impl<B: Backend> Model <B> {
173+
#[allow(unused_variables)]
174+
pub fn new(device: &B::Device) -> Self {
175+
let resize = InterpolateOptions {
176+
mode: InterpolateMode::Bilinear,
177+
};
178+
let resize = burn::module::Ignored(resize);
179+
Self {
180+
resize,
181+
phantom: core::marker::PhantomData,
182+
device: burn::module::Ignored(device.clone()),
183+
}
184+
}
185+
#[allow(clippy::let_and_return, clippy::approx_constant)]
186+
pub fn forward(
187+
&self,
188+
tensor1: Tensor<B, 4>,
189+
output_size: Tensor<B, 1, Int>
190+
) -> Tensor<B, 4> {
191+
let output_size_raw = output_size.to_data().value;
192+
let mut output_size = [0usize; 2];
193+
194+
for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() {
195+
output_size[i] = x.elem::<i64>() as usize;
196+
}
197+
198+
let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone());
199+
200+
tensor2
201+
}
202+
}
203+
};
204+
205+
assert_tokens(graph.codegen(), expected);
206+
}
207+
}

crates/burn-import/src/onnx/dim_inference.rs

+28
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node) {
5959
NodeType::ReduceSum => reduce_sum_update_outputs(node),
6060
NodeType::Relu => same_as_input(node),
6161
NodeType::Reshape => reshape_update_outputs(node),
62+
NodeType::Resize => resize_update_outputs(node),
6263
NodeType::Shape => shape_update_outputs(node),
6364
NodeType::Sigmoid => same_as_input(node),
6465
NodeType::Sign => same_as_input(node),
@@ -285,6 +286,33 @@ fn reshape_update_outputs(node: &mut Node) {
285286
}
286287
}
287288

289+
fn resize_update_outputs(node: &mut Node) {
290+
let input = match &node.inputs[0].ty {
291+
ArgType::Tensor(tensor) => tensor.clone(),
292+
_ => panic!("Resize: invalid input type"),
293+
};
294+
295+
let output = match &node.outputs[0].ty {
296+
ArgType::Tensor(tensor) => tensor.clone(),
297+
_ => panic!("Resize: invalid output type"),
298+
};
299+
300+
let output_size = match &node.inputs[3].ty {
301+
ArgType::Tensor(output_size) => output_size.clone(),
302+
_ => panic!("Resize: invalid output_size type"),
303+
};
304+
305+
if output_size.dim != 1 {
306+
panic!("Resize: output_size must be 1D");
307+
}
308+
309+
node.outputs[0].ty = ArgType::Tensor(TensorType {
310+
dim: input.dim,
311+
shape: None, // shape is calculated at runtime
312+
..output
313+
});
314+
}
315+
288316
fn greater_update_outputs(node: &mut Node) {
289317
match &node.inputs[0].ty {
290318
ArgType::Tensor(tensor) => {

0 commit comments

Comments
 (0)