From 7ff921c5385e1f08dc534b67a969cd06b91714d5 Mon Sep 17 00:00:00 2001 From: mokulus <36231852+mokulus@users.noreply.github.com> Date: Tue, 21 May 2024 21:47:32 +0200 Subject: [PATCH] Add RandomNormal ONNX operator (#2200) --- candle-onnx/src/eval.rs | 22 +++--- candle-onnx/tests/ops.rs | 144 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 8 deletions(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 78e0554ac..65fb6d77b 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -971,7 +971,7 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } - "RandomUniform" => { + random_type @ ("RandomUniform" | "RandomNormal") => { let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float // type by // default @@ -979,36 +979,42 @@ pub fn simple_eval( Ok(dt) => match dtype(dt) { Some(DType::U8 | DType::U32 | DType::I64) => { bail!( - "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}", + "unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}", node.name ) } Some(dt) => dt, None => { bail!( - "unsupported 'dtype' value {dt:?} for RandomUnifrom {}", + "unsupported 'dtype' value {dt:?} for {random_type} {}", node.name ) } }, Err(_) => { bail!( - "unsupported 'dtype' value {dt:?} for RandomUniform {}", + "unsupported 'dtype' value {dt:?} for {random_type} {}", node.name ) } }; - let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); - let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); let seed: Option = get_attr_opt(node, "seed")?.copied(); if seed.is_some() { - bail!("seed for RandomUniform is currently not supported") + bail!("seed for {random_type} is currently not supported") }; let shape: Vec = get_attr::<[i64]>(node, "shape")? .iter() .map(|x| *x as usize) .collect(); - let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?; + let output = if random_type == "RandomUniform" { + let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); + let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); + Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)? + } else { + let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0); + let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0); + Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)? + }; values.insert(node.output[0].clone(), output); } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 30e2480bf..a53ad8c59 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> { Ok(()) } +// "RandomNormal" +#[test] +fn test_random_normal() -> Result<()> { + test(vec![3, 2, 1, 4], None, None)?; + test(vec![2, 2, 2, 2], Some(-10.0), None)?; + test(vec![2, 2, 2, 2], None, Some(10.0))?; + test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; + + fn test(shape: Vec, mean: Option, scale: Option) -> Result<()> { + let att_mean = AttributeProto { + name: "mean".to_string(), + ref_attr_name: "mean".to_string(), + i: 0, + doc_string: "mean".to_string(), + r#type: 1, // FLOAT + f: mean.unwrap_or(0.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_scale = AttributeProto { + name: "scale".to_string(), + ref_attr_name: "scale".to_string(), + i: 0, + doc_string: "scale".to_string(), + r#type: 1, // FLOAT + f: scale.unwrap_or(1.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_shape = AttributeProto { + name: "shape".to_string(), + ref_attr_name: "shape".to_string(), + i: 0, + doc_string: "shape".to_string(), + r#type: 7, // INTS + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: shape, + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_dtype = AttributeProto { + name: "dtype".to_string(), + ref_attr_name: "dtype".to_string(), + i: 11, // DOUBLE + doc_string: "dtype".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![att_shape, att_dtype]; + if mean.is_some() { + mut_attrs.push(att_mean); + } + if scale.is_some() { + mut_attrs.push(att_scale); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "RandomNormal".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let data = z.flatten_all()?.to_vec1::()?; + + // test if values are unique + for (i, a) in data.iter().enumerate() { + for (j, b) in data.iter().enumerate() { + if i == j { + continue; + }; + assert_ne!(a, b); + } + } + + Ok(()) + } + + Ok(()) +} + // "Range" #[test] fn test_range() -> Result<()> {