diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index eeee6ab7f0..2030b51200 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -108,6 +108,14 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps( return success(); } +namespace { +template +void copyResultType(OnnxOp opToCopyFrom, Value &valueToCopyTo) { + assert(opToCopyFrom->getNumResults() == 1); + valueToCopyTo.setType(opToCopyFrom->getResult(0).getType()); +} +} // namespace + // Element-wise unary ops lowering to TOSA dialect. //===----------------------------------------------------------------------===// template , %arg1: tensor<13x21x1xf32>) -> t // ----- +func.func @test_mul_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func @test_mul_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +} + +// ----- + func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> { %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32> "func.return"(%0) : (tensor<13x21x1xf32>) -> ()