Skip to content

Commit

Permalink
Merge pull request #279 from Xilinx/chaitany.convtranspose_as_option
Browse files Browse the repository at this point in the history
Chaitany.convtranspose as option
  • Loading branch information
chaitanyakamarapu authored Feb 3, 2025
2 parents 27157da + a883dbc commit 3d6f0a2
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 17 deletions.
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ project(onnx-mlir)
option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON)
option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF)
option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
Expand Down Expand Up @@ -223,12 +222,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
endif()

if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE)
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 1)
else()
set(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED 0)
endif()

add_subdirectory(utils)
add_subdirectory(include)
Expand Down
7 changes: 7 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ bool enableParallel; // onnx-mlir only
bool disableSimdOption; // onnx-mlir only
bool enableFastMathOption; // onnx-mlir only
bool disableRecomposeOption; // onnx-mlir only
bool disableConvTransposeDecomposeOption; // onnx-mlir only
bool enableSimdDataLayout; // onnx-mlir only
bool verifyInputTensors; // onnx-mlir only
bool allowSorting; // onnx-mlir only
Expand Down Expand Up @@ -247,6 +248,12 @@ static llvm::cl::opt<bool, true> disableRecomposeOptionOpt("disable-recompose",
llvm::cl::location(disableRecomposeOption), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> disableConvTranposeDecomposeOptionOpt(
"disable-convtranspose-decompose",
llvm::cl::desc("Disable decomposition of ONNX ConvTranspose operator."),
llvm::cl::location(disableConvTransposeDecomposeOption),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));

// Options for onnx-mlir only
static llvm::cl::opt<EmissionTargetType, true> emissionTargetOpt(
llvm::cl::desc("Choose target to emit:"),
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ extern bool enableParallel; // onnx-mlir only
extern bool disableSimdOption; // onnx-mlir only
extern bool enableFastMathOption; // onnx-mlir only
extern bool disableRecomposeOption; // onnx-mlir only
extern bool disableConvTransposeDecomposeOption; // onnx-mlir only
extern bool enableSimdDataLayout; // onnx-mlir only
extern bool verifyInputTensors; // onnx-mlir only
extern bool allowSorting; // onnx-mlir only
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -451,15 +452,14 @@ Value replaceSequenceAt(
}

bool shouldDecomposeConvTransposeOp(Value convTransposeResult) {
#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE
if (onnx_mlir::disableConvTransposeDecomposeOption) {
// Disable the ONNXConvTransposeOp decomposition patterns.
return false;
}
ONNXConvTransposeOp op =
mlir::cast<ONNXConvTransposeOp>(convTransposeResult.getDefiningOp());
return hasShapeAndRank(convTransposeResult) &&
hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW());
#else
// Disable the ONNXConvTransposeOp decomposition patterns.
return false;
#endif
}

// Split on the specified axis. The length of each output is one.
Expand Down
3 changes: 0 additions & 3 deletions test/mlir/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,3 @@
# execution based on the available targets
for arch in config.targets_to_build.split():
config.available_features.add(arch.lower())

if config.decomp_onnx_convtranspose:
config.available_features.add("decomp_onnx_convtranspose")
1 change: 0 additions & 1 deletion test/mlir/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ config.onnx_mlir_obj_root = r"@ONNX_MLIR_BIN_ROOT@"

config.enable_stablehlo = @ONNX_MLIR_STABLEHLO_ENABLED@
config.enable_nnpa= 0x0@NNPA_LIT_ENABLED@
config.decomp_onnx_convtranspose = @ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE_ENABLED@

# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.
Expand Down
1 change: 0 additions & 1 deletion test/mlir/onnx/onnx_decompose_convtranspose.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx %s -split-input-file | FileCheck %s

// REQUIRES: decomp_onnx_convtranspose

// -----

Expand Down
104 changes: 104 additions & 0 deletions test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --disable-convtranspose-decompose %s -split-input-file | FileCheck %s


// -----

// Test unit strides. Only convert weight tensor

func.func @test_convtrans_unitstrides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32>
onnx.Return %1 : tensor<1x2x5x5xf32>
// CHECK-LABEL: func.func @test_convtrans_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x5xf32>
// CHECK: }
}

// -----

// Test 1d input

func.func @test_convtrans1d_unitstrides(%arg0: tensor<1x1x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32>
onnx.Return %1 : tensor<1x2x5xf32>
// CHECK-LABEL: func.func @test_convtrans1d_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5xf32>
// CHECK: }
}

// -----

// Test 3d input

func.func @test_convtrans3d_unitstrides(%arg0: tensor<1x1x3x4x5xf32>, %arg1: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32>
onnx.Return %1 : tensor<1x2x5x6x7xf32>
// CHECK-LABEL: func.func @test_convtrans3d_unitstrides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x4x5xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x6x7xf32>
// CHECK: }
}

// -----

// Test non unit strides. Added pads between elements in input data.

func.func @test_convtrans_strides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32>
onnx.Return %1 : tensor<1x2x7x3xf32>
// CHECK-LABEL: func.func @test_convtrans_strides(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x7x3xf32>
// CHECK: }
}

// -----

// Test output_padding. Additional pads are inserted after Conv op

func.func @test_convtrans_outputpadding(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32>
onnx.Return %1 : tensor<1x2x10x8xf32>
// CHECK-LABEL: func.func @test_convtrans_outputpadding(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x10x8xf32>
// CHECK: }
}

// -----

// Test for unknown dimension in spatial dimensions

func.func @test_convtranspose_unknown_spatial_dim(%arg0: tensor<?x?x3x3xf32>, %arg1: tensor<?x?x3x3xf32>) -> tensor<?x?x10x8xf32> {
%0 = "onnx.NoValue"() {value} : () -> none
%1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor<?x?x3x3xf32>, tensor<?x?x3x3xf32>, none) -> tensor<?x?x10x8xf32>
onnx.Return %1 : tensor<?x?x10x8xf32>
// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x3x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?x3x3xf32>) -> tensor<?x?x10x8xf32> {
// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none
// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor<?x?x3x3xf32>, tensor<?x?x3x3xf32>, none) -> tensor<?x?x10x8xf32>
// CHECK: onnx.Return %[[VAL_3]] : tensor<?x?x10x8xf32>
// CHECK: }
}

0 comments on commit 3d6f0a2

Please sign in to comment.