Skip to content

Commit

Permalink
feat: Greater + GreaterOrEqual onnx import (#1801)
Browse files Browse the repository at this point in the history
  • Loading branch information
JachymPutta authored May 23, 2024
1 parent 1f31e20 commit ef4646c
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 4 deletions.
4 changes: 2 additions & 2 deletions crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ represent the corresponding Burn Op.
| [GlobalAveragePool][63] |||
| [GlobalLpPool][64] |||
| [GlobalMaxPool][65] |||
| [Greater][66] | ||
| [GreaterOrEqual][67] | ||
| [Greater][66] | ||
| [GreaterOrEqual][67] | ||
| [GridSample][68] |||
| [GroupNormalization][69] |||
| [GRU][70] |||
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ fn main() {
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/less/less.onnx")
.input("tests/less_or_equal/less_or_equal.onnx")
.input("tests/recip/recip.onnx")
Expand Down
17 changes: 17 additions & 0 deletions crates/burn-import/onnx-tests/tests/greater/greater.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.3.0:�
8
onnx::Greater_0
onnx::Greater_12/Greater"Greater
main_graphZ!
onnx::Greater_0


Z!
onnx::Greater_1


b
2
 

B
38 changes: 38 additions & 0 deletions crates/burn-import/onnx-tests/tests/greater/greater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/greater/greater.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
return torch.gt(x,y)

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "greater.onnx"

test_input1 = torch.randn(4, 4, device=device)
test_input2 = torch.randn(4, 4, device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.3.0:�
T
onnx::GreaterOrEqual_0
onnx::GreaterOrEqual_12/GreaterOrEqual"GreaterOrEqual
main_graphZ(
onnx::GreaterOrEqual_0


Z(
onnx::GreaterOrEqual_1


b
2
 

B
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx

import torch
import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
return torch.ge(x,y)

def main():
# Set seed for reproducibility
torch.manual_seed(42)
torch.set_printoptions(precision=8)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "greater_or_equal.onnx"

test_input1 = torch.randn(4, 4, device=device)
test_input2 = torch.randn(4, 4, device=device)
torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

print("Test input data: {} {}".format(test_input1, test_input2))
output = model.forward(test_input1, test_input2)
print("Test output data: {}".format(output))

if __name__ == '__main__':
main()
32 changes: 32 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ include_models!(
mul,
neg,
not,
greater,
greater_or_equal,
less,
less_or_equal,
prelu,
Expand Down Expand Up @@ -1173,6 +1175,20 @@ mod tests {
assert_eq!(output, expected);
}

#[test]
fn greater() {
let device = Default::default();
let model: greater::Model<Backend> = greater::Model::new(&device);

let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 25.0]], &device);
let input2 = Tensor::<Backend, 2>::from_floats([[1.0, 5.0, 8.0, -25.0]], &device);

let output = model.forward(input1, input2);
let expected = Data::from([[false, false, true, true]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn less() {
let device = Default::default();
Expand All @@ -1183,6 +1199,21 @@ mod tests {

let output = model.forward(input1, input2);
let expected = Data::from([[false, true, false, false]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn greater_or_equal() {
let device = Default::default();
let model: greater_or_equal::Model<Backend> = greater_or_equal::Model::new(&device);

let input1 = Tensor::<Backend, 2>::from_floats([[1.0, 4.0, 9.0, 25.0]], &device);
let input2 = Tensor::<Backend, 2>::from_floats([[1.0, 5.0, 8.0, -25.0]], &device);

let output = model.forward(input1, input2);
let expected = Data::from([[true, false, true, true]]);

assert_eq!(output.to_data(), expected);
}

Expand All @@ -1196,6 +1227,7 @@ mod tests {

let output = model.forward(input1, input2);
let expected = Data::from([[true, true, false, false]]);

assert_eq!(output.to_data(), expected);
}

Expand Down
38 changes: 38 additions & 0 deletions crates/burn-import/src/burn/node/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub enum BinaryType {
Powi,
Min,
Max,
Greater,
GreaterOrEqual,
Less,
LessOrEqual,
}
Expand All @@ -32,6 +34,8 @@ impl BinaryType {
BinaryType::Powf => "powf",
BinaryType::Min => "min_pair",
BinaryType::Max => "max_pair",
BinaryType::Greater => "greater",
BinaryType::GreaterOrEqual => "greater_equal",
BinaryType::Less => "lower",
BinaryType::LessOrEqual => "lower_equal",
}
Expand Down Expand Up @@ -198,6 +202,30 @@ impl BinaryNode {
Self::new(lhs, rhs, output, BinaryType::Max, Arc::new(function))
}

pub(crate) fn greater(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.greater(#rhs) },
_ => panic!("greater is supported for tensor only"),
};
Self::new(lhs, rhs, output, BinaryType::Greater, Arc::new(function))
}

pub(crate) fn greater_equal(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => {
move |lhs, rhs| quote! { #lhs.greater_equal(#rhs) }
}
_ => panic!("greater_equal is supported for tensor only"),
};
Self::new(
lhs,
rhs,
output,
BinaryType::GreaterOrEqual,
Arc::new(function),
)
}

pub(crate) fn lower(lhs: Type, rhs: Type, output: Type) -> Self {
let function = match (&lhs, &rhs) {
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower(#rhs) },
Expand Down Expand Up @@ -384,6 +412,16 @@ mod tests {
test_binary_operator_on_tensors!(max_pair);
}

#[test]
fn test_binary_codegen_greater() {
test_binary_operator_on_tensors!(greater);
}

#[test]
fn test_binary_codegen_greater_or_equal() {
test_binary_operator_on_tensors!(greater_equal);
}

#[test]
fn test_binary_codegen_less() {
test_binary_operator_on_tensors!(lower);
Expand Down
26 changes: 26 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::Mul => same_as_input(node),
NodeType::Neg => same_as_input(node),
NodeType::Not => same_as_input(node),
NodeType::Greater => greater_update_outputs(node),
NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node),
NodeType::Less => less_update_outputs(node),
NodeType::LessOrEqual => less_or_equal_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
Expand Down Expand Up @@ -239,6 +241,18 @@ fn reshape_update_outputs(node: &mut Node) {
}
}

fn greater_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}

fn less_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
Expand All @@ -251,6 +265,18 @@ fn less_update_outputs(node: &mut Node) {
}
}

fn greater_or_equal_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type: ElementType::Bool,
..tensor.clone()
});
}
_ => panic!("Only tensor input is valid"),
}
}

fn less_or_equal_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
Expand Down
18 changes: 16 additions & 2 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ impl OnnxGraph {
NodeType::MatMul => graph.register(Self::matmul_conversion(node)),
NodeType::Neg => graph.register(Self::neg_conversion(node)),
NodeType::Not => graph.register(Self::not_conversion(node)),
NodeType::Greater => graph.register(Self::greater_conversion(node)),
NodeType::GreaterOrEqual => graph.register(Self::greater_or_equal_conversion(node)),
NodeType::Less => graph.register(Self::less_conversion(node)),
NodeType::LessOrEqual => graph.register(Self::less_or_equal_conversion(node)),
NodeType::LayerNormalization => {
Expand Down Expand Up @@ -824,19 +826,31 @@ impl OnnxGraph {
UnaryNode::not(input, output)
}

fn less_conversion(node: Node) -> BinaryNode {
fn greater_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
BinaryNode::greater(lhs, rhs, output)
}

fn less_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
BinaryNode::lower(lhs, rhs, output)
}

fn less_or_equal_conversion(node: Node) -> BinaryNode {
fn greater_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
BinaryNode::greater_equal(lhs, rhs, output)
}

fn less_or_equal_conversion(node: Node) -> BinaryNode {
let lhs = node.inputs.first().unwrap().to_type();
let rhs = node.inputs.get(1).unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
BinaryNode::lower_equal(lhs, rhs, output)
}

Expand Down

0 comments on commit ef4646c

Please sign in to comment.