Skip to content

Commit

Permalink
Merge pull request #282 from Xilinx/jrickert.mul.dynamic
Browse files Browse the repository at this point in the history
Fix lowering of onnx.Mul with dynamic shape
  • Loading branch information
jorickert authored Feb 3, 2025
2 parents 3d6f0a2 + 149f379 commit 3939ed0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ LogicalResult checkBasicTosaRequirementsForBinaryOps(
return success();
}

namespace {
template <typename OnnxOp>
void copySingleResultType(OnnxOp opToCopyFrom, Value &valueToCopyTo) {
assert(opToCopyFrom->getNumResults() == 1);
valueToCopyTo.setType(opToCopyFrom->getResult(0).getType());
}
} // namespace

// Element-wise unary ops lowering to TOSA dialect.
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOpONNX, typename ElementwiseUnaryOpTOSA,
Expand Down Expand Up @@ -197,6 +205,7 @@ class ONNXMulOpLoweringToTosa : public OpConversionPattern<ONNXMulOp> {

TosaBuilder tosaBuilder(rewriter, op->getLoc());
Value mulOp = tosaBuilder.mul(lhs, rhs);
copySingleResultType(op, mulOp);
rewriter.replaceOp(op, {mulOp});

return success();
Expand Down
10 changes: 10 additions & 0 deletions test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> t

// -----

func.func @test_mul_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
"func.return"(%0) : (tensor<13x?x?xf32>) -> ()
// CHECK-LABEL: func @test_mul_dynamic
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?x?xf32>, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> {
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor<?x?x?xf32>, tensor<13x?x?xf32>) -> tensor<13x?x?xf32>
}

// -----

func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> {
%0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32>
"func.return"(%0) : (tensor<13x21x1xf32>) -> ()
Expand Down

0 comments on commit 3939ed0

Please sign in to comment.