Skip to content

Commit

Permalink
ONNX: GatherElements, Xor
Browse files Browse the repository at this point in the history
  • Loading branch information
AnubhabB committed Oct 17, 2024
1 parent 6a8bdfb commit 0fab71a
Show file tree
Hide file tree
Showing 2 changed files with 582 additions and 0 deletions.
53 changes: 53 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,49 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), xs);
}
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
// A Note to fellow lurkers:
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
// and examples is incorrect.
// Use `torch.gather` for the validating/ verifying against the proper behaviour
"GatherElements" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;

let rank = data.rank();
if rank != indices.rank() {
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
}

let axis = {
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = data.normalize_axis(axis_i64)?;

if axis >= rank {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
axis_i64,
rank - 1
)
}

axis
};

// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};

values.insert(node.output[0].clone(), data.gather(indices, axis)?);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
Expand Down Expand Up @@ -1891,6 +1934,16 @@ fn simple_eval_(
);
}
}
// https://onnx.ai/onnx/operators/onnx__Xor.html
"Xor" => {
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
let a = get(&node.input[0])?.gt(0_u8)?;
let b = get(&node.input[1])?.gt(0_u8)?;

let out = a.broadcast_add(&b)?.eq(1_u8)?;

values.insert(node.output[0].clone(), out);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
Expand Down
Loading

0 comments on commit 0fab71a

Please sign in to comment.