Skip to content

Commit

Permalink
Fix lowering of onnx.Mul with dynamic shape
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Feb 3, 2025
1 parent 3d6f0a2 commit 4558b04
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(
return success();
}

namespace {
template <typename OnnxOp>
void copyResultType(OnnxOp opToCopyFrom, Value &valueToCopyTo) {
assert(opToCopyFrom->getNumResults() == 1);
valueToCopyTo.setType(opToCopyFrom->getResult(0).getType());
}
} // namespace

// Element-wise unary ops lowering to TOSA dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOpONNX, typename ElementwiseUnaryOpTOSA,
Expand Down Expand Up @@ -521,7 +529,7 @@ class ONNXHardSigmoidOpLoweringToTOSA
rewriter.getF32FloatAttr(0),
rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble()));
auto mulOp = tosaBuilder.mul(clampOp, constAlpha);

copyResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});
return success();
}
Expand Down
10 changes: 10 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t

// -----

func.func @test_mul_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
"func.return"(%0) : (tensor<13x?x?xf32>) -> ()
// CHECK-LABEL: func @test_mul_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
}

// -----

func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
Expand Down

0 comments on commit 4558b04

Please sign in to comment.