From 7c09215ef443256523d2de2579db56d1b59fd683 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Thu, 17 Oct 2024 23:52:35 +0530 Subject: [PATCH] ONNX: GatherElements, Xor (#2568) --- candle-onnx/src/eval.rs | 53 ++++ candle-onnx/tests/ops.rs | 529 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 582 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 629b3f93d..358af7acf 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -670,6 +670,49 @@ fn simple_eval_( }; values.insert(node.output[0].clone(), xs); } + // https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements + // A Note to fellow lurkers: + // The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py) + // and examples is incorrect. + // Use `torch.gather` for the validating/ verifying against the proper behaviour + "GatherElements" => { + let data = get(&node.input[0])?; + let indices = get(&node.input[1])?; + + let rank = data.rank(); + if rank != indices.rank() { + bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank()); + } + + let axis = { + let axis_i64 = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); + let axis = data.normalize_axis(axis_i64)?; + + if axis >= rank { + bail!( + "axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]", + axis_i64, + rank - 1 + ) + } + + axis + }; + + // index_select does not support negative indices, so normalize them + // to positive indices. + let indices = &{ + let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?; + let max = Tensor::new(data.dims()[axis] as i64, indices.device())? + .to_dtype(indices.dtype())?; + let mask = indices.lt(&zeros)?; + mask.to_dtype(indices.dtype())? + .broadcast_mul(&max)? + .add(indices)? + }; + + values.insert(node.output[0].clone(), data.gather(indices, axis)?); + } "Shape" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape let xs = get(&node.input[0])?; @@ -1891,6 +1934,16 @@ fn simple_eval_( ); } } + // https://onnx.ai/onnx/operators/onnx__Xor.html + "Xor" => { + // Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True) + let a = get(&node.input[0])?.gt(0_u8)?; + let b = get(&node.input[1])?.gt(0_u8)?; + + let out = a.broadcast_add(&b)?.eq(1_u8)?; + + values.insert(node.output[0].clone(), out); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 450a9879e..a84ba481e 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -1159,6 +1159,163 @@ fn test_gather_operation() -> Result<()> { Ok(()) } +// GatherElements +#[test] +fn test_gather_elements() -> Result<()> { + // all the tests below are verified against `torch.gather()` + + // Rank 1 index + test(&[1.0, 2.0, 3.0, 4.0], &[3i64], 0, &[4.0])?; + + // Rank 2 index + test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 1, &[[4.0]])?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_0 + test( + &[[1., 2.], [3., 4.]], + &[[0i64, 0], [1, 0]], + 1, + &[[1., 1.], [4., 3.]], + )?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_1 + test( + &[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + &[[1i64, 2, 0], [2, 0, 0]], + 0, + &[[4., 8., 3.], [7., 2., 3.]], + )?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_negative_indices + test( + &[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + &[[-1_i64, -2, 0], [-2, 0, 0]], + 0, + &[[7., 5., 3.], [4., 2., 3.]], + )?; + test( + &[[1.0], [2.0], [3.0], [4.0]], + &[[3i64], [2]], + 0, + &[[4.], [3.]], + )?; + + // Rank 3 + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64]]], + 0, + &[[[5.]]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64]]], + 1, + &[[[3.]]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64], [0]]], + 2, + &[[[2.], [3.]]], + )?; + + // Error cases + // Invalid index + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 0, &[[1., 2., 3., 4.]]).is_err()); + // Invalid axis/ dim + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 2, &[[1., 2., 3., 4.]]).is_err()); + // Invalid rank + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[3i64], 0, &[[1.]]).is_err()); + + fn test( + data: impl NdArray, + indices: impl NdArray, + axis: i64, + expected: impl NdArray, + ) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis, + doc_string: "axis".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "GatherElements".to_string(), + domain: "".to_string(), + attribute: vec![att_axis], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + // "Size" #[test] fn test_size_operation() -> Result<()> { @@ -5340,3 +5497,375 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> { Ok(()) } + +// Xor +#[test] +fn test_xor() -> Result<()> { + // tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor + + // 2d + test( + &[[0_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1]], + &[[1_u8, 1, 0, 0], [1, 0, 0, 1], [1, 1, 1, 0]], + &[[1_u8, 0, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]], + )?; + + // 3d + test( + &[ + [ + [0_u8, 1, 1, 1, 1], + [0, 1, 1, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 1], + ], + [ + [0, 0, 1, 1, 1], + [1, 0, 1, 1, 1], + [1, 1, 0, 0, 1], + [1, 0, 0, 1, 0], + ], + [ + [1, 0, 0, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 1], + [1, 0, 0, 0, 1], + ], + ], + &[ + [ + [1_u8, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + ], + [ + [1, 0, 0, 1, 1], + [1, 0, 1, 1, 1], + [0, 1, 0, 1, 1], + [1, 1, 1, 0, 0], + ], + [ + [0, 1, 1, 1, 0], + [1, 1, 0, 1, 0], + [0, 1, 1, 1, 0], + [1, 1, 0, 1, 0], + ], + ], + &[ + [ + [1_u8, 1, 1, 0, 0], + [0, 1, 0, 0, 1], + [0, 1, 1, 0, 1], + [0, 0, 0, 0, 1], + ], + [ + [1, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 1, 1, 0], + ], + [ + [1, 1, 1, 0, 1], + [0, 0, 1, 1, 0], + [1, 0, 1, 1, 1], + [0, 1, 0, 1, 1], + ], + ], + )?; + + // 4d + test( + &[ + [ + [[0_u8, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1]], + [[1, 1, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]], + ], + [ + [[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 0]], + [[1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1]], + ], + ], + &[ + [ + [[1_u8, 0, 1, 0], [0, 0, 1, 1], [1, 0, 1, 0]], + [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]], + ], + [ + [[1, 1, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]], + [[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1]], + ], + ], + &[ + [ + [[1_u8, 1, 0, 0], [1, 0, 1, 1], [0, 1, 1, 1]], + [[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 0]], + ], + [ + [[0, 0, 1, 0], [1, 0, 1, 1], [1, 0, 1, 0]], + [[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 1, 0]], + ], + ], + )?; + + // tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor_broadcast + // 3d vs 1d + test( + // Shape (3, 4, 5) + &[ + [ + [0_u8, 0, 0, 0, 1], + [0, 1, 0, 1, 1], + [1, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + ], + [ + [0, 1, 0, 1, 1], + [1, 1, 0, 0, 1], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + [ + [1, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 1, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + ], + // shape (5) + &[1_u8, 0, 0, 1, 1], + // shape (3, 4, 5) + &[ + [ + [1_u8, 0, 0, 1, 0], + [1, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 1, 0], + ], + [ + [1, 1, 0, 0, 0], + [0, 1, 0, 1, 0], + [1, 1, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + [ + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 1, 1, 1, 0], + [0, 1, 0, 0, 0], + ], + ], + )?; + + // 3d vs 2d + test( + // Shape (3, 4, 5) + &[ + [ + [0_u8, 0, 0, 0, 1], + [0, 1, 0, 1, 1], + [1, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + ], + [ + [0, 1, 0, 1, 1], + [1, 1, 0, 0, 1], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + [ + [1, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 1, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + ], + // shape (4, 5) + &[ + [0_u8, 1, 0, 1, 0], + [0, 0, 1, 0, 0], + [1, 1, 0, 1, 1], + [1, 1, 0, 1, 0], + ], + // shape (3, 4, 5) + &[ + [ + [0_u8, 1, 0, 1, 1], + [0, 1, 1, 1, 1], + [0, 1, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + [ + [0, 0, 0, 0, 1], + [1, 1, 1, 0, 1], + [1, 0, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + [ + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 1], + [1, 0, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + ], + )?; + + // 4d vs 2d + test( + // Shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]], + [[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]], + [[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]], + ], + [ + [[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]], + [[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]], + [[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]], + ], + ], + // shape (3, 4) + &[[0_u8, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]], + // shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]], + [[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0]], + [[1, 0, 1, 1], [0, 0, 0, 1], [0, 1, 1, 0]], + ], + [ + [[0, 1, 1, 0], [0, 0, 1, 0], [1, 1, 1, 0]], + [[1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 0]], + [[1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]], + ], + ], + )?; + + // 4d vs 3d + test( + // Shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]], + [[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]], + [[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]], + ], + [ + [[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]], + [[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]], + [[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]], + ], + ], + // shape (3, 3, 4) + &[ + [[1_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 0, 0]], + [[0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]], + [[0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]], + ], + // shape (2, 3, 3, 4) + &[ + [ + [[0_u8, 1, 0, 1], [1, 1, 1, 1], [0, 0, 0, 0]], + [[1, 0, 0, 1], [0, 1, 0, 0], [1, 1, 0, 0]], + [[1, 1, 1, 0], [0, 1, 0, 1], [1, 1, 1, 0]], + ], + [ + [[1, 0, 0, 1], [1, 1, 1, 0], [1, 1, 1, 1]], + [[1, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], + [[1, 1, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]], + ], + ], + )?; + + // 4d vs 4d + test( + // Shape (1, 4, 1, 2) + &[[[[1_u8, 0]], [[1, 0]], [[1, 0]], [[1, 1]]]], + // shape (2, 1, 4, 2) + &[ + [[[0_u8, 0], [1, 1], [1, 1], [1, 1]]], + [[[0, 1], [1, 0], [0, 1], [0, 0]]], + ], + // shape (2, 4, 4, 2) + &[ + [ + [[1_u8, 0], [0, 1], [0, 1], [0, 1]], + [[1, 0], [0, 1], [0, 1], [0, 1]], + [[1, 0], [0, 1], [0, 1], [0, 1]], + [[1, 1], [0, 0], [0, 0], [0, 0]], + ], + [ + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 0], [0, 1], [1, 0], [1, 1]], + ], + ], + )?; + + fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Xor".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let inputs: HashMap = HashMap::from([ + (INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?), + (INPUT_Y.to_string(), Tensor::new(other, &Device::Cpu)?), + ]); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + + match expected.dims().len() { + 0 => { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } + 1 => { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } + 2 => { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } + 3 => { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } + 4 => { + // Candle has no method equivallent to `to_vec4()` + // So, as a hack, we flatten it to a single dim vec to test the results + assert_eq!( + z.flatten_all()?.to_vec1::()?, + expected.flatten_all()?.to_vec1::()? + ) + } + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +}