Skip to content

Commit

Permalink
Add shape ONNX op support (#1639)
Browse files Browse the repository at this point in the history
* Add shape onnx op support

* Remove cast node from onnx graph

* Fix shape implementation

* Fix shape config error message

* Fix typo

* Fix clippy type complexity for generated code
  • Loading branch information
laggui authored Apr 16, 2024
1 parent 6d96e8d commit 35b36bb
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Here's how powf was added to burn fusion:
The way wgpu handles tensor-scalar operations is by transforming both into a sequence of vectorized
scalar operations. Since powf already existed in burn-wgpu, it was pretty easy to reuse the existing
implementation for the situation where both sides of the operation were tensors. The `burn-wgpu`
crate is primarily concered with how the operation is compiled and executed by the gpu. The actual
crate is primarily concerned with how the operation is compiled and executed by the gpu. The actual
implementation is defined in `burn-jit`.

Here is where code was added for powf in `burn-jit` and `burn-wgpu`:
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ represent the corresponding Burn Op.
| [SequenceInsert][157] |||
| [SequenceLength][158] |||
| [SequenceMap][159] |||
| [Shape][160] | ||
| [Shape][160] | ||
| [Shrink][161] |||
| [Sigmoid][162] |||
| [Sign][163] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn main() {
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/shape/shape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/sin/sin.onnx")
.input("tests/softmax/softmax.onnx")
Expand Down
16 changes: 16 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
macro_rules! include_models {
($($model:ident),*) => {
$(
// Allow type complexity for generated code
#[allow(clippy::type_complexity)]
pub mod $model {
include!(concat!(env!("OUT_DIR"), concat!("/model/", stringify!($model), ".rs")));
}
Expand Down Expand Up @@ -46,6 +48,7 @@ include_models!(
reduce_mean,
relu,
reshape,
shape,
sigmoid,
sin,
softmax,
Expand Down Expand Up @@ -476,6 +479,19 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn shape() {
let device = Default::default();
let model: shape::Model<Backend> = shape::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 2>::ones([4, 2], &device);
let output = model.forward(input);
let expected = Data::from([4, 2]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn flatten() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-import/onnx-tests/tests/shape/shape.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
pytorch2.1.2:K

x2Shape_0"Shape
main_graphZ
x

b
b
2


B
Expand Down
66 changes: 66 additions & 0 deletions crates/burn-import/onnx-tests/tests/shape/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/shape/shape.onnx

import onnx
import torch
import torch.nn as nn


# Trace with TorchScript to return the shape tensor (otherwise, would gather the shape
# of each dim as a scalar)
@torch.jit.script
def shape(x):
return torch.tensor(x.shape)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return shape(x)


def main():
# Set seed for reproducibility
torch.manual_seed(42)

torch.set_printoptions(precision=8)

# Export to onnx
device = torch.device("cpu")
model = Model()
model.eval()
test_input = torch.ones(4, 2, device=device)
file_name = "shape.onnx"

torch.onnx.export(
model,
test_input,
file_name,
input_names=["x"],
dynamic_axes={"x": {0: "b"}},
verbose=False,
opset_version=16,
)

m = onnx.load(file_name)
# Remove cast node
m.graph.node.pop(1)
m.graph.node[0].output[0] = m.graph.output[0].name
onnx.save(m, file_name)

print(f"Finished exporting model to {file_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
print(f"Test input data shape: {test_input.shape}")
output = model.forward(test_input)
# print(f"Test output data shape: {output.shape}")

print(f"Test output: {output}")


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum UnaryNodeKind {
Reciprocal,
LeakyRelu,
Relu,
Shape,
Sigmoid,
Sin,
Softmax,
Expand All @@ -60,6 +61,7 @@ impl UnaryNodeKind {
Self::Reciprocal => "reciprocal",
Self::LeakyRelu => "leaky_relu",
Self::Relu => "relu",
Self::Shape => "shape",
Self::Sigmoid => "sigmoid",
Self::Sin => "sin",
Self::Softmax => "softmax",
Expand Down Expand Up @@ -123,6 +125,9 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
UnaryNodeKind::Neg => {
imports.register("core::ops::Neg");
}
UnaryNodeKind::Shape => {
imports.register("burn::tensor::Int");
}
UnaryNodeKind::Not => {
imports.register("burn::tensor::Bool");
}
Expand Down Expand Up @@ -314,6 +319,22 @@ impl UnaryNode {
panic!("ReduceMean only supports tensor output");
}
}

pub(crate) fn shape(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
// Shape as defined by the ONNX op should return a tensor because other ops
// (e.g., Gather) will be used on a tensor
let function = move |input| {
quote! {
Tensor::<B, 1, Int>::from_data(
burn::tensor::Data::from(&#input.dims()[#start_dim..#end_dim])
.from_usize::<i64>()
.convert::<burn::tensor::ops::IntElem<B>>(),
&#input.device(),
)
}
};
Self::new(input, output, UnaryNodeKind::Shape, Rc::new(function))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -784,4 +805,30 @@ mod tests {
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_shape() {
one_node_graph(
UnaryNode::shape(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_int("tensor2", 1)),
1,
3,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1, Int> {
let tensor2 = Tensor::<B, 1, Int>::from_data(
burn::tensor::Data::from(&tensor1.dims()[1usize..3usize])
.from_usize::<i64>()
.convert::<burn::tensor::ops::IntElem<B>>(),
&tensor1.device(),
);

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}
}
13 changes: 8 additions & 5 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,17 @@ fn equal_update_outputs(node: &mut Node) {

fn shape_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("Gather: multiple inputs are not supported: {:?}", node);
panic!("Shape: multiple inputs are not supported: {:?}", node);
}

// Extract the configuration of the linear layer (inputs are known)
let node_input = &mut node.inputs[0];
if let ArgType::Tensor(tensor) = node_input.clone().ty {
// Update the output tensor
node.outputs[0].ty = ArgType::Shape(tensor.dim);
if let ArgType::Tensor(_tensor) = node_input.clone().ty {
// Output tensor is 1D int64
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Int64,
dim: 1,
..Default::default()
});
} else {
panic!("Only tensor input is valid");
}
Expand Down
38 changes: 38 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,41 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
Some(dim as usize)
}
}

pub fn shape_config(curr: &Node) -> (usize, usize) {
if curr.inputs.len() != 1 {
panic!(
"Shape: multiple inputs are not supported (got {:?})",
curr.inputs.len()
);
}

// Extract the shape of the input tensor
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Default: all axes up to the last one (included)
let mut start_dim: i64 = 0;
let mut end_dim: i64 = tensor.dim as i64;

// Extract the attributes
for (key, value) in curr.attrs.iter() {
match key.as_str() {
"start" => start_dim = value.clone().into_i64(),
"end" => end_dim = value.clone().into_i64(),
_ => {}
}
}

// If dim is negative, it is counted from the end
if start_dim < 0 {
start_dim += tensor.dim as i64;
}
if end_dim < 0 {
end_dim += tensor.dim as i64;
}

(start_dim as usize, end_dim as usize)
}
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ impl OnnxGraph {
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Shape => graph.register(Self::shape_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Sin => graph.register(Self::sin_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
Expand Down Expand Up @@ -474,6 +475,14 @@ impl OnnxGraph {
UnaryNode::reduce_mean(input, output, dim)
}

fn shape_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let (start_dim, end_dim) = shape_config(&node);

UnaryNode::shape(input, output, start_dim, end_dim)
}

fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let output = node.outputs.first().unwrap().to_tensor_type();
Expand Down

0 comments on commit 35b36bb

Please sign in to comment.