Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ONNXPadOp now also lowes to TOSA integer types
Browse files Browse the repository at this point in the history
roberteg16 committed Jan 31, 2025
1 parent 4219d8b commit 307a03f
Showing 2 changed files with 121 additions and 24 deletions.
57 changes: 43 additions & 14 deletions src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp
Original file line number Diff line number Diff line change
@@ -40,6 +40,18 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
Value data = adaptor.getData();
Value pads = adaptor.getPads();
Value constValue = adaptor.getConstantValue();

auto dataType = dyn_cast<RankedTensorType>(data.getType());
if (!dataType || !dataType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "input type has no static shape");
}

auto elementDtype =
dyn_cast<RankedTensorType>(data.getType()).getElementType();
if (!isa<FloatType>(elementDtype) && !isTOSAInt(elementDtype)) {
return rewriter.notifyMatchFailure(op, "unsupported type");
}

if (!adaptor.getAxes().getDefiningOp<ONNXNoneOp>()) {
return rewriter.notifyMatchFailure(op, "only default axes are supported");
}
@@ -78,27 +90,44 @@ class ONNXPadOpLoweringToTOSA : public OpConversionPattern<ONNXPadOp> {
mlir::Type resultType =
getTypeConverter()->convertType(op.getResult().getType());

float valueFloat = 0.0F;
if (!isa<NoneType>(constValue.getType())) {
auto valueAttr = tosa::getValueFromTosaConst<ElementsAttr>(constValue);
auto valueIt = valueAttr.getValues<FloatAttr>().begin();
// Need float for F32 Type
float valueFloat = cast<FloatAttr>(*valueIt).getValueAsDouble();

TosaBuilder tosaBuilder(rewriter, loc);
Value constTosaTensor =
tosaBuilder.getSplattedConst(valueFloat, valueAttr.getElementType());

Value constTosaTensor;
if (isa<FloatType>(valueAttr.getElementType())) {
auto valueIt = valueAttr.getValues<FloatAttr>().begin();
const float valueFloat = cast<FloatAttr>(*valueIt).getValueAsDouble();
constTosaTensor = tosaBuilder.getSplattedConst(
valueFloat, valueAttr.getElementType());
} else {
assert(isTOSAInt(elementDtype) && "Already validated");
auto valueIt = valueAttr.getValues<IntegerAttr>().begin();
const int64_t valueInt =
cast<IntegerAttr>(*valueIt).getValue().getSExtValue();
constTosaTensor =
tosaBuilder.getSplattedConst(valueInt, valueAttr.getElementType());
}
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
op, resultType, data, padsList1, constTosaTensor);
} else {
auto constType = RankedTensorType::get({}, rewriter.getF32Type());
auto constAttr = DenseElementsAttr::get(constType, valueFloat);
Value constTosaTensor = rewriter.create<mlir::tosa::ConstOp>(
op->getLoc(), constType, constAttr);

rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
op, resultType, data, padsList1, constTosaTensor);
} else {
auto constType = RankedTensorType::get({}, elementDtype);

DenseElementsAttr constAttr;
if (isa<FloatType>(elementDtype)) {
constAttr = DenseElementsAttr::get(constType, 0.0F);
} else {
assert(isTOSAInt(elementDtype) && "Already validated");
auto tyAsInt = cast<IntegerType>(elementDtype);
constAttr = DenseElementsAttr::get(constType,
llvm::APInt(tyAsInt.getWidth(), 0, tyAsInt.getSignedness()));
}

rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(op, resultType, data,
padsList1,
rewriter.create<mlir::tosa::ConstOp>(
op->getLoc(), constType, constAttr));
}

return success();
88 changes: 78 additions & 10 deletions test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir
Original file line number Diff line number Diff line change
@@ -1,55 +1,123 @@
// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s

func.func @test_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> {
func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> {
%noval = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32>
return %2 : tensor<24x22x52x42xf32>
// CHECK-LABEL: test_pad
// CHECK-LABEL: test_pad_f32
// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64>
// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4.500000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]]
}

// -----
func.func @test_no_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> {
func.func @test_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> {
%noval = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32>
return %2 : tensor<20x16x44x32xf32>
// CHECK-LABEL: test_no_pad
// CHECK-LABEL: test_no_pad_f32
// CHECK: return %arg0
}

// -----
func.func @test_novalue_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x45x33xf32> {
func.func @test_novalue_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x45x33xf32> {
%0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xf32>
return %2 : tensor<20x16x45x33xf32>
// CHECK-LABEL: test_novalue_pad
// CHECK-LABEL: test_novalue_pad_f32
// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64>
// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]]
}

// -----
func.func @test_novalue_no_pad(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> {
func.func @test_novalue_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> {
%0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xf32>
return %2 : tensor<20x16x44x32xf32>
// CHECK-LABEL: test_novalue_no_pad
// CHECK-LABEL: test_novalue_no_pad_f32
// CHECK: return %arg0
}

// -----
func.func @test_no_const_pad(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<8xi64>, %arg2: tensor<1xf32>) -> tensor<20x16x44x32xf32> {
func.func @test_no_const_pad_f32(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<8xi64>, %arg2: tensor<1xf32>) -> tensor<20x16x44x32xf32> {
%noval = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32>
return %2 : tensor<20x16x44x32xf32>
// CHECK-LABEL: test_no_const_pad
// CHECK-LABEL: test_no_const_pad_f32
// CHECK: "onnx.Pad"
}

// -----
func.func @test_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<24x22x52x42xi64> {
%noval = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64>
%2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<24x22x52x42xi64>
return %2 : tensor<24x22x52x42xi64>
// CHECK-LABEL: test_pad_i64
// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64>
// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor<i64>}> : () -> tensor<i64>
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]]
}

// -----
func.func @test_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> {
%noval = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64>
%2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64>
return %2 : tensor<20x16x44x32xi64>
// CHECK-LABEL: test_no_pad_i64
// CHECK: return %arg0
}

// -----
func.func @test_novalue_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x45x33xi64> {
%0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xi64>
return %2 : tensor<20x16x45x33xi64>
// CHECK-LABEL: test_novalue_pad_i64
// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64>
// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
// CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]]
}

// -----
func.func @test_novalue_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> {
%0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xi64>
return %2 : tensor<20x16x44x32xi64>
// CHECK-LABEL: test_novalue_no_pad_i64
// CHECK: return %arg0
}

// -----
func.func @test_no_const_pad_i64(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<8xi64>, %arg2: tensor<1xi64>) -> tensor<20x16x44x32xi64> {
%noval = "onnx.NoValue"() {value} : () -> none
%2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64>
return %2 : tensor<20x16x44x32xi64>
// CHECK-LABEL: test_no_const_pad_i64
// CHECK: "onnx.Pad"
}

// -----
func.func @test_pad_ui32(%arg0: tensor<20x16x44x32xui32>) -> tensor<24x22x52x42xui32> {
%noval = "onnx.NoValue"() {value} : () -> none
%0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64>
%1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xui32>} : () -> tensor<1xui32>
%2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xui32>, tensor<8xi64>, tensor<1xui32>, none) -> tensor<24x22x52x42xui32>
return %2 : tensor<24x22x52x42xui32>
// CHECK-LABEL: test_pad_ui32
// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64>
// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor<ui32>}> : () -> tensor<ui32>
// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]]
}

0 comments on commit 307a03f

Please sign in to comment.