Skip to content

Commit

Permalink
onnx: ReduceMin/Max Ops (#2563)
Browse files Browse the repository at this point in the history
* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README

* WIP: ONNX Reduce-max ops

* WIP: tests for ReduceMin

* Reduce min/ max v18+

* Reformatting tests for better review readability

* Error on empty set, backward compatibility (13 and below) with 'axes'
  • Loading branch information
AnubhabB authored Oct 15, 2024
1 parent 3d1dc06 commit a01aa89
Show file tree
Hide file tree
Showing 2 changed files with 1,211 additions and 1 deletion.
174 changes: 173 additions & 1 deletion candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

pub type Value = Tensor;

Expand Down Expand Up @@ -1189,6 +1189,92 @@ fn simple_eval_(
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};

axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();

if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}

if axes.len() > 1 {
axes.sort();
}

Some(axes)
} else {
None
};

// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.max_keepdim(axis)?
} else {
result.max(axis)?
}
}

result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.max_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.max(0)?
}
}
};

values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
Expand All @@ -1212,6 +1298,92 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};

axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();

if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}

if axes.len() > 1 {
axes.sort();
}

Some(axes)
} else {
None
};

// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.min_keepdim(axis)?
} else {
result.min(axis)?
}
}

result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.min_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.min(0)?
}
}
};

values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
Expand Down
Loading

0 comments on commit a01aa89

Please sign in to comment.