Skip to content

Commit

Permalink
Expand split ops (#2505)
Browse files Browse the repository at this point in the history
* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
  • Loading branch information
stevenlovegrove authored Sep 26, 2024
1 parent ad8a4c5 commit ed48f54
Show file tree
Hide file tree
Showing 2 changed files with 528 additions and 14 deletions.
213 changes: 199 additions & 14 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ fn simple_eval_(
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};

// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
Expand Down Expand Up @@ -608,15 +615,13 @@ fn simple_eval_(
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
let xs = if let Some(mins) = get_opt(1) {
xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
let xs = if let Some(maxs) = get_opt(2) {
xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
Expand Down Expand Up @@ -759,7 +764,14 @@ fn simple_eval_(
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;

// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
Expand Down Expand Up @@ -962,6 +974,7 @@ fn simple_eval_(
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};

values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
Expand Down Expand Up @@ -1199,6 +1212,152 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
let input_tensor = get(&node.input[0])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = input_tensor.normalize_axis(axis)?;

// Determine split sizes
let splits = if node.input.len() > 1 {
// If the split tensor is provided, use it to determine sizes
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
} else {
let num_outputs = if let Some(&num_outputs_attrib) =
get_attr_opt::<i64>(node, "num_outputs")?
{
num_outputs_attrib as usize
} else {
node.output.len()
};

let input_dim = input_tensor.dim(axis)?;

let mut split_sizes =
vec![input_dim / num_outputs as usize; num_outputs as usize];
let remainder = input_dim % num_outputs as usize;
if remainder > 0 {
// If there's a remainder, add it to the last split size
split_sizes[num_outputs as usize - 1] += remainder;
}

split_sizes
};

// Perform the split operation
let mut outputs = vec![];
let mut start = 0;
for &size in &splits {
let end = start + size;
let slice = input_tensor.narrow(axis, start, size)?;
outputs.push(slice);
start = end;
}

// Insert the split outputs into the values map
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
values.insert(output.clone(), slice);
}
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
// Version 13 impl
"Expand" => {
// unlike broadcast_to, expand allows for the output shape to
// be different from the specified shape.
let input_tensor = get(&node.input[0])?;
let input_shape = get(&node.input[1])?;

// Check that the shape tensor is 1D
if input_shape.rank() != 1 {
bail!(
"Expand expects 'shape' input to be 1D tensor: {:?}",
input_shape
);
}
let input_tensor_dims = input_tensor.dims();
let input_shape_dims = input_shape
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>();

let target_shape =
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;

let expanded_tensor = input_tensor.broadcast_as(target_shape)?;

values.insert(node.output[0].clone(), expanded_tensor);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
// Version 13 impl
"ReduceSum" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);

let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input.rank()).collect()
}
}
};

let output = if keepdims == 1 {
input.sum_keepdim(axes)?
} else {
input.sum(axes)?
};

values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
// Version 18 impl
"ReduceL2" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);

let input_sq = input.sqr()?;

let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input_sq.rank()).collect()
}
}
};

let output = if keepdims == 1 {
input_sq.sum_keepdim(axes)?.sqrt()?
} else {
input_sq.sum(axes)?.sqrt()?
};

values.insert(node.output[0].clone(), output);
}
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
Expand Down Expand Up @@ -1395,13 +1554,6 @@ fn simple_eval_(
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;

let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};

// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
Expand Down Expand Up @@ -1580,3 +1732,36 @@ fn simple_eval_(
})
.collect()
}

fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let (longest, shortest) = if shape_a.len() > shape_b.len() {
(shape_a, shape_b)
} else {
(shape_b, shape_a)
};
let diff = longest.len() - shortest.len();
let mut target_shape = longest[0..diff].to_vec();
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
target_shape.push(usize::max(*dim1, *dim2));
} else {
bail!(
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
shape_a,
shape_b
);
}
}
Ok(target_shape)
}

fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
if shapes.is_empty() {
return Ok(Vec::new());
}
let mut shape_out = shapes[0].to_vec();
for shape in shapes[1..].iter() {
shape_out = broadcast_shape(&shape_out, shape)?;
}
Ok(shape_out)
}
Loading

0 comments on commit ed48f54

Please sign in to comment.