diff --git a/.buildbot/Jenkinsfile b/.buildbot/Jenkinsfile index e63fb6de3d..daab23bbce 100644 --- a/.buildbot/Jenkinsfile +++ b/.buildbot/Jenkinsfile @@ -49,25 +49,25 @@ def call() { DOCKER_DAEMON_SOCKET = 'unix://var/run/docker.sock' DOCKER_REGISTRY_TOKEN_ACCESS = 'true' /* Settings for docker.io */ + /* DOCKER_REGISTRY_HOST_NAME = '' DOCKER_REGISTRY_USER_NAME = 'onnxmlir' DOCKER_REGISTRY_LOGIN_NAME = 'onnxmlir' + */ /* Settings for ghcr.io */ - /* DOCKER_REGISTRY_HOST_NAME = 'ghcr.io' DOCKER_REGISTRY_USER_NAME = 'onnxmlir' DOCKER_REGISTRY_LOGIN_NAME = 'onnxmlir' - */ /* Credentials defined in Jenkins */ JENKINS_REST_API_TOKEN = credentials('Jenkins-REST-API-Token') GITHUB_REPO_ACCESS_TOKEN = credentials('jenkins-buildbot-access-token') /* Settings for docker.io */ + /* DOCKER_REGISTRY_LOGIN_TOKEN = credentials('DOCKERHUB-ONNXMLIR-TOKEN') + */ /* Settings for ghcr.io */ - /* DOCKER_REGISTRY_LOGIN_TOKEN = credentials('GITHUB-ONNXMLIR-TOKEN') - */ /* Environment variables that depend on the arch */ JENKINS_REST_API_URL = sh(returnStdout: true, diff --git a/.buildbot/jenkins-build-llvm-project.py b/.buildbot/jenkins-build-llvm-project.py index de7b8f2ba7..06c5bb115e 100755 --- a/.buildbot/jenkins-build-llvm-project.py +++ b/.buildbot/jenkins-build-llvm-project.py @@ -7,7 +7,7 @@ LLVM_PROJECT_DOCKERFILE = "docker/Dockerfile.llvm-project" LLVM_PROJECT_GITHUB_URL = "https://api.github.com/repos/llvm/llvm-project" LLVM_PROJECT_BASE_IMAGE = { - "static": "ubuntu:jammy", + "static": "ghcr.io/onnxmlir/ubuntu:jammy", "shared": "registry.access.redhat.com/ubi8-minimal:latest", } LLVM_PROJECT_IMAGE = { @@ -187,7 +187,7 @@ def setup_per_pr_llvm_project(image_type, exp): ): if "stream" in line: # Keep track of the latest successful image layer - m = re.match("^\s*---> ([0-9a-f]+)$", line["stream"]) + m = re.match(r"^\s*---> ([0-9a-f]+)$", line["stream"]) if m: layer_sha256 = m.group(1) print(line["stream"], end="", flush=True) diff --git a/.buildbot/jenkins-build-onnx-mlir.py b/.buildbot/jenkins-build-onnx-mlir.py index f3c83c9d7a..4ff4d589d7 100755 --- a/.buildbot/jenkins-build-onnx-mlir.py +++ b/.buildbot/jenkins-build-onnx-mlir.py @@ -178,7 +178,7 @@ def build_per_pr_onnx_mlir(image_type, exp): ): if "stream" in line: # Keep track of the latest successful image layer - m = re.match("^\s*---> ([0-9a-f]+)$", line["stream"]) + m = re.match(r"^\s*---> ([0-9a-f]+)$", line["stream"]) if m: layer_sha256 = m.group(1) print(line["stream"], end="", flush=True) diff --git a/.buildbot/jenkins-watch-llvm-project.py b/.buildbot/jenkins-watch-llvm-project.py index abaae7d198..b6c6c396bb 100755 --- a/.buildbot/jenkins-watch-llvm-project.py +++ b/.buildbot/jenkins-watch-llvm-project.py @@ -336,7 +336,7 @@ def build_watch_image(repo, commit, dockerfile, base_image, image_repo, image_ta ): if "stream" in line: # Keep track of the latest successful image layer - m = re.match("^\s*---> ([0-9a-f]+)$", line["stream"]) + m = re.match(r"^\s*---> ([0-9a-f]+)$", line["stream"]) if m: layer_sha256 = m.group(1) print(line["stream"], end="", flush=True) diff --git a/.buildbot/jenkins_common.py b/.buildbot/jenkins_common.py index d2fbb1a1e2..eb36a16ae6 100755 --- a/.buildbot/jenkins_common.py +++ b/.buildbot/jenkins_common.py @@ -31,10 +31,11 @@ MEMORY_IN_GB = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") / (1024.0**3) NPROC = str(math.ceil(min(max(2, MEMORY_IN_GB / 8), os.cpu_count()))) -RETRY_LIMIT = 5 +RETRY_LIMIT = 10 READ_CHUNK_SIZE = 1024 * 1024 BASE_BRANCH = "main" +DOCKER_API_TIMEOUT = 3600 DOCKER_DIST_MANIFEST = "application/vnd.docker.distribution.manifest.v2+json" DOCKER_DIST_MANIFEST_LIST = "application/vnd.docker.distribution.manifest.list.v2+json" @@ -47,7 +48,7 @@ docker_registry_login_token = os.getenv("DOCKER_REGISTRY_LOGIN_TOKEN") docker_registry_token_access = os.getenv("DOCKER_REGISTRY_TOKEN_ACCESS") docker_rwlock = fasteners.InterProcessReaderWriterLock(docker_pushpull_rwlock) -docker_api = docker.APIClient(base_url=docker_daemon_socket) +docker_api = docker.APIClient(base_url=docker_daemon_socket, timeout=DOCKER_API_TIMEOUT) github_repo_access_token = os.getenv("GITHUB_REPO_ACCESS_TOKEN") github_repo_name = os.getenv("GITHUB_REPO_NAME") diff --git a/.github/workflows/macos-amd64-build.yml b/.github/workflows/macos-amd64-build.yml index 5268d6825d..78723aa499 100644 --- a/.github/workflows/macos-amd64-build.yml +++ b/.github/workflows/macos-amd64-build.yml @@ -1,6 +1,9 @@ name: GitHub Action MacOS amd64 -on: [push, pull_request] +on: + pull_request: + push: + branches: [ main, feature/onnx-to-tosa ] jobs: build: @@ -18,6 +21,12 @@ jobs: - name: install tools that are needed for compilation run: | brew install automake ninja pybind11 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + # Create symlinks to intercept compiler calls without modifying the cmake invocation + create-symlink: true + key: ${{ runner.os }}-ccache - name: install protobuf run: | cd ~/work @@ -46,6 +55,8 @@ jobs: python3 -m pip install -v . - name: build onnx-mlir run: | + # Disable stablehlo to ease bumping; we don't need it. + export EXTRA_CMAKE_ARGS="-DONNX_MLIR_ENABLE_STABLEHLO=OFF -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache" cd ~/work/onnx-mlir sh ~/work/onnx-mlir/onnx-mlir/utils/install-onnx-mlir.sh - name: build and run docs/doc_example tests @@ -57,6 +68,9 @@ jobs: cd ~/work/onnx-mlir sh ~/work/onnx-mlir/onnx-mlir/utils/check-unittest.sh - name: run onnx-mlir backend and numerical tests + # Those tests are not relevant to the work on the xilinx fork, but take + # 40 min. Don't run them on PRs. + if: github.event_name != 'pull_request' run: | cd ~/work/onnx-mlir sh ~/work/onnx-mlir/onnx-mlir/utils/check-onnx-backend-numerical.sh diff --git a/.github/workflows/ubuntu-build-intree.yml b/.github/workflows/ubuntu-build-intree.yml new file mode 100644 index 0000000000..67f857dbf6 --- /dev/null +++ b/.github/workflows/ubuntu-build-intree.yml @@ -0,0 +1,86 @@ +name: In-tree build + +on: + pull_request: + push: + branches: [ main, feature/onnx-to-tosa ] + +concurrency: + # Build every push. + # Only build the newest PR; cancel older builds of a PR + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build-intree: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: install tools that are needed for compilation + run: | + sudo apt-get update + sudo apt-get install -y gcc g++ cmake ninja-build + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1 + with: + # A full build seems to take ~ 250 MB. Leave a bit more room + # so we don't run out of cache space in the future. + max-size: 400M + key: sccache-intree + variant: sccache + create-symlink: true + + - name: install dependencies + run: | + utils/install-protobuf.sh + utils/install-venv.sh + + - name: clone llvm-project + run: sh utils/clone-mlir.sh + + - name: build + run: | + cmake llvm-project/llvm \ + -Bbuild \ + -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="onnx-mlir" \ + -DONNX_MLIR_ENABLE_STABLEHLO=OFF \ + -DLLVM_EXTERNAL_ONNX_MLIR_SOURCE_DIR=. \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DLLVM_BUILD_TOOLS=OFF \ + -DLLVM_BUILD_UTILS=OFF \ + -DLLVM_BUILD_RUNTIMES=OFF \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_LIBEDIT=OFF \ + -DLLVM_USE_LINKER=lld \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=sccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + + cmake --build build --target onnx-mlir + + - name: run LIT tests + run: | + export LIT_OPTS=-v + cmake --build build --target check-onnx-lit + + + - name: build and run docs/doc_example tests + run: | + cd .. + sh onnx-mlir/utils/check-doc-example.sh + + - name: build and run unit tests + run: | + cd .. + sh onnx-mlir/utils/check-unittest.sh diff --git a/.github/workflows/ubuntu-build.yml b/.github/workflows/ubuntu-build.yml new file mode 100644 index 0000000000..df9718c86c --- /dev/null +++ b/.github/workflows/ubuntu-build.yml @@ -0,0 +1,70 @@ +name: Out-of-tree build + +on: + pull_request: + push: + branches: [ main, feature/onnx-to-tosa ] + +concurrency: + # Build every push to main + # Only build the newest PR; cancel older builds of a PR + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: install tools that are needed for compilation + run: | + sudo apt-get update + sudo apt-get install -y gcc g++ cmake ninja-build + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1 + with: + # A full build of llvm, clang, lld, and lldb takes about 250MB + # of ccache space. There's not much reason to have more than this, + # because we usually won't need to save cache entries from older + # builds. Also, there is an overall 10GB cache limit, and each + # run creates a new cache entry so we want to ensure that we have + # enough cache space for all the tests to run at once and still + # fit under the 10 GB limit. + max-size: 500M + key: sccache + variant: sccache + create-symlink: true + + - name: install dependencies + run: | + utils/install-protobuf.sh + utils/install-venv.sh + + - name: clone & build MLIR + run: | + cd .. + export EXTRA_CMAKE_ARGS="-DLLVM_USE_LINKER=lld -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache" + sh onnx-mlir/utils/clone-mlir.sh + sh onnx-mlir/utils/build-mlir.sh + + - name: build onnx-mlir + run: | + cd .. + export EXTRA_CMAKE_ARGS="-DONNX_MLIR_ENABLE_STABLEHLO=OFF -DLLVM_USE_LINKER=lld -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache" + bash onnx-mlir/utils/install-onnx-mlir.sh + + - name: build and run docs/doc_example tests + run: | + cd .. + sh onnx-mlir/utils/check-doc-example.sh + + - name: build and run unit tests + run: | + cd .. + sh onnx-mlir/utils/check-unittest.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 0cc121f643..cbeac44cfb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,10 +8,12 @@ project(onnx-mlir) option(ONNX_MLIR_BUILD_TESTS "Build ONNX-MLIR test executables. If OFF, just generate build targets." ON) option(ONNX_MLIR_CCACHE_BUILD "Set to ON for a ccache enabled build." OFF) option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON) -option(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE "Enable ONNXConvTransposeOp decomposition." ON) option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF) option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON) option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON) +option(ONNX_MLIR_INSTALL_HEADERS "Install onnx-mlir headers" ON) +option(ONNX_MLIR_INSTALL_LIBS "Install onnx-mlir libraries" ON) +option(ONNX_MLIR_INSTALL_PYTHON_EXTENSIONS "Install onnx-mlir python bindings" ON) set(CMAKE_CXX_STANDARD 17) @@ -55,14 +57,20 @@ if (ONNX_MLIR_CCACHE_BUILD) endif() endif() -# Enable warnings as errors -# Leverage the imported LLVM_ENABLE_WERROR for compiler logic -set(LLVM_ENABLE_WERROR ${ONNX_MLIR_ENABLE_WERROR}) +if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + # Enable warnings as errors + # Leverage the imported LLVM_ENABLE_WERROR for compiler logic + set(LLVM_ENABLE_WERROR ${ONNX_MLIR_ENABLE_WERROR}) -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/bin) -set(CMAKE_INCLUDE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/include) + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/lib) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/bin) + set(CMAKE_INCLUDE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/$/include) + set(ONNX_MLIR_BUILD_INTREE OFF) +else() + message(STATUS "ONNX-MLIR in-tree build.") + set(ONNX_MLIR_BUILD_INTREE ON) +endif() set(ONNX_MLIR_SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(ONNX_MLIR_BIN_ROOT ${CMAKE_CURRENT_BINARY_DIR}) @@ -161,15 +169,19 @@ set(CMAKE_MESSAGE_LOG_LEVEL NOTICE) # Add third party subdirectories and define options appropriate to run their cmakes. set(pybind11_FIND_QUIETLY ON) -add_subdirectory(third_party/onnx) +add_subdirectory(third_party/onnx EXCLUDE_FROM_ALL) add_subdirectory(third_party/pybind11) -add_subdirectory(third_party/rapidcheck) +add_subdirectory(third_party/rapidcheck EXCLUDE_FROM_ALL) if (ONNX_MLIR_ENABLE_STABLEHLO) + if (ONNX_MLIR_BUILD_INTREE) + message(FATAL_ERROR "In tree builds don't support stablehlo yet. " + "Please pass -DONNX_MLIR_ENABLE_STABLEHLO=OFF to cmake or build out-of-tree.") + endif() add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL) endif() -if (NOT TARGET benchmark) +if (NOT ONNX_MLIR_BUILD_INTREE AND NOT TARGET benchmark) set(BENCHMARK_USE_BUNDLED_GTEST OFF) set(BENCHMARK_ENABLE_GTEST_TESTS OFF) set(BENCHMARK_ENABLE_TESTING OFF) @@ -183,8 +195,10 @@ endif() # compile flags updated via llvm_update_compile_flags, so we need to do that to # benchmark and rapidcheck as well, so that we can successfully link against them. # Otherwise, some of the flags for exceptions (among others) are not set correctly. -llvm_update_compile_flags(benchmark) -llvm_update_compile_flags(benchmark_main) +if (NOT ONNX_MLIR_BUILD_INTREE) + llvm_update_compile_flags(benchmark) + llvm_update_compile_flags(benchmark_main) +endif() llvm_update_compile_flags(rapidcheck) if (ONNX_MLIR_ENABLE_STABLEHLO) @@ -208,9 +222,6 @@ if (ONNX_MLIR_ENABLE_STABLEHLO) add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO) endif() -if (ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE) - add_compile_definitions(ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE) -endif() add_subdirectory(utils) add_subdirectory(include) diff --git a/CODING_PRACTICE.md b/CODING_PRACTICE.md new file mode 100644 index 0000000000..a2a4f2be62 --- /dev/null +++ b/CODING_PRACTICE.md @@ -0,0 +1,29 @@ + + +# Coding Practices + +This document contains coding practices to use when adding or updating code to the onnx-mlir project. + +## Practices + +* Use C++ style casting instead of C style when casting in cpp. + +For example, use C++ style casting: +``` + Value one = create.llvm.constant(llvmI64Ty, static_cast(1)); +``` + +Not, C style casting: +``` + Value one = create.llvm.constant(llvmI64Ty, (int64_t)1); +``` + +* Perform bitwise operations on unsigned types and not signed. +* Check the result of malloc() invocations. +* Check the result of input/output operations, such as fopen() and fprintf(). +* Use parentheses around parameter names in macro definitions. + +## Contributing + +We are welcoming contributions from the community. +Please consult the [CONTRIBUTING](CONTRIBUTING.md) page for help on how to proceed. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5b7f4b5c73..2e0b230951 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,6 +30,10 @@ A comprehensive list of documents is found [here](docs/DocumentList.md). * The Krnl Dialect is used to lower ONNX operators to MLIR affine. The Krnl Dialect is defined [here](docs/Dialects/krnl.md). * To update the internal documentation on our dialects when there are changes, please look for guidance [here](docs/ImportONNXDefs.md#update-your-operations-status). +## Coding practices for ONNX-MLIR + +* When adding or updating code, see [here](CODING_PRACTICE.md) for coding practices. + ## Testing and debugging ONNX-MLIR * To test new code, see [here](docs/Testing.md) for instructions. diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 0000000000..f92ca4d2b8 --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,6 @@ + + +# Governance + +The overall governance of the ONNX-MLIR project is described at https://github.com/onnx/onnx/blob/main/community/readme.md#onnx-open-governance. +The ONNX-MLIR project is under the purview of the Compilers Special Interest Group (Compilers SIG). diff --git a/MLIR.cmake b/MLIR.cmake index 1a66f6e41e..5a3a0f8b23 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -3,26 +3,36 @@ # Must unset LLVM_DIR in cache. Otherwise, when MLIR_DIR changes LLVM_DIR # won't change accordingly. unset(LLVM_DIR CACHE) -if (NOT DEFINED MLIR_DIR) - message(FATAL_ERROR "MLIR_DIR is not configured but it is required. " - "Set the cmake option MLIR_DIR, e.g.,\n" - " cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n" - ) -endif() +if (NOT ONNX_MLIR_BUILD_INTREE) + if (NOT DEFINED MLIR_DIR) + message(FATAL_ERROR "MLIR_DIR is not configured but it is required. " + "Set the cmake option MLIR_DIR, e.g.,\n" + " cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n" + ) + endif() -find_package(MLIR REQUIRED CONFIG) + find_package(MLIR REQUIRED CONFIG) -message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") -message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") + message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") + list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -include(TableGen) -include(AddLLVM) -include(AddMLIR) + include(TableGen) + include(AddLLVM) + include(AddMLIR) -include(HandleLLVMOptions) + include(HandleLLVMOptions) +else() + set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) + set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) + set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) + set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") + + set(MLIR_CMAKE_DIR ${MLIR_MAIN_SRC_DIR}/cmake/modules) + set(LLVM_CMAKE_DIR ${LLVM_MAIN_SRC_DIR}/cmake/modules) +endif() include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) @@ -201,7 +211,7 @@ function(add_onnx_mlir_library name) endif() endif() - if (NOT ARG_NO_INSTALL) + if (NOT ARG_NO_INSTALL AND ONNX_MLIR_INSTALL_LIBS) install(TARGETS ${name} ARCHIVE DESTINATION lib LIBRARY DESTINATION lib diff --git a/README.md b/README.md index d6dce7b082..93a2e32626 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,11 @@ Practically, each `git commit` needs to be signed, see [here](docs/Workflow.md#s The ONNX-MLIR code of conduct is described at https://onnx.ai/codeofconduct.html. +## Adopters + + +* IBM [zDLC compiler](https://github.com/IBM/zDLC) uses onnx-mlir technology to transform ONNX models into executable binary for [IBM Telum](https://www.ibm.com/z/telum) servers. + ## Projects related/using onnx-mlir * The [onnx-mlir-serving](https://github.com/IBM/onnx-mlir-serving) project implements a GRPC server written with C++ to serve onnx-mlir compiled models. Benefiting from C++ implementation, ONNX Serving has very low latency overhead and high throughput. diff --git a/docker/Dockerfile.llvm-project b/docker/Dockerfile.llvm-project index ab5e50d6ad..ecee3d1ea8 100644 --- a/docker/Dockerfile.llvm-project +++ b/docker/Dockerfile.llvm-project @@ -1,5 +1,5 @@ # By default, use ubuntu:jammy, remember to change Jenkins build script as well -ARG BASE_IMAGE="ubuntu:jammy" +ARG BASE_IMAGE="ghcr.io/onnxmlir/ubuntu:jammy" FROM ${BASE_IMAGE} # Label the image for various checking and cleanup @@ -47,11 +47,11 @@ RUN distro=$(cat /etc/os-release|grep -Po '(?<=^ID=").*(?=")|(?<=^ID=)[^"].*[^"] autoconf automake ca-certificates clang cmake diffutils \ file java-11-openjdk-devel java-11-openjdk-headless \ gcc gcc-c++ git libtool make ncurses-devel ninja-build \ - python39 python39-devel python39-numpy python39-pip \ - python39-setuptools python39-wheel tzdata-java zlib-devel && \ + python39 python39-devel python39-pip python39-setuptools \ + python39-wheel tzdata-java zlib-devel && \ # Use same versions as those in ubuntu:jammy pip3 install -q \ - Cython pytest==6.2.5 pytest-forked==1.4.0 \ + Cython pytest==6.2.5 numpy==1.21.5 pytest-forked==1.4.0 \ pytest-xdist==2.5.0 typing-extensions==3.10.0.2 && \ rm -rf /var/cache/dnf/* && \ echo -e "/usr/local/lib" > /etc/ld.so.conf.d/local.conf; \ diff --git a/docker/Dockerfile.onnx-mlir b/docker/Dockerfile.onnx-mlir index a9f8688e9a..170f3d63b8 100644 --- a/docker/Dockerfile.onnx-mlir +++ b/docker/Dockerfile.onnx-mlir @@ -26,7 +26,7 @@ RUN ONNX_ROOT=${WORK_DIR}/onnx-mlir/third_party/onnx \ ARG NPROC=4 ARG ACCEL=NNPA ARG TEST_NOFLOAT16 -ARG TEST_MCPU +ARG TEST_MARCH ARG KEEPSRC RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ @@ -53,20 +53,26 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ ([ "$(uname -m)" = "x86_64" ] && echo true || \ ([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \ # User image is built with SIMD (currently on s390x only) - && TEST_MCPU=${TEST_MCPU:-$([ "$(uname -m)" = "s390x" ] && echo z14 || \ + && TEST_MARCH=${TEST_MARCH:-$([ "$(uname -m)" = "s390x" ] && echo z16 || \ ([ "$(uname -m)" = "x86_64" ] && echo || \ ([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \ - && TEST_ARGS="-mcpu=${TEST_MCPU}" \ + && TEST_ARGS="-march=${TEST_MARCH}" \ && make check-docs \ && make check-unittest \ && make check-multiple-models \ && make NPROC=${NPROC} \ CTEST_PARALLEL_LEVEL=${NPROC} \ TEST_NOFLOAT16=${TEST_NOFLOAT16} \ - TEST_MCPU=${TEST_MCPU} \ + TEST_MARCH=${TEST_MARCH} \ TEST_ARGS="${TEST_ARGS}" \ -j${NPROC} \ check-onnx-backend-numerical \ + && if [ "${TEST_MARCH}" = "z16" ]; then \ + make NPROC=${NPROC} \ + CTEST_PARALLEL_LEVEL=${NPROC} \ + -j${NPROC} \ + check-onnx-backend-numerical-nnpa; \ + fi \ && make -j${NPROC} install && ldconfig \ # Clean up && cd ${WORK_DIR} \ diff --git a/docker/Dockerfile.onnx-mlir-dev b/docker/Dockerfile.onnx-mlir-dev index 574737c1a9..344fa273b5 100644 --- a/docker/Dockerfile.onnx-mlir-dev +++ b/docker/Dockerfile.onnx-mlir-dev @@ -20,7 +20,7 @@ RUN ONNX_ROOT=${WORK_DIR}/onnx-mlir/third_party/onnx \ ARG NPROC=4 ARG ACCEL=NNPA ARG TEST_NOFLOAT16 -ARG TEST_MCPU +ARG TEST_MARCH RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ && ONNX_MLIR_ROOT=${WORK_DIR}/onnx-mlir \ @@ -51,10 +51,10 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ ([ "$(uname -m)" = "x86_64" ] && echo true || \ ([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \ # Dev image is built without SIMD, placeholder for easy SIMD enablement - && TEST_MCPU=$([ "$(uname -m)" = "s390x" ] && echo || \ + && TEST_MARCH=$([ "$(uname -m)" = "s390x" ] && echo || \ ([ "$(uname -m)" = "x86_64" ] && echo || \ ([ "$(uname -m)" = "ppc64le" ] && echo || echo))) \ - && TEST_ARGS="-mcpu=${TEST_MCPU}" \ + && TEST_ARGS="-march=${TEST_MARCH}" \ && TEST_OPTLEVEL=0 \ && make check-docs \ && make check-unittest \ @@ -62,7 +62,7 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \ && make NPROC=${NPROC} \ CTEST_PARALLEL_LEVEL=${NPROC} \ TEST_NOFLOAT16=${TEST_NOFLOAT16} \ - TEST_MCPU=${TEST_MCPU} \ + TEST_MARCH=${TEST_MARCH} \ TEST_ARGS="${TEST_ARGS}" \ TEST_OPTLEVEL=${TEST_OPTLEVEL} \ -j${NPROC} \ diff --git a/docker/onnx-mlir.py b/docker/onnx-mlir.py index b26f2c49aa..d9d95db65e 100755 --- a/docker/onnx-mlir.py +++ b/docker/onnx-mlir.py @@ -33,7 +33,7 @@ import sys DOCKER_SOCKET = "/var/run/docker.sock" -ONNX_MLIR_IMAGE = "onnxmlir/onnx-mlir" +ONNX_MLIR_IMAGE = "ghcr.io/onnxmlir/onnx-mlir" KNOWN_INPUT_TYPE = (".onnx", ".json", ".mlir") mount_dirs = [] diff --git a/docs/BuildONNX.md b/docs/BuildONNX.md index 2e932df876..c47eefbce0 100644 --- a/docs/BuildONNX.md +++ b/docs/BuildONNX.md @@ -6,32 +6,40 @@ Backend tests are triggered by `make check-onnx-backend` in the build directory You will need to install python 3.x if its not default in your environment, and possibly set the cmake `PYTHON_EXECUTABLE` variable in your top cmake file. -You will also need `pybind11` which may need to be installed (mac: `brew install pybind11` or linux: `apt -y install python3-pybind11` for example) and you may need to indicate where to find the software (Mac, POWER, possibly other platforms: `export pybind11_DIR=`). Then install the `third_party/onnx` software (Mac: `pip install -e third_party/onnx`) typed in the top directory. +You will also need `pybind11` which may need to be installed (mac: `brew install pybind11` or linux: `apt -y install python3-pybind11` for example) and you may need to indicate where to find the software (Mac, POWER, possibly other platforms: `export pybind11_DIR=`). Then install the `third_party/onnx` software (Mac: `pip install third_party/onnx`) typed in the top directory. ## Upgrading ONNX in ONNX-MLIR - + Here are the steps taken to upgrade the ONNX version: 1. Create your own branch -2. cd into `third_party/onnx` and checkout the commit for the latest version of onnx (You can find the latest commit here: https://github.com/onnx/onnx/releases) +2. "cd" into `third_party/onnx` and checkout the commit for the latest version of onnx (You can find the latest commit here: https://github.com/onnx/onnx/releases) -3. pip uninstall onnx (remove older version) +3. "pip uninstall onnx" (remove older version) -4. In `onnx-mlir/` directory, pip install —user third_party/onnx (install onnx from the commit and not online version) +4. In `onnx-mlir/` directory, "pip install third_party/onnx" (install onnx from the commit and not online version) 5. Update `utils/gen_onnx_mlir.py` file with the correct version number 6. Build onnx in the `build/` directory using: set CMAKE_ARGS=-DONNX_USE_LITE_PROTO=ON -7. Run in the `build/` directory : make OMONNXOpsIncTranslation +7. Run in the `build/` directory : "make OMONNXOpsIncTranslation" -8. Run in `build/` directory : make onnx-mlir-docs +8. Run in `build/` directory : "make onnx-mlir-docs" +9. Run in `build/` directory : "make check-onnx-backend-case" -**Note: Please use `git add ` for files that might have been changed before doing a PR.** +10. Update the [new backend tests](https://github.com/onnx/onnx-mlir/blob/main/test/backend/all_test_names.txt) based on the results from `step 9` + +11. Update the [Opset documentation for cpu](https://github.com/onnx/onnx-mlir/blob/main/test/backend/inference_backend.py) and then issue the following command in the `build/` directory: "make onnx_mlir_supported_ops_cpu" -**Tip: Check that we have the right version of `third_party/onnx` committed by looking in the PR and clicking on the files that are changed. You should be redirected to the appropriate onnx commit hash.** +12. Update the [Opset documentation for NNPA](https://github.com/onnx/onnx-mlir/blob/main/test/backend/inference_backend.py) and then issue the following command in the `build/` directory: "make onnx_mlir_supported_ops_NNPA" + +13. Ensure the lit tests and backend tests pass successfully and then you are done! + + +**Note: Please use `git add ` for files that might have been changed before doing a PR.** ## Known issues diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index eb2014d748..7029b52e96 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. +cd llvm-project && git checkout e8be3bea2ce0ec51b614cd7eb7d5d3a1e56d9524 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index 77650910c1..61863953bf 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. +cd llvm-project && git checkout e8be3bea2ce0ec51b614cd7eb7d5d3a1e56d9524 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/DebuggingNumericalError.md b/docs/DebuggingNumericalError.md index 62b513ff33..0eeabcb505 100644 --- a/docs/DebuggingNumericalError.md +++ b/docs/DebuggingNumericalError.md @@ -65,7 +65,7 @@ optional arguments: ## Helper script to compare a model under two distinct compile option. Based on the above `utils/runONNXModel.py`, the `utils/checkONNXModel.py` allows a user to run a given model twice, under two distinct compile options, and compare its results. -This let a user simply test a new option, comparing the safe version of the compiler (e.g. `-O0` or `-O3`) with a more advanced version (e.g. `-O3` or `-O3 -march=x86-64`). Simply specify the compile options using the `--ref-compile-args` and `--test-compile-args` flags, a model using the `--model` flag, and possibly a `--shape-info` in presence of dynamic shape inputs. +This let a user simply test a new option, comparing the safe version of the compiler (e.g. `-O0` or `-O3`) with a more advanced version (e.g. `-O3` or `-O3 --march=x86-64`). Simply specify the compile options using the `--ref-compile-args` and `--test-compile-args` flags, a model using the `--model` flag, and possibly a `--shape-info` in presence of dynamic shape inputs. Full options are listed under the `--help` flag. ## Debugging the Code Generated for an Operator. diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md index c5c76450bd..797aef9ae6 100644 --- a/docs/Dialects/krnl.md +++ b/docs/Dialects/krnl.md @@ -191,6 +191,12 @@ Interfaces: `MemoryEffectOpInterface` | :-----: | ----------- | | `parameters` | variadic of any type +#### Results: + +| Result | Description | +| :----: | ----------- | +| `returnValue` | variadic of floating-point or integer + ### `krnl.copy_from_tile_buffer` (KrnlCopyFromBufferOp) _Copy from buffer._ @@ -323,6 +329,7 @@ _Indicate ONNX entry point_ The "krnl.entry_point" function indicates the main entry point of ONNX model. + ### `krnl.erf` (KrnlErfOp) _Krnl erf scalar operation_ @@ -929,6 +936,35 @@ Typically it is used for optional arguments used in KrnlCallop. | :----: | ----------- | | `none_val` | none type +### `krnl.parallel_clause` (KrnlParallelClauseOp) + +_Attach OpenMP clauses to an index varialbe_ + + +Syntax: + +``` +operation ::= `krnl.parallel_clause` `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)? + attr-dict `:` type($parallel_loop_index) +``` + +Attach OpenMP clauses to an index variable. That index variable +is used to uniquely associate a parallel loop with its clauses. + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
proc_bind::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `parallel_loop_index` | index +| `num_threads` | 32-bit signless integer + ### `krnl.parallel` (KrnlParallelOp) _Mark Krnl loops as parallel loops_ @@ -937,7 +973,7 @@ _Mark Krnl loops as parallel loops_ Syntax: ``` -operation ::= `krnl.parallel` `(` $loops `)` attr-dict `:` type($loops) +operation ::= `krnl.parallel` `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops) ``` Parallelize the specified loops. When multiple loop specifiers are passed @@ -945,15 +981,30 @@ as parameters, there loops can be parallelized as a collapsed loop. krnl.parallel should be placed as the last operator before krnl.iterate, Since we do not want to parallelize the loop until we interpret krnl.block, krnl.permute and krnl.unroll. + +Optionally, a value may specifiy the number of threads requested for the +parallel loop. A proc_bind string may also be specified; valid values are +"primary", "close", or "spread". Default values are used when not specified. + ``` krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop ``` +Traits: `AttrSizedOperandSegments` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
proc_bind::mlir::StringAttrstring attribute
+ #### Operands: | Operand | Description | | :-----: | ----------- | | `loops` | variadic of any type +| `num_threads` | 32-bit signless integer ### `krnl.permute` (KrnlPermuteOp) @@ -1149,6 +1200,25 @@ create a new memref inside the region and use it outside of the region. Traits: `AffineScope`, `NoTerminator`, `SingleBlock` +### `krnl.round_even` (KrnlRoundEvenOp) + +_Krnl round to nearest even operation_ + +Krnl round to nearest even operation. Accept scalar or vector float values. +Vector must be 1D of a size that is a multiple of the hardware vector size. + +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `in` | floating-point-like + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `out` | floating-point-like + ### `krnl.seqalloc` (KrnlSeqAllocOp) _Krnl create a sequence_ diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 3b074a59ec..e32d13e332 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -7,7 +7,7 @@ Absolute takes one input data (Tensor) and produces one output data (Tensor) where absolute value, y = abs(x), is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -31,7 +31,7 @@ _ONNX Acos operation_ Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -41,13 +41,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Acosh` (ONNXAcoshOp) @@ -55,7 +55,7 @@ _ONNX Acosh operation_ Calculates the hyperbolic arccosine of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -65,13 +65,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Adagrad` (ONNXAdagradOp) @@ -128,7 +128,7 @@ Compute one iteration of ADAGRAD, a stochastic gradient based optimization In that reference paper, this operator is a special case of the Figure 1's composite mirror descent update. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -223,7 +223,7 @@ Compute one iteration of Adam, a stochastic gradient based optimization If there are multiple inputs to be optimized, the pseudo code will be applied independently to each of them. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -264,7 +264,7 @@ This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; fo (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -292,7 +292,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<7>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -323,7 +323,7 @@ is selected if the max appears more than once in the input. Otherwise the index first occurrence is selected. The type of the output tensor is integer. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -362,7 +362,7 @@ is selected if the min appears more than once in the input. Otherwise the index first occurrence is selected. The type of the output tensor is integer. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -396,7 +396,7 @@ _ONNX ArrayFeatureExtractor operation_ Select elements of the input tensor based on the indices passed.
The indices are applied to the last axes of the tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -421,7 +421,7 @@ _ONNX Asin operation_ Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -431,13 +431,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Asinh` (ONNXAsinhOp) @@ -445,7 +445,7 @@ _ONNX Asinh operation_ Calculates the hyperbolic arcsine of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -455,13 +455,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Atan` (ONNXAtanOp) @@ -469,7 +469,7 @@ _ONNX Atan operation_ Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -479,13 +479,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Atanh` (ONNXAtanhOp) @@ -493,7 +493,7 @@ _ONNX Atanh operation_ Calculates the hyperbolic arctangent of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -503,13 +503,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.AveragePool` (ONNXAveragePoolOp) @@ -529,7 +529,7 @@ AveragePool consumes an input tensor X and applies average pooling across ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -548,7 +548,7 @@ AveragePool consumes an input tensor X and applies average pooling across The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -571,13 +571,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.BatchNormalizationInferenceMode` (ONNXBatchNormalizationInferenceModeOp) @@ -600,7 +600,7 @@ by an argument that is present) may also be simply omitted. This operation is not part of the standard and was added to assist onnx-mlir. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<15>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -674,7 +674,7 @@ For previous (depreciated) non-spatial cases, implementors are suggested to flatten the input shape to (N x C * D1 * D2 * ... * Dn) before a BatchNormalization Op. This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<15>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -707,6 +707,55 @@ Effects: `MemoryEffects::Effect{}` | `running_mean` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values or none type | `running_var` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values or none type +### `onnx.BatchNormalizationV9` (ONNXBatchNormalizationV9Op) + +_ONNX BatchNormalization operation_ + +Carries out batch normalization as described in the paper +https://arxiv.org/abs/1502.03167. Depending on the mode it is being run, +there are multiple cases for the number of outputs, which we list below: + +Output case #1: Y, mean, var, saved_mean, saved_var (training mode) +Output case #2: Y (test mode) + +For previous (depreciated) non-spatial cases, implementors are suggested +to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op. +This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<9>` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
epsilon::mlir::FloatAttr32-bit float attribute
momentum::mlir::FloatAttr32-bit float attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `scale` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `mean` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `var` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `out_mean` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `out_var` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `saved_mean` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `saved_var` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type + ### `onnx.Bernoulli` (ONNXBernoulliOp) _ONNX Bernoulli operation_ @@ -718,7 +767,7 @@ where an output of 1 is produced with probability p and an output of 0 is produc This operator is non-deterministic and may not produce the same values in different implementations (even if a seed is specified). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -736,13 +785,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 1-bit signless integer values +| `output` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 1-bit signless integer values ### `onnx.Binarizer` (ONNXBinarizerOp) @@ -750,7 +799,7 @@ _ONNX Binarizer operation_ Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -792,7 +841,7 @@ Because this operator supports Numpy-style broadcasting, X's and Y's shapes are not necessarily identical. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -827,7 +876,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -852,7 +901,7 @@ _ONNX BitwiseNot operation_ Returns the bitwise not of the input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -879,7 +928,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -907,7 +956,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -932,7 +981,7 @@ _ONNX BlackmanWindow operation_ Generates a Blackman window as described in the paper https://ieeexplore.ieee.org/document/1455106. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -966,7 +1015,7 @@ The operator casts the elements of a given input tensor (the first input) to the same data type as the elements of the second input tensor. See documentation of the Cast operator for further details. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1000,7 +1049,7 @@ Converts a map to a tensor.
The map key must be an int64 and the values will in ascending order based on this key.
The operator supports dense packing or sparse packing. If using sparse packing, the key cannot exceed the max_map-1 value. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1096,7 +1145,7 @@ The rules then become: | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | | else | RNE | RNE | RNE | RNE | -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1114,13 +1163,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values ### `onnx.CategoryMapper` (ONNXCategoryMapperOp) @@ -1135,7 +1184,7 @@ Converts strings to integers and vice versa.
If the string default value is set, it will convert integers to strings. If the int default value is set, it will convert strings to integers. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1171,7 +1220,7 @@ Ceil takes one input data (Tensor) and produces one output data (Tensor) where the ceil is, y = ceil(x), is applied to the tensor elementwise. If x is integral, +0, -0, NaN, or infinite, x itself is returned. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1201,7 +1250,7 @@ using formula: max(0,x) + min(0,alpha*(exp(x/alpha)-1)) ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<12>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1218,13 +1267,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 32-bit float values +| `X` | tensor of 32-bit float values or tensor of bfloat16 type values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 32-bit float values +| `Y` | tensor of 32-bit float values or tensor of bfloat16 type values ### `onnx.CenterCropPad` (ONNXCenterCropPadOp) @@ -1239,7 +1288,7 @@ If the input dimensions are bigger than the crop shape, a centered cropping wind If the input dimensions are smaller than the crop shape, the input is padded on each side equally, so that the input is centered in the output. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1273,7 +1322,7 @@ Clip operator limits the given input within an interval. The interval is specified by the inputs 'min' and 'max'. They default to numeric_limits::lowest() and numeric_limits::max(), respectively. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1301,7 +1350,7 @@ Clip operator limits the given input within an interval. The interval is specified by the inputs 'min' and 'max'. They default to numeric_limits::lowest() and numeric_limits::max(), respectively. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1329,7 +1378,7 @@ Clip operator limits the given input within an interval. The interval is specified by the inputs 'min' and 'max'. They default to numeric_limits::lowest() and numeric_limits::max(), respectively. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<12>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1357,7 +1406,7 @@ Clip operator limits the given input within an interval. The interval is specified with arguments 'min' and 'max'. They default to numeric_limits::lowest() and numeric_limits::max() respectively. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<6>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1398,7 +1447,7 @@ NOTE: convolution formulas, it is required as input for more advanced scenarios as explained at PyTorch's implementation (https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Col2Im.cpp#L10) -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1436,7 +1485,7 @@ Selects slices from an input tensor along a given axis where condition evaluates Compress behaves like numpy.compress: https://docs.scipy.org/doc/numpy/reference/generated/numpy.compress.html -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1471,7 +1520,7 @@ All input tensors must have the same shape, except for the dimension size of the By default 'new_axis' is 0, the behavior is similar to numpy.concatenate. When 'new_axis' is 1, the behavior is similar to numpy.stack. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1503,7 +1552,7 @@ _ONNX Concat operation_ Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1574,7 +1623,7 @@ _ONNX ConstantOfShape operation_ Generate a tensor with given value and shape. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1606,7 +1655,7 @@ _ONNX Constant operation_ This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, or value_* must be specified. -Traits: `AlwaysSpeculatableImplTrait`, `ConstantLike` +Traits: `AlwaysSpeculatableImplTrait`, `ConstantLike`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1639,7 +1688,7 @@ _ONNX ConvInteger operation_ The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point, and computes the output. The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1679,7 +1728,7 @@ _ONNX Conv operation_ The convolution operator consumes an input tensor and a filter, and computes the output. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1701,15 +1750,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.ConvTranspose` (ONNXConvTransposeOp) @@ -1730,7 +1779,7 @@ output_shape can also be explicitly specified in which case pads values are auto -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1754,15 +1803,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Cos` (ONNXCosOp) @@ -1770,7 +1819,7 @@ _ONNX Cos operation_ Calculates the cosine of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1780,13 +1829,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Cosh` (ONNXCoshOp) @@ -1794,7 +1843,7 @@ _ONNX Cosh operation_ Calculates the hyperbolic cosine of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1804,13 +1853,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.CumSum` (ONNXCumSumOp) @@ -1837,7 +1886,7 @@ output = [5, 3, 0] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -1970,7 +2019,7 @@ The actual shape of the output is specified in the \"output\" section. Reference: https://docs.scipy.org/doc/scipy/tutorial/fft.html -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2004,7 +2053,7 @@ _ONNX DFT operation_ Computes the discrete Fourier transform of input. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2039,7 +2088,7 @@ _ONNX DeformConv operation_ Performs deformable convolution as described in https://arxiv.org/abs/1703.06211 and https://arxiv.org/abs/1811.11168. This operator specification supports the general N-D case. Note that most common use cases have 2D or 3D data. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2061,17 +2110,17 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `offset` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `mask` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `offset` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `mask` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.DepthToSpace` (ONNXDepthToSpaceOp) @@ -2101,7 +2150,7 @@ tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2139,7 +2188,7 @@ there's no zero point (zero point is supposed to be 0). `zero-point` is usually not used in the case of float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz quantization, but the dequantization formula remains the same for consistency and 'x_scale' still determines the output type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2176,7 +2225,7 @@ and the inner-most 2 dimensions form square matrices. The output is a tensor of shape `[*]`, containing the determinants of all input submatrices. e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2186,13 +2235,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.DictVectorizer` (ONNXDictVectorizerOp) @@ -2211,7 +2260,7 @@ Uses an index mapping to convert a dictionary to an array.
then an input of ``{\"a\": 4, \"c\": 8}`` will produce an output of ``[4, 8, 0, 0]``. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2322,7 +2371,7 @@ This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; fo (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2358,7 +2407,7 @@ scale = 1. / (1. - ratio). ``` This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2375,15 +2424,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `data` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values -| `ratio` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `data` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `ratio` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or none type | `training_mode` | tensor of 1-bit signless integer values or none type #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values | `mask` | tensor of 1-bit signless integer values or none type ### `onnx.DynamicQuantizeLinear` (ONNXDynamicQuantizeLinearOp) @@ -2418,7 +2467,7 @@ y = saturate (round (x / y_scale) + y_zero_point) * for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported. * rounding to nearest ties to even. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2468,7 +2517,7 @@ Specifically, every occurrence of ellipsis in the equation must represent the sa The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the beginning of the output. The equation string may contain space (U+0020) character. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<12>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2502,7 +2551,7 @@ Elu takes one input data (Tensor) and produces one output data 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2519,13 +2568,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.EntryPoint` (ONNXEntryPointOp) @@ -2551,7 +2600,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>`, `SameOperandsElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2576,7 +2625,7 @@ _ONNX Erf operation_ Computes the error function of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2600,7 +2649,7 @@ _ONNX Exp operation_ Calculates the exponential of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2610,13 +2659,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Expand` (ONNXExpandOp) @@ -2631,7 +2680,7 @@ but the major difference is numpy.broadcast_to() does not allow shape to be smal It is possible that the output.shape is not equal to shape, when some dimensions in shape is equal to 1, or the shape.ndim < input.shape.ndim. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2662,9 +2711,9 @@ is populated with ones, but attribute 'k' can be used to populate upper or lower The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message and be valid as an output type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` -Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` Effects: `MemoryEffects::Effect{}` @@ -2680,13 +2729,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values +| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 1-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values +| `output` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 1-bit signless integer values ### `onnx.FeatureVectorizer` (ONNXFeatureVectorizerOp) @@ -2697,7 +2746,7 @@ Concatenates input tensors into one continuous output.
Inputs are copied to the output maintaining the order of the input arguments.
All inputs must be integers or floats, while the output will be all floating point values. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2730,7 +2779,7 @@ Flattens the input tensor into a 2D matrix. If input tensor has shape (d_0, d_1, ... d_n) then the output will have shape (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2763,7 +2812,7 @@ Floor takes one input data (Tensor) and produces one output data (Tensor) where the floor is, y = floor(x), is applied to the tensor elementwise. If x is integral, +0, -0, NaN, or infinite, x itself is returned. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2833,7 +2882,7 @@ Equations (Default: f=Sigmoid, g=Tanh): * Ht = (1 - zt) (.) ht + zt (.) Ht-1 This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -2857,19 +2906,19 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `R` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `R` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type | `sequence_lens` | tensor of 32-bit signless integer values or none type -| `initial_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `initial_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `Y_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type ### `onnx.GatherElements` (ONNXGatherElementsOp) @@ -2927,7 +2976,7 @@ output = [ ] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3042,7 +3091,7 @@ indices = [[1],[0]] # indices_shape = [2, 1] output = [[2,3],[4,5]] # output_shape = [2, 2] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3121,7 +3170,7 @@ output = [ ] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3159,7 +3208,7 @@ $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and appli to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3201,7 +3250,7 @@ computation if attribute transA is non-zero, same for B and transB. This operator supports **unidirectional broadcasting** (tensor C should be unidirectional broadcastable to tensor A * B); for more details please check [the doc](Broadcasting.md). This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3239,7 +3288,7 @@ GlobalAveragePool consumes an input tensor X and applies average pooling across the values in the same channel. This is equivalent to AveragePool with kernel size equal to the spatial dimension of input tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3249,13 +3298,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.GlobalLpPool` (ONNXGlobalLpPoolOp) @@ -3265,7 +3314,7 @@ GlobalLpPool consumes an input tensor X and applies lp pool pooling across the values in the same channel. This is equivalent to LpPool with kernel size equal to the spatial dimension of input tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<2>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3282,13 +3331,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.GlobalMaxPool` (ONNXGlobalMaxPoolOp) @@ -3298,7 +3347,7 @@ GlobalMaxPool consumes an input tensor X and applies max pooling across the values in the same channel. This is equivalent to MaxPool with kernel size equal to the spatial dimension of input tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3308,13 +3357,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Gradient` (ONNXGradientOp) @@ -3444,7 +3493,7 @@ forward pass can be reused if the gradient is computed via reverse-mode auto-differentiation. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3480,7 +3529,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3508,7 +3557,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>`, `SameOperandsElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3531,6 +3580,57 @@ Effects: `MemoryEffects::Effect{}` _ONNX GridSample operation_ +Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. +For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), +the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), +the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). +More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), +the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + +The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). +The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values +at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) +and a padding mode (for `grid` positions falling outside the 2-dimensional image). + +For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. +They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + +The GridSample operator is often used in doing grid generator and sampler in the +[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). +See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
align_corners::mlir::IntegerAttr64-bit signed integer attribute
mode::mlir::StringAttrstring attribute
padding_mode::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `grid` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values + +### `onnx.GridSampleV16` (ONNXGridSampleV16Op) + +_ONNX GridSample operation_ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. Currently, only spatial (4-D) inputs are supported. For input `X` with shape (N, C, H, W) and `grid` with shape (N, H_out, W_out, 2), the output `Y` will have shape (N, C, H_out, W_out). @@ -3545,7 +3645,58 @@ They are used to interpolate output values of `Y[N, C, H_out, W_out]`. The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
align_corners::mlir::IntegerAttr64-bit signed integer attribute
mode::mlir::StringAttrstring attribute
padding_mode::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `grid` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values + +### `onnx.GridSampleV20` (ONNXGridSampleV20Op) + +_ONNX GridSample operation_ + +Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. +For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), +the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), +the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). +More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), +the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + +The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). +The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values +at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) +and a padding mode (for `grid` positions falling outside the 2-dimensional image). + +For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. +They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + +The GridSample operator is often used in doing grid generator and sampler in the +[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). +See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3589,11 +3740,68 @@ where the mean and variance are computed per instance per group of channels, and groups `num_groups` should be divisible by the number of channels so that there are an equal number of channels per group. +The overall computation has two stages: the first stage normalizes the elements to +have zero mean and unit variance for each instance in each group, and the second +stage scales and shifts the results of the first stage. The floating-point precision +used in the first stage is determined by the `stash_type` attribute. For example, +if `stash_type` is 1, the operator casts all input variables to 32-bit float, +performs the computation, and finally casts the normalized results back to the +original type of `X`. The second stage does not depend on `stash_type`. + +When the number of groups is the same as the number of channels, this operator is +equivalent to InstanceNormalization. When there is only one group, this operator +is equivalent to LayerNormalization. + +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<21>` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
epsilon::mlir::FloatAttr32-bit float attribute
num_groups::mlir::IntegerAttr64-bit signed integer attribute
stash_type::mlir::IntegerAttr64-bit signed integer attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `scale` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `bias` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values + +### `onnx.GroupNormalizationV18` (ONNXGroupNormalizationV18Op) + +_ONNX GroupNormalization operation_ + +A GroupNormalization function. Carries out group normalization as described in +the paper https://arxiv.org/abs/1803.08494 + +This operator transforms input according to +``` +y = scale * (x - mean) / sqrt(variance + epsilon) + bias, +``` +where the mean and variance are computed per instance per group of channels, and +`scale` and `bias` should be specified for each group of channels. The number of +groups `num_groups` should be divisible by the number of channels so that there are +an equal number of channels per group. + When the number of groups is the same as the number of channels, this operator is equivalent to InstanceNormalization. When there is only one group, this operator is equivalent to LayerNormalization. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3627,7 +3835,7 @@ _ONNX HammingWindow operation_ Generates a Hamming window as described in the paper https://ieeexplore.ieee.org/document/1455106. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3659,7 +3867,7 @@ _ONNX HannWindow operation_ Generates a Hann window as described in the paper https://ieeexplore.ieee.org/document/1455106. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3693,7 +3901,7 @@ HardSigmoid takes one input data (Tensor) and produces one output data (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)), is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3711,13 +3919,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.HardSwish` (ONNXHardSwishOp) @@ -3727,7 +3935,7 @@ HardSwish takes one input data (Tensor) and produces one output data (Tensor< the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x), where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3737,13 +3945,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Hardmax` (ONNXHardmaxOp) @@ -3757,7 +3965,7 @@ The \"axis\" attribute indicates the dimension along which Hardmax will be performed. The output tensor has the same shape and contains the Hardmax values of the corresponding input. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3788,7 +3996,7 @@ _ONNX Identity operation_ Identity operator -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3812,7 +4020,7 @@ _ONNX If operation_ If conditional -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `HasOnnxSubgraphOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3843,7 +4051,7 @@ Replaces inputs that equal one value with another, leaving all other elements al which one depends on whether floats or integers are being processed.
The imputed_value attribute length can be 1 element, or it can have one element per input feature.
In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3882,7 +4090,7 @@ y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance are computed per instance per channel. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3899,15 +4107,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `scale` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `scale` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.IsInf` (ONNXIsInfOp) @@ -3915,7 +4123,7 @@ _ONNX IsInf operation_ Map infinity to true and other values to false. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3947,7 +4155,7 @@ _ONNX IsNaN operation_ Returns which elements of the input are NaN. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -3980,7 +4188,7 @@ where `max(0, c - floor((size - 1) / 2)) <= i <= min(C - 1, c + ceil((size - 1) `Y[n, c, d1, ..., dk] = X[n, c, d1, ..., dk] / (bias + alpha / size * square_sum[n, c, d1, ..., dk] ) ^ beta` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4063,7 +4271,7 @@ Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): * Ht = ot (.) h(Ct) This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4087,22 +4295,22 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `R` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `R` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type | `sequence_lens` | tensor of 32-bit signless integer values or none type -| `initial_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `initial_c` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `P` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `initial_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `initial_c` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `P` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `Y_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `Y_c` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y_c` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type ### `onnx.LabelEncoder` (ONNXLabelEncoderOp) @@ -4126,7 +4334,7 @@ Maps each element in the input tensor to another value.
For key look-up, bit-wise comparison is used so even a float NaN can be mapped to a value in 'values_*' attribute.
-Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<2>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4201,9 +4409,11 @@ This is layer normalization defined in ONNX as function. Let `d[i]` indicate the i-th dimension of `X`. If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`, the shape of `Mean` and `InvStdDev` is `[d[0], ..., d[axis-1], 1, ..., 1]`. - `Y` and `X` have the same shape. + `Y` and `X` have the same shape. This operator supports unidirectional broadcasting + (tensors `Scale` and `B` should be unidirectional broadcastable to tensor `X`); + for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4285,7 +4495,7 @@ LeakyRelu takes input data (Tensor) and an argument alpha, and produces one output data (Tensor) where the function `f(x) = alpha * x for x < 0`, `f(x) = x for x >= 0`, is applied to the data tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4319,7 +4529,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4347,7 +4557,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>`, `SameOperandsElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4372,7 +4582,7 @@ _ONNX LinearClassifier operation_ Linear classifier -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4414,7 +4624,7 @@ Generalized linear regression evaluation.
The coefficients array is of length n, and the coefficients for each target are contiguous. Intercepts are optional but if provided must match the number of targets. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4448,7 +4658,7 @@ _ONNX Log operation_ Calculates the natural log of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4478,7 +4688,7 @@ The \"axis\" attribute indicates the dimension along which LogSoftmax will be performed. The output tensor has the same shape and contains the LogSoftmax values of the corresponding input. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4643,7 +4853,7 @@ point-wise operators (e.g. dropout, residual connections, linear layer). The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `HasOnnxSubgraphOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4669,7 +4879,7 @@ _ONNX LpNormalization operation_ Given a matrix, apply Lp-normalization along the provided axis. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4687,13 +4897,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.LpPool` (ONNXLpPoolOp) @@ -4723,7 +4933,7 @@ LpPool consumes an input tensor X and applies Lp pooling across pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + {kernelSpatialShape} - input_spatial_shape[i] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4746,22 +4956,22 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.MatMulInteger` (ONNXMatMulIntegerOp) _ONNX MatMulInteger operation_ -Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. +Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4786,9 +4996,9 @@ Effects: `MemoryEffects::Effect{}` _ONNX MatMul operation_ -Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html +Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4815,7 +5025,7 @@ Element-wise max of each of the input tensors (with Numpy-style broadcasting sup All inputs and outputs must have the same data type. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4851,7 +5061,7 @@ MaxPool consumes an input tensor X and applies max pooling across ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -4870,7 +5080,7 @@ MaxPool consumes an input tensor X and applies max pooling across The output of each pooling window is maximum number of elements exclude pad. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4893,13 +5103,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 8-bit unsigned integer values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 8-bit unsigned integer values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 8-bit unsigned integer values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 8-bit unsigned integer values | `Indices` | tensor of 64-bit signless integer values or none type ### `onnx.MaxPoolSingleOut` (ONNXMaxPoolSingleOutOp) @@ -4911,7 +5121,7 @@ See ONNXMaxPoolOp for a full description of the MaxPool semantics. This operation is not part of the standard and was added to assist onnx-mlir. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<12>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4950,7 +5160,7 @@ ROI max pool consumes an input tensor X and region of interests (RoIs) to apply max pooling across each RoI, to produce output 4-D tensor of shape (num_rois, channels, pooled_shape[0], pooled_shape[1]). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -4968,14 +5178,14 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `rois` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `rois` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.MaxUnpool` (ONNXMaxUnpoolOp) @@ -5000,7 +5210,7 @@ In addition to the inputs, MaxUnpool takes three attributes, namely kernel_shape which define the exact unpooling op. The attributes typically have the same values as the corresponding pooling op that the unpooling op is trying to invert. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5019,7 +5229,7 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values | `I` | tensor of 64-bit signless integer values | `output_shape` | tensor of 64-bit signless integer values or none type @@ -5027,7 +5237,7 @@ Effects: `MemoryEffects::Effect{}` | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Mean` (ONNXMeanOp) @@ -5037,7 +5247,7 @@ Element-wise mean of each of the input tensors (with Numpy-style broadcasting su All inputs and outputs must have the same data type. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5062,7 +5272,7 @@ _ONNX MeanVarianceNormalization operation_ A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: `(X-EX)/sqrt(E(X-EX)^2)` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5100,7 +5310,7 @@ In the returned matrix, all the triangles (filterbanks) have a peak value of 1.0 The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of linear scale spectrum values (e.g. STFT magnitudes) to generate a \"mel spectrogram\" M of shape [frames, num_mel_bins]. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5137,7 +5347,7 @@ Element-wise min of each of the input tensors (with Numpy-style broadcasting sup All inputs and outputs must have the same data type. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5167,7 +5377,7 @@ Perform the linear unit element-wise on the input tensor X using formula: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5177,13 +5387,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Mod` (ONNXModOp) @@ -5203,7 +5413,7 @@ Performs element-wise binary modulus (with Numpy-style broadcasting support). This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5294,7 +5504,7 @@ Compute one iteration of stochastic gradient update with momentum. concatenation of \"X_1\" and \"X_2\" (of course, their gradient and accumulate gradient should be concatenated too) and then our pseudo code becomes applicable. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5334,7 +5544,7 @@ This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; fo (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5360,7 +5570,7 @@ _ONNX Multinomial operation_ Generate a tensor of samples from a multinomial distribution according to the probabilities of each of the possible outcomes. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5379,7 +5589,7 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: @@ -5395,7 +5605,7 @@ Neg takes one input data (Tensor) and produces one output data (Tensor) where each element flipped sign, y = -x, is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5519,7 +5729,7 @@ loss = np.sum(loss) / weight_total // -1.57 ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5537,15 +5747,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values | `target` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values -| `weight` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `weight` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `loss` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `loss` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.NonMaxSuppression` (ONNXNonMaxSuppressionOp) @@ -5559,7 +5769,7 @@ result in the same boxes being selected by the algorithm. The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5598,7 +5808,7 @@ Returns the indices of the elements that are non-zero https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html, but for scalar input, NonZero produces output shape (0, N) instead of (1, N), which is different from Numpy's behavior. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5665,7 +5875,7 @@ Normalize the input. There are three normalization modes, which have the corres For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row of the batch is normalized independently. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5696,7 +5906,7 @@ _ONNX Not operation_ Returns the negation of the input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5727,7 +5937,7 @@ Replace each input element with an array of ones and zeros, where a single If the input is a tensor of float, int32, or double, the data will be cast to integers and the cats_int64s category list will be used for the lookups. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5778,7 +5988,7 @@ Produces a one-hot tensor based on inputs. output[i, j, k, input[i, j, k]] = 1 for all i, j, k and 0 otherwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5813,7 +6023,7 @@ If the input is a tensor or sequence type, it returns the input. If the input is an optional type, it outputs the element in the input. It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5839,7 +6049,7 @@ Returns true if (1) the input is an optional-type and contains an element, or, (2) the input is a tensor or sequence type. If the input is not provided or is an empty optional-type, this op returns false. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5864,7 +6074,7 @@ _ONNX Optional operation_ Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute, or a non-empty value containing the input element. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<15>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5898,7 +6108,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<7>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -5926,7 +6136,7 @@ output data (Tensor) where the function `f(x) = slope * x for x < 0`, `f(x) = x for x >= 0`., is applied to the data tensor elementwise. This operator supports **unidirectional broadcasting** (tensor slope should be unidirectional broadcastable to input tensor X); for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6050,7 +6260,7 @@ output = [ ] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6158,7 +6368,7 @@ Example 3 (`edge` mode): ] -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6265,7 +6475,7 @@ Example 3 (`edge` mode): ] -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6372,7 +6582,7 @@ output = [ ] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6421,7 +6631,7 @@ Example: ], ] -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<2>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6457,7 +6667,7 @@ produces one output data (Tensor) where the function `f(x) = x^exponent`, is applied to the data tensor elementwise. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<15>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6510,7 +6720,7 @@ Each input or output and its related zero point must have same type. When bias is present it must be quantized using scale = input scale * weight scale and zero point as 0. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6552,7 +6762,7 @@ Effects: `MemoryEffects::Effect{}` _ONNX QLinearMatMul operation_ -Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. +Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. @@ -6564,7 +6774,7 @@ for per column quantization. If the input is N-D tensor with shape [D1, D2, M, K have shape [D1, D2, M, 1] for per row quantization and shape [D1, D2, 1, K] for per column quantization. Production must never overflow, and accumulation may overflow if and only if in 32 bits. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6603,7 +6813,7 @@ For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipe but the quantization formula remains the same for consistency and the type of the attribute 'y_zero_point' still determines the quantization type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6748,7 +6958,7 @@ Equations (Default: f=Tanh): * Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6771,19 +6981,19 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `W` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `R` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `B` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `W` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `R` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `B` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type | `sequence_lens` | tensor of 32-bit signless integer values or none type -| `initial_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `initial_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type -| `Y_h` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type +| `Y_h` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or none type ### `onnx.RandomNormalLike` (ONNXRandomNormalLikeOp) @@ -6797,7 +7007,7 @@ The data type is specified by the 'dtype' argument, or copied from the input ten The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message, and be valid as an output type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6817,13 +7027,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.RandomNormal` (ONNXRandomNormalOp) @@ -6837,7 +7047,7 @@ The data type is specified by the 'dtype' argument. The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6858,7 +7068,7 @@ Effects: `MemoryEffects::Effect{}` | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.RandomUniformLike` (ONNXRandomUniformLikeOp) @@ -6872,7 +7082,7 @@ The data type is specified by the 'dtype' argument, or copied from the input ten The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message and be valid as an output type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6892,13 +7102,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values +| `input` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of string type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.RandomUniform` (ONNXRandomUniformOp) @@ -6911,7 +7121,7 @@ The data type is specified by the 'dtype' argument. The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6932,7 +7142,7 @@ Effects: `MemoryEffects::Effect{}` | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Range` (ONNXRangeOp) @@ -6969,7 +7179,7 @@ Inputs: start = 10, limit = 4, delta = -2 Output: [10, 8, 6] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -6997,7 +7207,7 @@ Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is, y = 1/x, is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7028,7 +7238,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7068,7 +7278,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7107,7 +7317,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7147,7 +7357,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7186,7 +7396,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7226,7 +7436,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7265,7 +7475,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7305,7 +7515,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7346,7 +7556,7 @@ If the input data type is Boolean, the comparison should consider `False < True` The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7386,7 +7596,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7425,7 +7635,7 @@ valid. Reduction over an empty set of values yields minus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7465,7 +7675,7 @@ valid. Reduction over an empty set of values yields undefined. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7505,7 +7715,7 @@ valid. Reduction over an empty set of values yields undefined. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7546,7 +7756,7 @@ If the input data type is Boolean, the comparison should consider `False < True` The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<20>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7586,7 +7796,7 @@ valid. Reduction over an empty set of values yields plus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7625,7 +7835,7 @@ valid. Reduction over an empty set of values yields plus infinity (if supported The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7665,7 +7875,7 @@ valid. Reduction over an empty set of values yields 1. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7705,7 +7915,7 @@ valid. Reduction over an empty set of values yields 1. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7744,7 +7954,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7784,7 +7994,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7824,7 +8034,7 @@ valid. Reduction over an empty set of values yields 0. The above behavior is similar to numpy, with the exception that numpy defaults `keepdims` to `False` instead of `True`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7861,7 +8071,7 @@ the resulted tensor have the reduced dimension pruned. The above behavior is similar to numpy, with the exception that numpy defaults keepdims to False instead of True. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7895,7 +8105,7 @@ Relu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = max(0, x), is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7931,7 +8141,7 @@ If the attribute 'allowzero' is set, it is invalid for the specified shape to contain both a zero value and -1, as the value of the dimension corresponding to -1 cannot be determined uniquely. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -7968,7 +8178,7 @@ output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) ``` if input \\"sizes\\" is not specified. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8012,7 +8222,7 @@ Resize the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * scale). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8046,7 +8256,7 @@ Resize the input tensor. In general, it calculates every value in the output ten Each dimension value of the output tensor is: output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8087,7 +8297,7 @@ Resize the input tensor. In general, it calculates every value in the output ten Each dimension value of the output tensor is: output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \\"sizes\\" is not specified. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8129,7 +8339,7 @@ Each dimension value of the output tensor is:
`output_dimension = floor(input_dimension * (roi_end - roi_start) * scale)`
if input \\"sizes\\" is not specified. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8236,7 +8446,7 @@ Example 2: [10.0, 9.0, 8.0, 11.0], [15.0, 14.0, 13.0, 12.0]] -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8279,7 +8489,7 @@ map and from feature map into RoI feature; in each ROI bin, the value of the sampled locations are computed directly through bilinear interpolation. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8301,15 +8511,15 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values -| `rois` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `rois` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values | `batch_indices` | tensor of 64-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Round` (ONNXRoundOp) @@ -8330,7 +8540,7 @@ round([1.5]) = [2.0] round([-4.5]) = [-4.0] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8340,13 +8550,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.STFT` (ONNXSTFTOp) @@ -8354,7 +8564,7 @@ _ONNX STFT operation_ Computes the Short-time Fourier Transform of the signal. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8388,7 +8598,7 @@ _ONNX SVMClassifier operation_ Support Vector Machine classifier -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8430,7 +8640,7 @@ _ONNX SVMRegressor operation_ Support Vector Machine regression prediction and one-class SVM anomaly detection. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8468,7 +8678,7 @@ _ONNX Scaler operation_ Rescale input data, for example to standardize features by removing the mean and scaling to unit variance. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8620,7 +8830,7 @@ values are computed in the outer graph, they need to be passed in as extra state } -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `HasOnnxSubgraphOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ResultTypeInferenceOpInterface`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8716,7 +8926,7 @@ axis = 1 output = [[1.0, 1.1, 3.0, 2.1, 5.0]] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8824,7 +9034,7 @@ output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8909,7 +9119,7 @@ Example 2: output = [[1.0, 1.1, 3.0, 2.1, 5.0]] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8945,7 +9155,7 @@ Selu takes one input data (Tensor) and produces one output data `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`, is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -8963,13 +9173,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.SequenceAt` (ONNXSequenceAtOp) @@ -8979,7 +9189,7 @@ Outputs a tensor copy from the tensor at 'position' in 'input_sequence'. Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'. Negative value means counting positions from the back. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9005,7 +9215,7 @@ _ONNX SequenceConstruct operation_ Construct a tensor sequence containing 'inputs' tensors. All tensors in 'inputs' must have the same data type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9029,7 +9239,7 @@ _ONNX SequenceEmpty operation_ Construct an empty tensor sequence, with given data type. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9057,7 +9267,7 @@ Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of te Negative value means counting positions from the back. 'position' is optional, by default it erases the last tensor from 'input_sequence'. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9086,7 +9296,7 @@ Accepted range for 'position' is in `[-n, n]`, where `n` is the number of tensor Negative value means counting positions from the back. 'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9112,7 +9322,7 @@ _ONNX SequenceLength operation_ Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9149,7 +9359,7 @@ the input. This operator assumes that processing each sample is independent and could executed in parallel or in any order. Users cannot expect any specific ordering in which each subgraph is computed. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<17>` Interfaces: `ConditionallySpeculatable`, `HasOnnxSubgraphOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9211,7 +9421,7 @@ end: 2 Output: [3] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9306,7 +9516,7 @@ having same datatype and shape with input. It has two attributes, lambd and bias. The formula of this operator is: If x < -lambd, y = x + bias; If x > lambd, y = x - bias; Otherwise, y = 0. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<9>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9340,7 +9550,7 @@ Sigmoid takes one input data (Tensor) and produces one output data (Tensor) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9365,7 +9575,7 @@ _ONNX Sign operation_ Calculate the sign of the given input tensor element-wise. If input > 0, output 1. if input < 0, output -1. if input == 0, output 0. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9389,7 +9599,7 @@ _ONNX Sin operation_ Calculates the sine of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9399,13 +9609,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Sinh` (ONNXSinhOp) @@ -9413,7 +9623,7 @@ _ONNX Sinh operation_ Calculates the hyperbolic sine of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9423,13 +9633,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Size` (ONNXSizeOp) @@ -9437,7 +9647,7 @@ _ONNX Size operation_ Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<19>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9522,7 +9732,7 @@ result = [ ] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9589,7 +9799,7 @@ Finally, L is optionally reduced: * If reduction = 'mean', the output is scalar: ReduceMean(L), or if weight is provided: `ReduceSum(L) / ReduceSum(W)`, where tensor W is of shape `(N, D1, D2, ..., Dk)` and `W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]]`. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9630,7 +9840,7 @@ The \"axis\" attribute indicates the dimension along which Softmax will be performed. The output tensor has the same shape and contains the Softmax values of the corresponding input. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9674,7 +9884,7 @@ Each of these dimensions must be matched correctly, or else the operator will throw errors. The output tensor has the same shape and contains the softmax values of the corresponding input. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9707,7 +9917,7 @@ Softplus takes one input data (Tensor) and produces one output data (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9717,13 +9927,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Softsign` (ONNXSoftsignOp) @@ -9731,7 +9941,7 @@ _ONNX Softsign operation_ Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9741,13 +9951,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.SpaceToDepth` (ONNXSpaceToDepthOp) @@ -9757,7 +9967,7 @@ SpaceToDepth rearranges blocks of spatial data into depth. More specifically, this op outputs a copy of the input tensor where values from the height and width dimensions are moved to the depth dimension. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9792,7 +10002,7 @@ If the attribute 'num_outputs' is specified, then the tensor is split into equal If the tensor is not evenly splittable into `num_outputs`, the last chunk will be smaller. If the input 'split' is specified, it indicates the sizes of each output in the split. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<18>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9836,7 +10046,7 @@ If 'split' is a 1-dimensional tensor, the input tensor is split into 'size(split with lengths of the parts on 'axis' specified in 'split'. In this scenario, the sum of entries in 'split' must be equal to the dimension size of input tensor on 'axis'. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9871,7 +10081,7 @@ Split a tensor into a list of tensors, along the specified 'axis'. Lengths of the parts can be specified using argument 'split'. Otherwise, the tensor is split to equal sized parts. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9905,7 +10115,7 @@ Split a tensor into a list of tensors, along the specified 'axis'. Lengths of the parts can be specified using input 'split'. Otherwise, the tensor is split to equal sized parts. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9939,7 +10149,7 @@ Square root takes one input data (Tensor) and produces one output data (Tensor) where the square root is, y = x^0.5, is applied to the tensor elementwise. If x is negative, then it will return NaN. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9966,7 +10176,7 @@ Takes an input `axes` with a list of axes to squeeze. If `axes` is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -9994,7 +10204,7 @@ Takes a parameter `axes` with a list of axes to squeeze. If `axes` is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10033,7 +10243,7 @@ This operator only accepts [C]- and [1, C]-tensor. If all elements in X are dropped, the output will be the empty value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C]. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<10>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10071,7 +10281,7 @@ This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; fo (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10098,7 +10308,7 @@ Element-wise sum of each of the input tensors (with Numpy-style broadcasting sup All inputs and outputs must have the same data type. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>`, `SameOperandsAndResultElementType` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10122,7 +10332,7 @@ _ONNX Tan operation_ Calculates the tangent of the given input tensor, element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10132,13 +10342,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Tanh` (ONNXTanhOp) @@ -10146,7 +10356,7 @@ _ONNX Tanh operation_ Calculates the hyperbolic tangent of the given input tensor element-wise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10156,13 +10366,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `input` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of bfloat16 type values +| `output` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.TfIdfVectorizer` (ONNXTfIdfVectorizerOp) @@ -10196,7 +10406,7 @@ this operator first computes the counts of all n-grams and then scale them by th Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor. If pool_strings is set, the input must be a string tensor. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<9>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10237,7 +10447,7 @@ ThresholdedRelu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, is applied to the tensor elementwise. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<22>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10254,13 +10464,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `X` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values #### Results: | Result | Description | | :----: | ----------- | -| `Y` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values +| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values ### `onnx.Tile` (ONNXTileOp) @@ -10270,7 +10480,7 @@ Constructs a tensor by tiling a given tensor. This is the same as function `tile` in Numpy, but no broadcast. For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]] -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10294,11 +10504,11 @@ Effects: `MemoryEffects::Effect{}` _ONNX TopK operation_ Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of -shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: +shape [a_0, a_1, ..., a_{n-1\}\] and integer argument k, return two outputs: -* Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] +* Value tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1\}\] which contains the values of the top k elements along the specified axis -* Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which +* Index tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1\}\] which contains the indices of the top k elements (original indices from the input tensor). @@ -10309,7 +10519,7 @@ shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: Given two equivalent values, this operator uses the indices along the axis as a tiebreaker. That is, the element with the lower index will appear first. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10346,7 +10556,7 @@ Transpose the input tensor similar to numpy.transpose. For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape will be (2, 1, 3). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10385,7 +10595,7 @@ Tree Ensemble classifier. Returns the top class for each of N inputs.
One and only one of classlabels_strings or classlabels_int64s will be defined. The class_ids are indices into this list. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10442,7 +10652,7 @@ Tree Ensemble regressor. Returns the regressed values for each input in N.
All trees must have their node ids start at 0 and increment by 1.
Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10500,7 +10710,7 @@ A negative k value retains the main diagonal and |k| diagonals below it. If upper is set to false, a positive k retains the lower triangular matrix including the main diagonal and k diagonals above it. A negative k value excludes the main diagonal and (|k|-1) diagonals below it. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<14>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10628,7 +10838,7 @@ output_counts: [2, 1, 1] ``` -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10672,7 +10882,7 @@ The rank of the output tensor (`output_rank`) is the rank of the input tensor (` Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. The order of values in `axes` does not matter and can come in any order. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<13>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10708,7 +10918,7 @@ Each value in `axes` should be within the (inclusive) range [-output_rank , outp The order of values in `axes` does not matter and can come in any order. -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<11>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10741,7 +10951,7 @@ Upsample the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * scale). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<9>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10775,7 +10985,7 @@ Upsample the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * scale). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<7>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10812,7 +11022,7 @@ with three parameters. This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<16>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10841,7 +11051,7 @@ elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting supp This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). -Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<7>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` @@ -10898,7 +11108,7 @@ Creates a map from the input and the attributes.
Must provide keys in either classlabels_strings or classlabels_int64s (but not both).
The columns of the tensor correspond one-by-one to the keys specified by the attributes. There must be as many columns as keys.
-Traits: `AlwaysSpeculatableImplTrait` +Traits: `AlwaysSpeculatableImplTrait`, `OpVersionTrait<1>` Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` diff --git a/docs/Dialects/zhigh.md b/docs/Dialects/zhigh.md index 4780cbe551..dd87eeecf5 100644 --- a/docs/Dialects/zhigh.md +++ b/docs/Dialects/zhigh.md @@ -337,6 +337,61 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `hn_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS +### `zhigh.Gelu` (::onnx_mlir::zhigh::ZHighGeluOp) + +_ZHigh Gelu operation_ + +"ZHigh operation to perform a Gelu." + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
approximate::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +### `zhigh.InvSqrt` (::onnx_mlir::zhigh::ZHighInvSqrtOp) + +_ZHigh InvSqrt operation_ + +ZHigh operation to perform a InvSqrt. + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + ### `zhigh.LSTM` (::onnx_mlir::zhigh::ZHighLSTMOp) _ZHigh LSTM operation_ @@ -389,6 +444,37 @@ Effects: `MemoryEffects::Effect{}` | `hn_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS | `cf_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS +### `zhigh.LeakyRelu` (::onnx_mlir::zhigh::ZHighLeakyReluOp) + +_ZHigh LeakyRelu operation_ + +"ZHigh operation to perform a LeakyRelu." + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
alpha::mlir::FloatAttr32-bit float attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + ### `zhigh.Log` (::onnx_mlir::zhigh::ZHighLogOp) _ZHigh Log operation_ @@ -425,6 +511,14 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac Effects: `MemoryEffects::Effect{}` +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
transposeA::mlir::IntegerAttr64-bit signed integer attribute
transposeB::mlir::IntegerAttr64-bit signed integer attribute
+ #### Operands: | Operand | Description | @@ -577,6 +671,168 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH +### `zhigh.QuantizedMatMul` (::onnx_mlir::zhigh::ZHighQuantizedMatMulOp) + +_ZHigh QuantizedMatMul operation_ + +ZHigh operation to perform a quantized MatMul. + +`OutRecScaleIn` and `OutOffsetIn` are recscale and offset for the output. +If `OutRecScaleIn` is given, it will be passed to `OutRecScale`. If it is +None, `OutRescScale` is set to 1.0. +If `OutOffsetIn` is given, it will be passed to `OutOffset`. If it is +None, `OutOffset` is set to 0.0. + +* PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed. + +`DequantizeOutput` indicates if the output +is dequantized to real dfloat16 or not. If not, the output is int8 but stored in dlfloat (int8-as-dlfloat). +* DequantizeOutput: -1 output is dequantized, 0: output is not dequantized. + +Traits: `AlwaysSpeculatableImplTrait` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
PreComputedBias::mlir::IntegerAttr64-bit signed integer attribute
DisableClipping::mlir::IntegerAttr64-bit signed integer attribute
DequantizeOutput::mlir::IntegerAttr64-bit signed integer attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS +| `XRecScale` | 0D tensor of 32-bit float values +| `XOffset` | 0D tensor of 32-bit float values +| `Y` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS +| `YRecScale` | 0D tensor of 32-bit float values +| `YOffset` | 0D tensor of 32-bit float values +| `B` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or none type +| `BRecScale` | 0D tensor of 32-bit float values or none type +| `BOffset` | 0D tensor of 32-bit float values or none type +| `OutRecScaleIn` | 0D tensor of 32-bit float values or none type +| `OutOffsetIn` | 0D tensor of 32-bit float values or none type + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS +| `OutRecScale` | 0D tensor of 32-bit float values +| `OutOffset` | 0D tensor of 32-bit float values + +### `zhigh.QuantizedStick` (::onnx_mlir::zhigh::ZHighQuantizedStickOp) + +_ZHigh QuantizedStick operation_ + +ZHigh operation to perform a quantized Stick. +Type is one of values: dlfloat16, int8, and weights. +`sym_mode` indicates whether to use symmetric quantization or not to compute the output rescale and offset. +`sym_mode` is only effective when the input rescale and offset are None. +By default, asymmetric quantization is used. + +Traits: `AlwaysSpeculatableImplTrait` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
quantized_type::mlir::StringAttrstring attribute
sym_mode::mlir::IntegerAttr64-bit signless integer attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `In` | tensor of 32-bit float values or tensor of 8-bit signless integer values or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS +| `InRecScale` | 0D tensor of 32-bit float values or none type +| `InOffset` | 0D tensor of 32-bit float values or none type + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS or none type +| `RecScale` | 0D tensor of 32-bit float values +| `Offset` | 0D tensor of 32-bit float values + +### `zhigh.ReduceMax` (::onnx_mlir::zhigh::ZHighReduceMaxOp) + +_ZHigh ReduceMax operation_ + +ZHigh operation to perform a ReduceMax. +op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM. + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
op_type::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `data` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +### `zhigh.ReduceMin` (::onnx_mlir::zhigh::ZHighReduceMinOp) + +_ZHigh ReduceMin operation_ + +ZHigh operation to perform a ReduceMin. +op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM. + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
op_type::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `data` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + ### `zhigh.Relu` (::onnx_mlir::zhigh::ZHighReluOp) _ZHigh Relu operation_ @@ -657,6 +913,30 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `Out` | unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS +### `zhigh.Sqrt` (::onnx_mlir::zhigh::ZHighSqrtOp) + +_ZHigh Sqrt operation_ + +ZHigh operation to perform a Sqrt. + +Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + ### `zhigh.StickForGRU` (::onnx_mlir::zhigh::ZHighStickForGRUOp) _ZHigh stick operation for GRU_ @@ -815,7 +1095,7 @@ Effects: `MemoryEffects::Effect{}` | Result | Description | | :----: | ----------- | -| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH +| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH or unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS ### `zhigh.Sub` (::onnx_mlir::zhigh::ZHighSubOp) diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md index ba6907fced..7be1c6457b 100644 --- a/docs/Dialects/zlow.md +++ b/docs/Dialects/zlow.md @@ -342,6 +342,52 @@ Interfaces: `MemoryEffectOpInterface` | `shape` | memref of 64-bit signless integer values | `hn_output` | memref of dlfloat16 type values +### `zlow.gelu` (::onnx_mlir::zlow::ZLowGeluOp) + +_ZLow gelu operation_ + +ZLow operation to perform a gelu. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + +### `zlow.invsqrt` (::onnx_mlir::zlow::ZLowInvSqrtOp) + +_ZLow invsqrt operation_ + +ZLow operation to perform a invsqrt. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + ### `zlow.lstm` (::onnx_mlir::zlow::ZLowLSTMOp) _ZLow lstm operation_ @@ -387,6 +433,30 @@ Interfaces: `MemoryEffectOpInterface` | `hn_output` | memref of dlfloat16 type values | `cf_output` | memref of dlfloat16 type values +### `zlow.leakyrelu` (::onnx_mlir::zlow::ZLowLeakyReluOp) + +_ZLow leakyrelu operation_ + +ZLow operation to perform a leakyrelu. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
alpha::mlir::FloatAttr32-bit float attribute
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + ### `zlow.log` (::onnx_mlir::zlow::ZLowLogOp) _ZLow log operation_ @@ -423,14 +493,18 @@ shape is a 1D MemRef (memref<3xi64>) whose items are: * 2nd item: n * 3rd item: p * In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) - or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) + or broadcasting1: X(m, n) * Y(s, n, p) + Bias(s, p) + or broadcasting23: X(s, m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<4xi64>) whose items are: * 1st item: s * 2nd item: m * 3rd item: n * 4th item: p -* is_bcast: -1 broadcasting, 0: no broadcasting. +* is_bcast1: -1 broadcasting1, 0: no broadcasting1. +* is_bcast23: -1 broadcasting23, 0: no broadcasting23. * is_stacked: -1 stacked, 0: unstacked. +* transposeA: !0 transpose A, 0: do not transpose A. +* transposeB: !0 transpose B, 0: do not transpose B. Traits: `MemRefsNormalizable` @@ -440,8 +514,11 @@ Interfaces: `MemoryEffectOpInterface` - + + + +
AttributeMLIR TypeDescription
is_bcast::mlir::IntegerAttr64-bit signed integer attribute
is_bcast1::mlir::IntegerAttr64-bit signed integer attribute
is_bcast23::mlir::IntegerAttr64-bit signed integer attribute
is_stacked::mlir::IntegerAttr64-bit signed integer attribute
transposeA::mlir::IntegerAttr64-bit signed integer attribute
transposeB::mlir::IntegerAttr64-bit signed integer attribute
#### Operands: @@ -592,6 +669,144 @@ Interfaces: `MemoryEffectOpInterface` | `shape` | memref of 64-bit signless integer values | `Out` | memref of dlfloat16 type values +### `zlow.quantizedMatmul` (::onnx_mlir::zlow::ZLowQuantizedMatMulOp) + +_ZLow quantized matmul operation_ + +ZLow operation to perform a matmul. +work_area: a 4K-aligned buffer having the same layout as bias but dlfloat16 type. +* In case of unstacked: X(m, n) * Y(n, p) + Bias(p) +shape is a 1D MemRef (memref<3xi64>) whose items are: + * 1st item: m + * 2nd item: n + * 3rd item: p +* In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) + or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) +shape is a 1D MemRef (memref<4xi64>) whose items are: + * 1st item: s + * 2nd item: m + * 3rd item: n + * 4th item: p +* is_bcast: -1 broadcasting, 0: no broadcasting. +* is_stacked: -1 stacked, 0: unstacked. +* DequantizeOutput: -1 output is dequantized, 0: output is not dequantized. +* PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed. + +Values for `q_type` are "DLFLOAT16", "INT8", "WEIGHTS", "UNDEFINED". + + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + + + + + + + + + +
AttributeMLIR TypeDescription
x_q_type::mlir::StringAttrstring attribute
y_q_type::mlir::StringAttrstring attribute
bias_q_type::mlir::StringAttrstring attribute
out_q_type::mlir::StringAttrstring attribute
is_bcast::mlir::IntegerAttr64-bit signed integer attribute
is_stacked::mlir::IntegerAttr64-bit signed integer attribute
pre_computed_bias::mlir::IntegerAttr64-bit signed integer attribute
disable_clipping::mlir::IntegerAttr64-bit signed integer attribute
dequantize_output::mlir::IntegerAttr64-bit signed integer attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type or 8-bit signless integer values +| `x_rec_scale` | 0D memref of 32-bit float values +| `x_offset` | 0D memref of 32-bit float values +| `Y` | memref of dlfloat16 type or 8-bit signless integer values +| `y_rec_scale` | 0D memref of 32-bit float values +| `y_offset` | 0D memref of 32-bit float values +| `Bias` | memref of dlfloat16 type or 8-bit signless integer values +| `bias_rec_scale` | 0D memref of 32-bit float values +| `bias_offset` | 0D memref of 32-bit float values +| `work_area` | memref of dlfloat16 type or 8-bit signless integer values or none type +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type or 8-bit signless integer values +| `out_rec_scale` | 0D memref of 32-bit float values +| `out_offset` | 0D memref of 32-bit float values + +### `zlow.quantizedStick` (::onnx_mlir::zlow::ZLowQuantizedStickOp) + +_ZLow stick operation for quantization_ + +"ZLow operation to perform a quantization stick." +"Type is one of values: dlfloat16, int8, and weights." + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
q_type::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of 8-bit signless integer or 32-bit float values +| `rec_scale` | 0D memref of 32-bit float values +| `offset` | 0D memref of 32-bit float values +| `out` | memref of dlfloat16 type or 8-bit signless integer values + +### `zlow.reducemax` (::onnx_mlir::zlow::ZLowReduceMaxOp) + +_ZLow reducemax operation_ + +ZLow operation to perform a reducemax. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `work_area` | memref of 8-bit signless integer values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + +### `zlow.reducemin` (::onnx_mlir::zlow::ZLowReduceMinOp) + +_ZLow reducemin operation_ + +ZLow operation to perform a reducemin. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `work_area` | memref of 8-bit signless integer values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + ### `zlow.relu` (::onnx_mlir::zlow::ZLowReluOp) _ZLow relu operation_ @@ -670,6 +885,29 @@ Interfaces: `MemoryEffectOpInterface` | `shape` | memref of 64-bit signless integer values | `Out` | memref of dlfloat16 type values +### `zlow.sqrt` (::onnx_mlir::zlow::ZLowSqrtOp) + +_ZLow sqrt operation_ + +ZLow operation to perform a sqrt. + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + ### `zlow.stickForGRU` (::onnx_mlir::zlow::ZLowStickForGRUOp) _ZLow stick operation for GRU_ diff --git a/docs/Docker.md b/docs/Docker.md index bcc071a948..40b60c4757 100644 --- a/docs/Docker.md +++ b/docs/Docker.md @@ -12,7 +12,7 @@ There are three ways to use ONNX-MLIR with Docker. An easy way to get started with ONNX-MLIR is to use a prebuilt Docker image. These images are created as a result of a successful merge build on the trunk. This means that the latest image represents the tip of the trunk. -Currently there are both Release and Debug mode images for `amd64`, `ppc64le` and `s390x` saved in Docker Hub as, respectively, [onnxmlir/onnx-mlir](https://hub.docker.com/r/onnxmlir/onnx-mlir) and [onnxmlir/onnx-mlir-dev](https://hub.docker.com/r/onnxmlir/onnx-mlir-dev). +Currently there are both Release and Debug mode images for `amd64`, `ppc64le` and `s390x` saved in Docker Hub as, respectively, [onnxmlir/onnx-mlir](https://github.com/users/onnxmlir/packages/container/onnx-mlir) and [onnxmlir/onnx-mlir-dev](https://github.com/users/onnxmlir/packages/container/onnx-mlir-dev). To use one of these images either pull it directly from Docker Hub, launch a container and run an interactive bash shell in it, or use it as the base image in a Dockerfile. Here are the differences between the two Docker images. @@ -53,7 +53,7 @@ The Dockerfile is shown here, and should be modified according to one's need. Th [same-as-file]: <> (docs/docker-example/Dockerfile) ``` -FROM onnxmlir/onnx-mlir-dev +FROM ghcr.io/onnxmlir/onnx-mlir-dev WORKDIR /workdir ENV HOME=/workdir @@ -122,9 +122,9 @@ cd ~/DockerOnnxMlir # Edit the Dockerfile. vi Dockerfile # Build the Docker image. -docker build --tag onnx-mlir-dev . +docker build --tag ghcr.io/onnxmlir/onnx-mlir-dev . # Start a container using the Docker dashboard or a docker run command. -docker run -it onnx-mlir-dev +docker run -it ghcr.io/onnxmlir/onnx-mlir-dev ``` **NOTE:** If you are using a MacBook with the Apple M1 chip, please follow the steps below for configuration: @@ -135,11 +135,11 @@ cd ~/DockerOnnxMlir # Edit the Dockerfile. vi Dockerfile # Pull the Docker image with the specified platform -docker pull --platform linux/amd64 onnxmlir/onnx-mlir-dev +docker pull --platform linux/amd64 ghcr.io/onnxmlir/onnx-mlir-dev # Build the Docker image. -docker build --platform linux/amd64 --tag onnx-mlir-dev . +docker build --platform linux/amd64 --tag ghcr.io/onnxmlir/onnx-mlir-dev . # Start a container using the Docker dashboard or a docker run command. -docker run --platform linux/amd64 -it onnx-mlir-dev +docker run --platform linux/amd64 -it ghcr.io/onnxmlir/onnx-mlir-dev ``` Tip: Instead of adding the platform flag for every docker pull, build, and run command. You can set the environment variable `DOCKER_DEFAULT_PLATFORM` and use the first set of steps: diff --git a/docs/Instrumentation.md b/docs/Instrumentation.md index 31969ff15e..25b77153b6 100644 --- a/docs/Instrumentation.md +++ b/docs/Instrumentation.md @@ -61,11 +61,11 @@ The output for the memory measurement is explained here. Other example for NNPA - Performance profiling for onnx ops before lowering to zhigh ops: - `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` + `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` - Performance profiling for onnx and zhigh ops: - `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=ZHigh --instrument-ops=onnx.*,zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` + `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=ZHigh --instrument-ops=onnx.*,zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` - Performance profiling for zlow ops: - `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` + `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx` ## Control instrument at runtime By providing certain env variable at runtime, you can disable reports from instrument library. diff --git a/docs/LoweringCode.md b/docs/LoweringCode.md index 34d6523e8b..db0b154092 100644 --- a/docs/LoweringCode.md +++ b/docs/LoweringCode.md @@ -105,7 +105,7 @@ struct KrnlBuilder : public DialectBuilder { void iterate(ValueRange originalLoops, ValueRange optimizedLoops, ValueRange lbs, ValueRange ubs, - function_ref + function_ref bodyBuilderFn); }; ``` @@ -128,7 +128,7 @@ ValueRange loopDef = createKrnl.defineLoops(2); // Create the loop. createKrnl.iterate(loopDef, loopDef, {zero, zero}, {ub0, ub1}, - [&](KrnlBuilder &createKrnl, ValueRange loopInd){ + [&](const KrnlBuilder &createKrnl, ValueRange loopInd){ // Loop body. createKrnl.store(zero, array, loopInd); }); @@ -183,7 +183,7 @@ ValueRange loopBlockDef = createKrnl.block(loopDef, 4); createKrnl.permute({loopBlockDef[0], loopBlockDef[1], {0,1}); // Create the loop iterating over the blocks. createKrnl.iterate(loopDef, {loopBlockDef[0], loopBlockDef[0]}, {zero}, {ub0}, - [&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){ + [&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){ // Loop body. createKrnl.store(zero, array, loopInd); }); @@ -209,10 +209,10 @@ We now consider tiling our original 2-dimensional example below. // Create the loop iterating over the blocks. createKrnl.iterate(loopDef, {outerLoopBlockDef[0], innerLoopBlockDef[0]}, {zero, zero}, {ub0, ub1}, - [&](KrnlBuilder &createKrnl, ValueRange blockLoopInd){ + [&](const KrnlBuilder &createKrnl, ValueRange blockLoopInd){ // Create the loop iterating inside the blocks. createKrnl.iterate({}, {outerLoopBlockDef[1], innerLoopBlockDef[1]}, - {}, {}, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + {}, {}, [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Loop body. createKrnl.store(zero, array, loopInd); }); diff --git a/docs/ONNXAI.md b/docs/ONNXAI.md index 5529733442..73eba1fea5 100644 --- a/docs/ONNXAI.md +++ b/docs/ONNXAI.md @@ -3,7 +3,7 @@ # About ONNX-MLIR is an open-source project for compiling ONNX models into native code -on x86, P and Z machines (and more). It is built on top of Multi-Level +on x86, Power, s390x and other architectures. It is built on top of Multi-Level Intermediate Representation (MLIR) compiler infrastructure. # Slack channel diff --git a/docs/PythonPackage.md b/docs/PythonPackage.md new file mode 100644 index 0000000000..038615ea3f --- /dev/null +++ b/docs/PythonPackage.md @@ -0,0 +1,19 @@ +The Python package, onnxmlir, provides an installable package to use onnx-mlir +compiler in a similar way to onnxruntime. Also the package supports the way to +run model by `utils/RunONNXModel.py`. + +The source of the package is located at `onnx-mlir/utils/onnxmlir`. The main python code, `onnxmlir/src/onnxmlir/RunONNXModel.py` should be the same as `onnx-mlir/utils/RunONNXModel.py`. You can use target `OMCreateONNXMLIRSource` to create the installable directory in your build directory. +The package can be installed from your local directory with `pip3 install your_path/onnx-mlir/build/utils/onnxmlir` + +Follow instructions in https://packaging.python.org/en/latest/tutorials/packaging-projects/ +commands to use under the top directory onnxmlir +``` +python3 -m pip install --upgrade build +python3 -m build +#After get the api-token +python3 -m pip install --upgrade twine +python3 -m twine upload --repository testpypi dist/* +``` +Different from document, the prompt asked only for the api-token + +Examples can be found at onnx-mlir/util/onnxmlir/tests. diff --git a/docs/SupportedONNXOps-NNPA.md b/docs/SupportedONNXOps-NNPA.md index ab21ed5ef2..a0f85aef41 100644 --- a/docs/SupportedONNXOps-NNPA.md +++ b/docs/SupportedONNXOps-NNPA.md @@ -3,43 +3,43 @@ # Supported ONNX Operation for Target *NNPA*. -Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. + * A ^ indicates onnx-mlir is compatible with the latest level of the NNPA Architecture which is z16. -NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. +NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. NNPA currently only support DLFLOAT16 as its data type. Common data formats like FP32, FP16, BFLOAT need to undergo data conversions to the NNPA internal format DLFLOAT16. Hence ONNX ops which updated their tensors to BFLOAT16 will not be natively supported on NNPA. Onnx-mlir with NNPA utilizes hardware when possible. To accomplish this, the compiler converts ONNX ops to [ZHigh](Dialects/zhigh.md) ops, [ZLow](Dialects/zlow.md) ops, and are processed by the [IBM Z Deep Neural Network Library (zDNN)](https://github.com/IBM/zDNN). -| Op |Supported Opsets (inclusive) |Limitations |Notes | -| --- |--- |--- |--- | -| **Add** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **AveragePool** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors (N x C x H x W).
- `kernel_shape` must be static.
- `count_include_pad` must be default value(0).
- `ceil_mode` must be default value(0). | | -| **BatchNormalization** |6 - * |Input and output tensor must be 4D(N x C x H x W). | | -| **Conv** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- Dimension in Height and weight must be static.
- `group` must be default value(1).
- `dilations` must be default value(1).
- Input and output tensors must have 4D (N x C x H x W).
- `kernel_shape` must be static. | | -| **ConvTranspose** |6 - * |- 1D and 3D not supported because Conv1D and Conv3D not supported in zDNN. non-default `dilations` not supported because dilated convolution not supported in zDNN. | | -| **Div** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **Exp** |6 - * |Input tensor must have 4 dimensions. | | -| **GRU** |7 - * |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- If `B` and `initial_h` are given, they must have static dimensions.
- `sequence_lens` is not supported for bidirectional GRU.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `linear_before_reset` must be 1.
- `layout` is not supported. | | -| **Gemm** |6 - * |- `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`. | | -| **GlobalAveragePool** |6 - * |- Input shape must be 4D tensor(NCHW).
- Dimensions in `H` and `W` must be static. | | -| **LSTM** |7 - * |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.
- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `input_forget` must be default value(0).
- `layout` is not supported. | | -| **LeakyRelu** |6 - * |The operations immediately before and after the LeakyRelu operation must be executed on the NNPA. Otherwise, LeakyRelu is executed on the CPU. This limitation is set to avoid performance degradation. | | -| **Log** |6 - * |Input tensor must have 4 dimensions. | | -| **LogSoftmax** |6 - * | | | -| **MatMul** |6 - * |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. | | -| **Max** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **MaxPool** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors(N x C x H x W).
- `kernel_shape` must be static.
- `ceil_mode` must be default value(0).
- `dilations` must be default value(1). | | -| **Min** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **Mul** |6 - * |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **Pow** |7 - * |- Exponent should be a scalar integer and less or equal to 64. | | -| **ReduceMean** |6 - * |- `keepdims` must be 1.
- Input tensor must be 4D tensors and `axis` must be [2, 3]. | | -| **Relu** |6 - * |Input tensor must be less than or equal to 4 dimensions. | | -| **Sigmoid** |6 - * |Input tensor must be less than or equal to 4 dimensions. | | -| **Softmax** |6 - * |- `axis` must be the last dimension, i.e. `rank - 1` or -1. | | -| **Softplus** |6 - * |The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. | | -| **Sub** |6 - * |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | -| **Sum** |6 - * |- All inputs must have the same static shape (Broadcasting not supported.)
- Single input not supported. | | -| **Tanh** |6 - * |Input tensor must be less than or equal to 4 dimensions. | | +| Op |Supported Opsets (inclusive) |Minimum NNPA Level(Inclusive) |Limitations |Notes | +| --- |--- |--- |--- |--- | +| **Add** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **AveragePool** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors (N x C x H x W).
- `kernel_shape` must be static.
- `count_include_pad` must be default value(0).
- `ceil_mode` must be default value(0). | | +| **BatchNormalization** |6 - * |z16 |Input and output tensor must be 4D(N x C x H x W). | | +| **Conv** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- Dimension in Height and weight must be static.
- `group` must be default value(1).
- `dilations` must be default value(1).
- Input and output tensors must have 4D (N x C x H x W).
- `kernel_shape` must be static. | | +| **ConvTranspose** |6 - * |z16 |- 1D and 3D not supported because Conv1D and Conv3D not supported in zDNN. non-default `dilations` not supported because dilated convolution not supported in zDNN. | | +| **Div** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **Exp** |6 - * |z16 |Input tensor must have 4 dimensions. | | +| **GRU** |7 - * |z16 |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- If `B` and `initial_h` are given, they must have static dimensions.
- `sequence_lens` is not supported for bidirectional GRU.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `linear_before_reset` must be 1.
- `layout` is not supported. | | +| **Gemm** |6 - * |z16 |- `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`.
. | | +| **GlobalAveragePool** |6 - * |z16 |- Input shape must be 4D tensor(NCHW).
- Dimensions in `H` and `W` must be static. | | +| **LSTM** |7 - * |z16 |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.
- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `input_forget` must be default value(0).
- `layout` is not supported. | | +| **Log** |6 - * |z16 |Input tensor must have 4 dimensions. | | +| **LogSoftmax** |6 - * |z16 | | | +| **MatMul** |6 - * |z16 |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. | | +| **Max** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **MaxPool** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors(N x C x H x W).
- `kernel_shape` must be static.
- `ceil_mode` must be default value(0).
- `dilations` must be default value(1). | | +| **Min** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **Mul** |6 - * |z16 |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **Pow** |7 - * |z16 |- Exponent should be a scalar integer and less or equal to 64. | | +| **ReduceMean** |6 - * |z16 |- `keepdims` must be 1.
- Input tensor must be 4D tensors and `axis` must be [2, 3]. | | +| **Relu** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | | +| **Sigmoid** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | | +| **Softmax** |6 - * |z16 |- `axis` must be the last dimension, i.e. `rank - 1` or -1. | | +| **Softplus** |6 - * |z16 |The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. | | +| **Sub** |6 - * |z16 |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | | +| **Sum** |6 - * |z16 |- All inputs must have the same static shape (Broadcasting not supported.)
- Single input not supported. | | +| **Tanh** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | | diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index a9206358ad..172c9ed7ca 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *cpu*. -Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. | Op |Supported Opsets (inclusive) |Limitations |Notes | @@ -36,8 +36,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **BitwiseOr** |18 - * | | | | **BitwiseXor** |18 - * | | | | **BlackmanWindow** |none | | | | -| **Cast** |6 - * |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. | | -| **CastLike** |19 - * |CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. | | +| **Cast** |6 - * |Cast only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. | | +| **CastLike** |19 - * |CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. | | | **CastMap** |none | | | | | **CategoryMapper** |none | | | | | **Ceil** |6 - * | | | @@ -48,8 +48,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Compress** |9 - * | | | | **Concat** |6 - * | | | | **ConcatFromSequence** |none | | | | -| **Constant** |6 - * | | | -| **ConstantOfShape** |9 - * | | | +| **Constant** |6 - * |Does not support int4 and uint4. | | +| **ConstantOfShape** |9 - * |Does not support int4 and uint4. | | | **Conv** |6 - * | | | | **ConvInteger** |none | | | | | **ConvTranspose** |6 - * |Spatial dimensions (H and W in input `X`, and kH and kW in input `W`) must be static dimension. | | @@ -59,7 +59,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **DFT** |17 - * | | | | **DeformConv** |none | | | | | **DepthToSpace** |13 - * | | | -| **DequantizeLinear** |10 - * |Only support for per-tensor or layer dequantization. No support for per-axis dequantization. | | +| **DequantizeLinear** |10 - * |Only support for per-tensor or layer dequantization. No support for per-axis dequantization. Does not support int4 and uint4. | | | **Det** |none | | | | | **DictVectorizer** |none | | | | | **Div** |6 - * |No support for short integers. | | @@ -73,7 +73,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Expand** |8 - * |Input `shape` must have static shape. | | | **EyeLike** |none | | | | | **FeatureVectorizer** |none | | | | -| **Flatten** |6 - * | | | +| **Flatten** |6 - * |Does not support int4 and uint4. | | | **Floor** |6 - * | | | | **GRU** |7 - * |W, B and R must be constants. | | | **Gather** |6 - * | | | @@ -94,8 +94,8 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **HardSigmoid** |6 - * | | | | **HardSwish** |none | | | | | **Hardmax** |6 - * | | | -| **Identity** |16 - * |Sequence identity not supported. | | -| **If** |16 - * |Sequence and Optional outputs are not supported. | | +| **Identity** |16 - * |Sequence identity not supported. Does not support int4 and uint4. | | +| **If** |16 - * |Sequence and Optional outputs are not supported. Does not support int4 and uint4. | | | **Imputer** |none | | | | | **InstanceNormalization** |6 - * | | | | **IsInf** |20 - * |Currently no support for float16 infinity value. Only for float32 and float64. | | @@ -111,7 +111,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **LinearRegressor** |none | | | | | **Log** |6 - * | | | | **LogSoftmax** |13 - * |Axis 0, 1, and default currently disabled due to changes in ONNX 1.8.1/Opset 13. |Temporally removed due to changes in onnx 1.8.1. | -| **Loop** |6 - * |Input must have static shape. | | +| **Loop** |6 - * |Input must have static shape. Does not support int4 and uint4. | | | **LpNormalization** |none | | | | | **LpPool** |none | | | | | **MatMul** |6 - * | | | @@ -142,11 +142,11 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **OptionalHasElement** |none | | | | | **Or** |7 - * | | | | **PRelu** |6 - * | | | -| **Pad** |6 - * |axes input not supported. | | +| **Pad** |6 - * |axes input not supported. Does not support int4 and uint4. | | | **Pow** |7 - * |No support for power with integer types. | | | **QLinearConv** |none | | | | | **QLinearMatMul** |none | | | | -| **QuantizeLinear** |10 - * |Do not support per-axis and i8 quantization. | | +| **QuantizeLinear** |10 - * |Does not support per-axis and i8 quantization. Does not support int4 and uint4. | | | **RNN** |7 - * |W, B and R must be constants. | | | **RandomNormal** |none | | | | | **RandomNormalLike** |none | | | | @@ -158,14 +158,14 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **ReduceL2** |13 - * |do_not_keep_dim not supported. | | | **ReduceLogSum** |13 - * |do_not_keep_dim not supported. | | | **ReduceLogSumExp** |13 - * |do_not_keep_dim not supported. | | -| **ReduceMax** |6 - * |do_not_keep_dim not supported. | | -| **ReduceMean** |6 - * |do_not_keep_dim not supported. | | -| **ReduceMin** |6 - * |do_not_keep_dim not supported. | | +| **ReduceMax** |6 - * |do_not_keep_dims not supported. | | +| **ReduceMean** |6 - * |do_not_keep_dims not supported. | | +| **ReduceMin** |6 - * |do_not_keep_dims not supported. | | | **ReduceProd** |13 - * |do_not_keep_dim not supported. | | | **ReduceSum** |6 - * |Default axis and do_not_keep_dim not supported. |Default axis and do_not_keep_dim temporarily removed due to changes in onnx 1.8.1. | | **ReduceSumSquare** |13 - * |Default axis and do_not_keep_dim not supported. | | | **Relu** |6 - * | | | -| **Reshape** |6 - * |allowzero not supported. Input `shape` must have static dimension. | | +| **Reshape** |6 - * |allowzero not supported. Input `shape` must have static dimension. Does not support int4 and uint4. | | | **Resize** |10 - * |Missing support for linear, cubic, crop, pytorch_half_pixel, and floor. Attributes antialias, axes and keep_aspect_ratio_policy are not supported. `scales` and `sizes` must have static dimension. | | | **ReverseSequence** |10 - * | | | | **RoiAlign** |none | | | | @@ -174,7 +174,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **SVMClassifier** |none | | | | | **SVMRegressor** |none | | | | | **Scaler** |none | | | | -| **Scan** |8 - * |Does not support dynamic shapes. |Precision issue with newer opset, maybe just unsupported. Dynamic shape?. | +| **Scan** |8 - * |Does not support dynamic shapes. Does not support int4 and uint4. |Precision issue with newer opset, maybe just unsupported. Dynamic shape?. | | **Scatter** |none | | | | | **ScatterElements** |11 - * |Does not support duplicate indices. | | | **ScatterND** |11 - * |Does not support scatternd add/multiply. | | @@ -186,13 +186,13 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **SequenceInsert** |11 - * |Does not support unranked sequence element. | | | **SequenceLength** |none | | | | | **SequenceMap** |none | | | | -| **Shape** |15 - * |Does not support start and end attributes. | | +| **Shape** |15 - * |Does not support start and end attributes. Does not support int4 and uint4. | | | **Shrink** |none | | | | | **Sigmoid** |6 - * | | | | **Sign** |9 - * | | | | **Sin** |7 - * | | | | **Sinh** |9 - * | | | -| **Size** |13 - * | | | +| **Size** |13 - * |Does not support int4 and uint4. | | | **Slice** |13 - * |Axis must be a constant argument. |Add tests to slices, currently have none. | | **Softmax** |6 - * | | | | **SoftmaxCrossEntropyLoss** |none | | | | @@ -202,7 +202,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **Split** |6 - * |Does not support static and dynamic shape, zero size splits. |Temporally removed due to changes in onnx 1.8.1. | | **SplitToSequence** |none | | | | | **Sqrt** |6 - * | | | -| **Squeeze** |6 - * |Does not support static and dynamic shape. |Temporally removed due to changes in onnx 1.8.1. | +| **Squeeze** |6 - * |Does not support static and dynamic shape. Does not support int4 and uint4. |Temporally removed due to changes in onnx 1.8.1. | | **StringNormalizer** |none | | | | | **Sub** |6 - * |Does not support short integers. | | | **Sum** |6 - * | | | @@ -212,12 +212,12 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitatio | **ThresholdedRelu** |none | | | | | **Tile** |6 - * | | | | **TopK** |10 - * |`K`, the number of top elements to retrieve, must have static shape. | | -| **Transpose** |6 - * | | | +| **Transpose** |6 - * |Does not support int4 and uint4. | | | **TreeEnsembleClassifier** |none | | | | | **TreeEnsembleRegressor** |none | | | | | **Trilu** |14 - * | | | | **Unique** |11 - * | | | -| **Unsqueeze** |6 - * |Does not support static and dynamic shape. |Temporally removed due to changes in onnx 1.8.1. | +| **Unsqueeze** |6 - * |Does not support static and dynamic shape. Does not support int4 and uint4. |Temporally removed due to changes in onnx 1.8.1. | | **Upsample** |7 - * |Input `X` and `Y` must have static shape. | | | **Where** |9 - * | | | | **Xor** |7 - * | | | diff --git a/docs/Testing.md b/docs/Testing.md index 9fafc75876..c5029f01a5 100644 --- a/docs/Testing.md +++ b/docs/Testing.md @@ -122,9 +122,9 @@ cmake --build . --config Release --target check-onnx-backend-signature ### Enable SIMD instructions -On supported platforms, currently s390x only, backend tests can generate SIMD instructions for the compiled models. To enable SIMD, set the TEST_MCPU environment variable, e.g., +On supported platforms (currently s390x z14 and up, x86, and arm), backend tests can generate SIMD instructions for the compiled models. To enable SIMD, set the TEST_MARCH environment variable, e.g., ``` -TEST_MCPU=z14 cmake --build . --config Release --target check-onnx-backend[-jni] +TEST_MARCH=z16 cmake --build . --config Release --target check-onnx-backend[-jni] ``` ### Execution of backend tests @@ -294,9 +294,9 @@ If you need to change ATOL and RTOL for accuracy checks, set the environment var ### Enable SIMD instructions -On supported platforms, currently s390x only, numerical tests can generate SIMD instructions for the compiled models. To enable SIMD, set the `TEST_ARGS` environment variable, e.g., +On supported platforms (currently s390x z14 and up, x86, and arm), numerical tests can generate SIMD instructions for the compiled models. To enable SIMD, set the `TEST_ARGS` environment variable, e.g., ``` -TEST_ARGS="-mcpu=z14" CTEST_PARALLEL_LEVEL=$(nproc) cmake --build . --config Release --target check-onnx-numerical +TEST_ARGS="-march=z16" CTEST_PARALLEL_LEVEL=$(nproc) cmake --build . --config Release --target check-onnx-numerical ``` ### Testing of specific accelerators @@ -395,7 +395,7 @@ Without specifying a model using `-m`, the script will check all models in the O If you want to gather performance info about a model zoo (or any models, for that matter), simplest is to request the desired statistic at compile time (using `-profile-ir` flag), divert the output statistic to a file, and then analyze it using `make-report.py`. For example: ``` -> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 -march=arm64 --profile-ir=Onnx" -m bertsquad-10 +> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 --march=arm64 --profile-ir=Onnx" -m bertsquad-10 ... > make-report.py -r run.log ... @@ -408,7 +408,7 @@ Statistics start (all ops). The runtime profiling info can be combined with specific compile-time statistics as well. Let's say that we are interested in SIMD statistics. We inform the compiler of the compile-time statistic to emit using `-opt-report` option, and inform `RunONNXModelZoo.py` that we want to preserve the compiler output using the `--log-to-file` option. For example ``` -> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 -march=arm64 -opt-report=Simd --profile-ir=Onnx" -m bertsquad-10 --log-to-file compile.log +> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 --march=arm64 -opt-report=Simd --profile-ir=Onnx" -m bertsquad-10 --log-to-file compile.log ... > make-report.py -c compile.log -r run.log ... diff --git a/docs/TestingHighLevel.md b/docs/TestingHighLevel.md index 48fc363447..fa53e66bce 100644 --- a/docs/TestingHighLevel.md +++ b/docs/TestingHighLevel.md @@ -40,7 +40,7 @@ If you run into protobuf related errors during the build, check the following po * llvm-project, onnx, and/or onnx-mlir may detect different versions of python3 (so watch their cmake output) if you have multiple python versions installed * cmake caches stuff and you should never use "make clean" when rebuilding. Instead remove everything under the build tree and start from scratch. -These and many other trickeries for setting up the build env are the reason why we recommend using the `onnxmlir/onnx-mlir-dev` docker image for development. +These and many other trickeries for setting up the build env are the reason why we recommend using the [onnxmlir/onnx-mlir-dev](https://github.com/users/onnxmlir/packages/container/onnx-mlir-dev) docker image for development. ## High level testing of ONNX-MLIR diff --git a/docs/doc_example/main.c b/docs/doc_example/main.c index 83537c21c5..3a940d0aca 100644 --- a/docs/doc_example/main.c +++ b/docs/doc_example/main.c @@ -1,5 +1,6 @@ #include #include +#include OMTensorList *run_main_graph(OMTensorList *); @@ -11,9 +12,16 @@ OMTensorList *create_input_list() { // Construct float arrays filled with 1s or 2s. float *x1Data = (float *)malloc(sizeof(float) * num_elements); + // Check if memory is allocated for generating the data. + if(!x1Data) return NULL; for (int i = 0; i < num_elements; i++) x1Data[i] = 1.0; float *x2Data = (float *)malloc(sizeof(float) * num_elements); + // Check if memory is allocated for generating the data. + if(!x2Data){ + free(x1Data); + return NULL; + } for (int i = 0; i < num_elements; i++) x2Data[i] = 2.0; @@ -32,7 +40,10 @@ OMTensorList *create_input_list() { int main() { // Generate input TensorList OMTensorList *input_list = create_input_list(); - + if(!input_list){ + // Return 2 for failure to create inputs. + return 2; + } // Call the compiled onnx model function. OMTensorList *output_list = run_main_graph(input_list); if (!output_list) { diff --git a/docs/docker-example/Dockerfile b/docs/docker-example/Dockerfile index 3db740090c..f44ff97be7 100644 --- a/docs/docker-example/Dockerfile +++ b/docs/docker-example/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxmlir/onnx-mlir-dev +FROM ghcr.io/onnxmlir/onnx-mlir-dev WORKDIR /workdir ENV HOME=/workdir diff --git a/docs/mnist_example/requirements.txt b/docs/mnist_example/requirements.txt index fe023b5220..4d01c849b1 100644 --- a/docs/mnist_example/requirements.txt +++ b/docs/mnist_example/requirements.txt @@ -1,4 +1,4 @@ numpy~=1.22.2 -pillow~=10.2.0 -torch~=2.0.0 -torchvision~=0.15.1 +pillow~=10.3.0 +torch~=2.5.0 +torchvision~=0.20.0 diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 6575d8d60e..31df2cf34b 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -2,5 +2,7 @@ add_subdirectory(onnx-mlir) -install(FILES OnnxMlirCompiler.h DESTINATION include) -install(FILES OnnxMlirRuntime.h DESTINATION include) +if(ONNX_MLIR_INSTALL_HEADERS) + install(FILES OnnxMlirCompiler.h DESTINATION include) + install(FILES OnnxMlirRuntime.h DESTINATION include) +endif() diff --git a/include/onnx-mlir/Compiler/CMakeLists.txt b/include/onnx-mlir/Compiler/CMakeLists.txt index ed756998a6..fc8b46b25a 100644 --- a/include/onnx-mlir/Compiler/CMakeLists.txt +++ b/include/onnx-mlir/Compiler/CMakeLists.txt @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -install(FILES OMCompilerTypes.h DESTINATION include/onnx-mlir/Compiler) -install(FILES OMCompilerMacros.h DESTINATION include/onnx-mlir/Compiler) +if(ONNX_MLIR_INSTALL_HEADERS) + install(FILES OMCompilerTypes.h DESTINATION include/onnx-mlir/Compiler) + install(FILES OMCompilerMacros.h DESTINATION include/onnx-mlir/Compiler) +endif() diff --git a/include/onnx-mlir/Runtime/CMakeLists.txt b/include/onnx-mlir/Runtime/CMakeLists.txt index 0c1c50b922..ce2e602d4a 100644 --- a/include/onnx-mlir/Runtime/CMakeLists.txt +++ b/include/onnx-mlir/Runtime/CMakeLists.txt @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -install(FILES OMEntryPoint.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OMInstrument.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OMSignature.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OMTensor.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OMTensorList.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OnnxDataType.h DESTINATION include/onnx-mlir/Runtime) -install(FILES OnnxDataTypeMetaData.inc DESTINATION include/onnx-mlir/Runtime) +if(ONNX_MLIR_INSTALL_HEADERS) + install(FILES OMEntryPoint.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OMInstrument.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OMSignature.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OMTensor.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OMTensorList.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OnnxDataType.h DESTINATION include/onnx-mlir/Runtime) + install(FILES OnnxDataTypeMetaData.inc DESTINATION include/onnx-mlir/Runtime) +endif() diff --git a/requirements.txt b/requirements.txt index f77bb46e5b..9f7c35fa97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ lit~=15.0 # numpy 1.24 deprecates np.object, np.bool, np.float, np.complex, np.str, # and np.int which are used heavily in onnx-mlir. -numpy~=1.22.2, <=1.23.5 +numpy==2.0.1 +onnx==1.16.2 protobuf==4.21.12 -pytest~=7.2 -pytest-xdist~=3.0 +pytest==8.3.2 +pytest-xdist==3.6.1 diff --git a/src/Accelerators/CMakeLists.txt b/src/Accelerators/CMakeLists.txt index 4a4f97a2e0..87f199d443 100644 --- a/src/Accelerators/CMakeLists.txt +++ b/src/Accelerators/CMakeLists.txt @@ -52,6 +52,11 @@ add_onnx_mlir_library(OMInitAccelerators MLIRIR ) +if (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "GNU") + target_compile_options(OMInitAccelerators PUBLIC -Wno-gnu-zero-variadic-macro-arguments) +endif() + + add_onnx_mlir_library(OMAccelerator Accelerator.cpp @@ -67,3 +72,6 @@ add_onnx_mlir_library(OMAccelerator LLVMSupport MLIRIR ) +if (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "GNU") + target_compile_options(OMAccelerator PUBLIC -Wno-gnu-zero-variadic-macro-arguments) +endif() diff --git a/src/Accelerators/NNPA/CMakeLists.txt b/src/Accelerators/NNPA/CMakeLists.txt index 51625e984b..d3687aabc9 100644 --- a/src/Accelerators/NNPA/CMakeLists.txt +++ b/src/Accelerators/NNPA/CMakeLists.txt @@ -33,7 +33,7 @@ else() endif() include(zdnn.cmake) -setup_zdnn(v1.0.1) +setup_zdnn(v1.1.1) add_subdirectory(Dialect) add_subdirectory(Conversion) diff --git a/src/Accelerators/NNPA/Compiler/CMakeLists.txt b/src/Accelerators/NNPA/Compiler/CMakeLists.txt index 83e4bdd9a2..a12b9126d8 100644 --- a/src/Accelerators/NNPA/Compiler/CMakeLists.txt +++ b/src/Accelerators/NNPA/Compiler/CMakeLists.txt @@ -19,6 +19,7 @@ add_onnx_mlir_library(OMNNPACompilerOptions add_onnx_mlir_library(OMNNPACompilerUtils NNPACompilerUtils.cpp + ZHighDisposableGarbageCollector.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index ee4e0ae363..52d7933888 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -61,7 +61,7 @@ llvm::cl::opt nnpaEnableCompilerStickUnstick( llvm::cl::opt nnpaEnableScalarBcastBinary( "nnpa-enable-scalar-bcast-binary", llvm::cl::desc("Enable the lowering to NNPA of binary operations with " - "broadcasting of a scalar operand." + "broadcasting of a scalar operand.\n" "Currently only enable ONNXDiv. Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -70,9 +70,10 @@ llvm::cl::opt nnpaLoadDevicePlacementFile{ llvm::cl::desc( "Load device placement configuration from a JSON file. To " "have a template for the JSON file, use " - "--nnpa-save-device-placement-file=cfg.json. Note that we can use " + "--nnpa-save-device-placement-file=cfg.json.\nNote that we can use " "regex for " - "string values in the JSON file to match operations. The compiler uses " + "string values in the JSON file to match operations.\nThe compiler " + "uses " "C++ std::regex_match function for matching."), llvm::cl::init(""), llvm::cl::cat(OnnxMlirOptions)}; @@ -87,11 +88,11 @@ llvm::cl::opt nnpaPlacementHeuristic{ "[Optional] Choose NNPA-related heuristic to place operations " "on NNPA device:"), llvm::cl::values( - clEnumVal(QualifyingOps, "Place all qualifying ops on NNPA (default)"), - clEnumVal(FasterOps, "Place qualifying ops that are faster on NNPA"), - clEnumVal(FasterOpsWSU, "FasterOps with stick/unstick cost"), + clEnumVal(QualifyingOps, "Place all qualifying ops on NNPA (default)."), + clEnumVal(FasterOps, "Place qualifying ops that are faster on NNPA."), + clEnumVal(FasterOpsWSU, "FasterOps with stick/unstick cost."), clEnumVal(MuchFasterOpsWSU, - "Much/Significantly FasterOps with stick/unstick cost")), + "Much/Significantly FasterOps with stick/unstick cost.")), llvm::cl::init(QualifyingOps), llvm::cl::cat(OnnxMlirOptions)}; llvm::cl::opt nnpaEnableSaturation("nnpa-saturation", @@ -100,4 +101,27 @@ llvm::cl::opt nnpaEnableSaturation("nnpa-saturation", "Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU("nnpa-cpu-dql", + llvm::cl::desc("Use dynamic quantized linear on CPU. Default is false"), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); + +llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset( + "nnpa-cpu-dql-scale", + llvm::cl::desc("Use dynamic quantized linear computation of " + " scale and offset on CPU. Default is false"), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); + +llvm::cl::opt nnpaQuantization("nnpa-quantization", + llvm::cl::desc("Enable quantization with a specific type. Only " + "MatMul whose weight is a constant is supported."), + llvm::cl::values( + clEnumVal(DynSymI8, + "Dynamic Quantization to signed integer 8. Asymmetric " + "quant for activations and symmetric quant for weights."), + clEnumVal(SymSymI8, + "Dynamic Quantization to signed integer 8. Symmetric " + "quant for activations and symmetric quant for weights."), + clEnumVal(QNONE, "No quantization (default).")), + llvm::cl::init(QNONE), llvm::cl::cat(OnnxMlirOptions)); + } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index 2b0343295c..366efee3fe 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -55,6 +55,15 @@ typedef enum { MuchFasterOpsWSU, /* FasterOpsWSU only if significantly faster. */ } NNPAPlacementHeuristic; +// Quantization type +typedef enum { + DynSymI8, /* Dynamic quantization to signed integer 8. Asymmetric quant for + activations and symmetric quant for weights.*/ + SymSymI8, /* Dynamic quantization to signed integer 8. Symmetric quant for + activations and symmetric quant for weights.*/ + QNONE, /* Only qualifying ops that are faster on NNPA. */ +} NNPAQuantType; + extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirCommonOptions; extern llvm::cl::opt nnpaEmissionTarget; @@ -68,6 +77,9 @@ extern llvm::cl::opt profileZHighIR; extern llvm::cl::opt nnpaLoadDevicePlacementFile; extern llvm::cl::opt nnpaSaveDevicePlacementFile; extern llvm::cl::opt nnpaEnableSaturation; +extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU; +extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset; +extern llvm::cl::opt nnpaQuantization; } // namespace onnx_mlir #endif diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 2d411da967..d7c5cfcac0 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -32,9 +32,11 @@ #include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" #include "src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp" +#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" +#include "src/Accelerators/NNPA/Support/NNPALimit.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerPasses.hpp" #include "src/Pass/Passes.hpp" @@ -48,15 +50,10 @@ namespace onnx_mlir { void configurePassesNNPA() { configureOnnxToZHighLoweringPass(optReport == OptReport::NNPAUnsupportedOps); - // Compiler generated sticks supports saturation, so force its usage. - // TODO: remove this if zDNN adds support for saturation. - if (nnpaEnableSaturation) + // z16 does not support for hardware saturation. + // So, force its usage to compiler generated sticks. + if (nnpaEnableSaturation && isLessEqualNNPALevel(NNPALevel::M14)) nnpaEnableCompilerStickUnstick = true; - // Currently nnpaEnableCompilerStickUnstick not supported on zOS. - // TODO enable on zOS - if (mtriple == "s390x-ibm-zos") { - nnpaEnableCompilerStickUnstick = false; - } } void addONNXToZHighPasses(mlir::PassManager &pm) { @@ -88,12 +85,14 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass( onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); - pm.addPass(onnx_mlir::createONNXToZHighPass()); + pm.addPass(onnx_mlir::createONNXToZHighPass(nnpaQuantization)); pm.addNestedPass(onnx_mlir::createShapeInferencePass()); + // There are more opportunities for const propagation once all zhigh ops were // generated. pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); pm.addPass(mlir::createCanonicalizerPass()); + // Layout propagation at ZHighIR. pm.addNestedPass( onnx_mlir::zhigh::createZHighLayoutPropagationPass()); @@ -110,13 +109,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); } - // After all optimizations, if there are still light-weight ops (e.g. add, - // sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to - // use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle - // these ops, e.g vectorize the computation. - if (nnpaEnableZHighToOnnx) - pm.addNestedPass(onnx_mlir::createZHighToONNXPass()); - // One more call to ONNX shape inference/canonicalization/... to update shape // if possible. if (enableONNXHybridPass) { @@ -130,17 +122,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass(onnx_mlir::createShapeInferencePass()); } - // Replace every DisposableElementsAttr with DenseElementsAttr. - // ZHighConstPropagation currently assumes that DenseElementsAttr is used. - pm.addPass(createScrubDisposablePass()); - - // Constant propagation at ZHighIR: constant stickify. - // Only support BE machines. - bool isBE = llvm::endianness::native == llvm::endianness::big; - if (isBE) - pm.addNestedPass( - onnx_mlir::zhigh::createZHighConstPropagationPass()); - // Experimental feature: Decompose stick/unstick into two phases: layout // transform and data conversion. Do some optimizations after decomposing. // Then, recompose again layout and data conversion if they are not optimized. @@ -152,12 +133,28 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { onnx_mlir::zhigh::createZHighRecomposeToStickUnstickPass()); } + // After all optimizations, if there are still light-weight ops (e.g. add, + // sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to + // use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle + // these ops, e.g vectorize the computation. + if (nnpaEnableZHighToOnnx) + pm.addNestedPass(onnx_mlir::createZHighToONNXPass()); + + // Constant propagation at ZHighIR: constant stickify. + // Only support BE machines. + bool isBE = llvm::endianness::native == llvm::endianness::big; + if (isBE) + pm.addPass(onnx_mlir::zhigh::createZHighConstPropagationPass()); + // Remove common sub-expressions. pm.addPass(mlir::createCSEPass()); // Clean dead code. pm.addPass(mlir::createSymbolDCEPass()); + // Replace every DisposableElementsAttr with DenseElementsAttr. + pm.addPass(onnx_mlir::zhigh::createZHighScrubDisposablePass()); + // Insert an instrumentation after lowering onnx to zhigh to get profiling // for onnx and zhigh ops. // Keep this pass at the end of this function. @@ -198,7 +195,11 @@ void addPassesNNPA(mlir::OwningOpRef &module, // LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;); if (emissionTarget >= EmitONNXIR) { - addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty()); + pm.addInstrumentation( + std::make_unique( + pm.getContext())); + addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(), + /*donotScrubDisposableElementsAttr*/ true); pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile, nnpaSaveDevicePlacementFile, nnpaPlacementHeuristic)); } diff --git a/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp b/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp new file mode 100644 index 0000000000..d5c1da2d3f --- /dev/null +++ b/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- ZHighDisposableGarbageCollector.cpp -----------------===// +// +// Garbage collects DisposableElementsAttr attributes. +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" +#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" +#include "src/Dialect/ONNX/ONNXDialect.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "mlir/IR/BuiltinOps.h" + +using namespace mlir; + +namespace onnx_mlir { +namespace zhigh { + +ZHighDisposableGarbageCollector::ZHighDisposableGarbageCollector( + MLIRContext *context) + : disposablePool(*DisposablePool::get(context)) {} + +ZHighDisposableGarbageCollector::~ZHighDisposableGarbageCollector() {} + +void ZHighDisposableGarbageCollector::runAfterPass(Pass *pass, Operation *op) { + if (!disposablePool.isActive()) + return; + ModuleOp moduleOp = mlir::dyn_cast(op); + if (!moduleOp) + return; + disposablePool.garbageCollectUnreachable( + moduleOp, {{ONNXConstantOp::getOperationName(), "value"}, + {ONNXConstantOfShapeOp::getOperationName(), "value"}, + {ZHighStickifiedConstantOp::getOperationName(), "value"}}); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp b/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp new file mode 100644 index 0000000000..c4a34d50eb --- /dev/null +++ b/src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- ZHighDisposableGarbageCollector.hpp -----------------===// +// +// Garbage collects DisposableElementsAttr attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H +#define ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H + +#include "mlir/Pass/PassInstrumentation.h" + +namespace mlir { +class MLIRContext; +} + +namespace onnx_mlir { +class DisposablePool; + +namespace zhigh { + +struct ZHighDisposableGarbageCollector : public mlir::PassInstrumentation { + ZHighDisposableGarbageCollector(mlir::MLIRContext *context); + ~ZHighDisposableGarbageCollector() override; + + void runAfterPass(mlir::Pass *pass, mlir::Operation *op) override; + +private: + DisposablePool &disposablePool; +}; + +} // namespace zhigh +} // namespace onnx_mlir +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp index 58e6439897..47724d8d3e 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp @@ -200,13 +200,13 @@ void DevicePlacementPass::runOnOperation() { // Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh. // E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv. RewritePatternSet Patterns2(context); - getONNXToZHighMultipleOpPatterns(Patterns2); + getONNXToZHighMultipleOpPatterns(Patterns2, nnpaQuantization); (void)applyAnalysisConversion(module, target, std::move(Patterns2), ConversionConfig{.legalizableOps = &legalizedOps2}); // Call ONNXToZHigh pass for lowering a single ONNX op to ZHigh. RewritePatternSet Patterns3(context); - getONNXToZHighOneOpPatterns(Patterns3); + getONNXToZHighOneOpPatterns(Patterns3, nnpaQuantization); getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis); (void)applyAnalysisConversion(module, target, std::move(Patterns3), ConversionConfig{.legalizableOps = &legalizedOps3}); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 1818c37939..76fa3fa547 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -15,6 +15,7 @@ #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" #include "src/Accelerators/NNPA/Support/NNPALimit.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" @@ -38,12 +39,16 @@ bool onnxToZHighUnsupportedReport(Operation *op, const std::string &message) { /// Report incompatibility with NNPA Level. bool onnxToZHighInCompatibilityReport( - Operation *op, std::string inputNNPALevel) { - std::string message = - "onnx-mlir NNPA level (" + inputNNPALevel + - ") is not compatible with NNPA level specified by '-mcpu'(" + mcpu + - ")."; - return onnxToZHighUnsupportedReport(op, message); + Operation *op, const std::string &message) { + std::string compilerNNPALevelStr = getNNPAString(getNNPAFromFlags()); + std::string errorMessage = + "onnx-mlir NNPA level \"" + message + "\" is not compatible with " + + "NNPA level specified by \"" + compilerNNPALevelStr + "\"."; + return onnxToZHighUnsupportedReport(op, errorMessage); +} + +bool onnxToZHighInCompatibilityReport(Operation *op, NNPALevel level) { + return onnxToZHighInCompatibilityReport(op, getNNPAString(level)); } /// A function to check whether a value's element type is valid for zAIU or not. @@ -315,7 +320,7 @@ bool meetPoolParamRestrictions(Operation *op, int64_t inputShape, return onnxToZHighUnsupportedReport(op, message); } if (paddingType == "SAME_PADDING") { - int64_t reqOutputShape = ceil((float)inputShape / strides); + int64_t reqOutputShape = ceil(static_cast(inputShape) / strides); if (outputShape != reqOutputShape) { std::string message = "When the strides (" + std::to_string(strides) + @@ -329,7 +334,7 @@ bool meetPoolParamRestrictions(Operation *op, int64_t inputShape, } } else { // VALID_PADDING int64_t reqOutputShape = - ceil((float)(inputShape - kernelShape + 1) / strides); + ceil(static_cast(inputShape - kernelShape + 1) / strides); if (outputShape != reqOutputShape) { std::string message = "When the strides (" + std::to_string(strides) + ") and the padding type is VALID_PADDING, output " @@ -357,8 +362,8 @@ template <> bool isSuitableForZDNN( ONNXAddOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) { - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) { + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); } if (!isValidElementTypeAndRank(op.getOperation(), op.getA())) return false; @@ -376,8 +381,8 @@ template <> bool isSuitableForZDNN( ONNXSubOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getA())) return false; if (!isValidElementTypeAndRank(op.getOperation(), op.getB())) @@ -394,8 +399,8 @@ template <> bool isSuitableForZDNN( ONNXMulOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getA())) return false; if (!isValidElementTypeAndRank(op.getOperation(), op.getB())) @@ -414,8 +419,8 @@ bool isSuitableForZDNN( Value A = op.getA(); Value B = op.getB(); // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Broadcast with a scalar operand. if (isEnableScalarBcastBinary()) { if (isF32ScalarConstantTensor(A) && @@ -442,8 +447,8 @@ template <> bool isSuitableForZDNN( ONNXSumOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Do not support a single input. if (op.getData_0().size() < 2) return onnxToZHighUnsupportedReport(op.getOperation(), @@ -473,8 +478,8 @@ template <> bool isSuitableForZDNN( ONNXMinOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); int64_t opnum = op.getNumOperands(); if (opnum != 2) return onnxToZHighUnsupportedReport(op.getOperation(), @@ -491,13 +496,13 @@ bool isSuitableForZDNN( } /// Check legality for ONNXMax. -/// zDNN Min/Max do not support boradcasting, and getNumOperands != 2. +/// zDNN Min/Max do not support broadcasting, and getNumOperands != 2. template <> bool isSuitableForZDNN( ONNXMaxOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); int64_t opnum = op.getNumOperands(); if (opnum != 2) return onnxToZHighUnsupportedReport(op.getOperation(), @@ -520,8 +525,8 @@ template <> bool isSuitableForZDNN( ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getInput())) return false; ShapedType inputType = mlir::cast(op.getType()); @@ -541,13 +546,37 @@ bool isSuitableForZDNN( return true; } +/// Check legality for ONNXLeakyRelu. +template <> +bool isSuitableForZDNN( + ONNXLeakyReluOp op, const DimAnalysis *dimAnalysis) { + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); + if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) + return false; + return true; +} + /// Check legality for ONNXRelu. template <> bool isSuitableForZDNN( ONNXReluOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); + if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) + return false; + return true; +} + +/// Check legality for ONNXGelu. +template <> +bool isSuitableForZDNN( + ONNXGeluOp op, const DimAnalysis *dimAnalysis) { + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) return false; return true; @@ -558,8 +587,8 @@ template <> bool isSuitableForZDNN( ONNXTanhOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getInput())) return false; return true; @@ -570,8 +599,20 @@ template <> bool isSuitableForZDNN( ONNXSigmoidOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); + if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) + return false; + return true; +} + +/// Check legality for ONNXSqrt. +template <> +bool isSuitableForZDNN( + ONNXSqrtOp op, const DimAnalysis *dimAnalysis) { + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) return false; return true; @@ -582,8 +623,8 @@ template <> bool isSuitableForZDNN( ONNXLogOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getInput())) return false; return true; @@ -594,8 +635,8 @@ template <> bool isSuitableForZDNN( ONNXExpOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); if (!isValidElementTypeAndRank(op.getOperation(), op.getInput())) return false; return true; @@ -606,8 +647,8 @@ template <> bool isSuitableForZDNN( ONNXMatMulOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); int64_t opnum = op.getNumOperands(); if (opnum != 2) return onnxToZHighUnsupportedReport(op.getOperation(), @@ -663,10 +704,10 @@ bool isSuitableForZDNN( } return true; } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) { - // stacked w/ bcast + // stacked w/ bcast23 case if (aType.hasStaticShape() && bType.hasStaticShape()) { if (shapeA[2] != shapeB[0]) { - std::string message = "Stacked w/ bcast case: the 3rd dim of A (" + + std::string message = "Stacked w/ bcast23 case: the 3rd dim of A (" + std::to_string(shapeA[2]) + ") and the 1st dim of B (" + std::to_string(shapeB[0]) + ") are not the same."; @@ -674,6 +715,21 @@ bool isSuitableForZDNN( } } return true; + } else if ((shapeA.size() == 2) && (shapeB.size() == 3)) { + // stacked w/ bcast1 case + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M15); + if (aType.hasStaticShape() && bType.hasStaticShape()) { + if (shapeA[1] != shapeB[1]) { + std::string message = "Stacked w/ bcast1 case: the 2nd dim of A (" + + std::to_string(shapeA[1]) + + ") and the 2nd dim of B (" + + std::to_string(shapeB[1]) + ") are not the same."; + return onnxToZHighUnsupportedReport(op.getOperation(), message); + } + } + return true; } std::string message = "Dim size of A(" + std::to_string(shapeA.size()) + ") and B(" + std::to_string(shapeB.size()) + @@ -681,6 +737,141 @@ bool isSuitableForZDNN( return onnxToZHighUnsupportedReport(op.getOperation(), message); } +/// Check legality for ONNXMatMulInteger. +template <> +bool isSuitableForZDNN( + ONNXMatMulIntegerOp op, const DimAnalysis *dimAnalysis) { + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); + + // Only support per-tensor quantization. + Value AZeroPoint = op.getAZeroPoint(); + Value BZeroPoint = op.getBZeroPoint(); + if (!isScalarTensor(AZeroPoint)) + return onnxToZHighInCompatibilityReport( + op.getOperation(), "A's zeropoint is not scalar"); + if (!isScalarTensor(BZeroPoint)) + return onnxToZHighInCompatibilityReport( + op.getOperation(), "B's zeropoint is not scalar"); + + ShapedType aType = mlir::cast(op.getA().getType()); + ShapedType bType = mlir::cast(op.getB().getType()); + + // Illegal if A or B is unranked. + if (!aType.hasRank() || !bType.hasRank()) + return false; + + auto shapeA = aType.getShape(); + auto shapeB = bType.getShape(); + + // In case of Tensors with unknown dimension, check only size of matrices. + // Actual shape is not checked. If actual shape does not meet, get error at + // runtime. + // TODO: Support other cases + // (https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul) on zDNN + // by using broadcasting etc. + if ((shapeA.size() == 2) && (shapeB.size() == 2)) { + // unstacked case + if (aType.hasStaticShape() && bType.hasStaticShape()) + return (shapeA[1] == shapeB[0]); + else + return true; + } else if ((shapeA.size() == 3) && (shapeB.size() == 3)) { + // stacked w/o bcast case + if (aType.hasStaticShape() && bType.hasStaticShape()) + return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1])); + else + return true; + } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) { + // stacked w/ bcast + if (aType.hasStaticShape() && bType.hasStaticShape()) + return (shapeA[2] == shapeB[0]); + else + return true; + } + + return false; // unsupported case +} + +/// Check legality for ONNXQLinearMatMul. +template <> +bool isSuitableForZDNN( + ONNXQLinearMatMulOp op, const DimAnalysis *dimAnalysis) { + Value A = op.getA(); + Value AScale = op.getAScale(); + Value AZeroPoint = op.getAZeroPoint(); + Value B = op.getB(); + Value BScale = op.getBScale(); + Value BZeroPoint = op.getBZeroPoint(); + Value Y = op.getY(); + Value YScale = op.getYScale(); + + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); + + // Only support float32 <-> int8/uint8. + Type elemTyA = getElementType(A.getType()); + Type elemTyAScale = getElementType(AScale.getType()); + Type elemTyB = getElementType(B.getType()); + Type elemTyBScale = getElementType(BScale.getType()); + Type elemTyY = getElementType(Y.getType()); + Type elemTyYScale = getElementType(YScale.getType()); + + if (!elemTyAScale.isF32() || !elemTyBScale.isF32() || !elemTyYScale.isF32()) + return false; + if (!(elemTyA.isInteger(8) || elemTyA.isUnsignedInteger(8))) + return false; + if (!(elemTyB.isInteger(8) || elemTyB.isUnsignedInteger(8))) + return false; + if (!(elemTyY.isInteger(8) || elemTyY.isUnsignedInteger(8))) + return false; + + // Only support per-tensor quantization. + if (!isScalarTensor(AScale) || !isScalarTensor(BScale) || + !isScalarTensor(AZeroPoint) || !isScalarTensor(BZeroPoint)) + return false; + + ShapedType aType = mlir::cast(A.getType()); + ShapedType bType = mlir::cast(B.getType()); + + // Illegal if A or B is unranked. + if (!aType.hasRank() || !bType.hasRank()) + return false; + + auto shapeA = aType.getShape(); + auto shapeB = bType.getShape(); + + // In case of Tensors with unknown dimension, check only size of matrices. + // Actual shape is not checked. If actual shape does not meet, get error at + // runtime. + // TODO: Support other cases + // (https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul) on zDNN + // by using broadcasting etc. + if ((shapeA.size() == 2) && (shapeB.size() == 2)) { + // unstacked case + if (aType.hasStaticShape() && bType.hasStaticShape()) + return (shapeA[1] == shapeB[0]); + else + return true; + } else if ((shapeA.size() == 3) && (shapeB.size() == 3)) { + // stacked w/o bcast case + if (aType.hasStaticShape() && bType.hasStaticShape()) + return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1])); + else + return true; + } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) { + // stacked w/ bcast + if (aType.hasStaticShape() && bType.hasStaticShape()) + return (shapeA[2] == shapeB[0]); + else + return true; + } + + return false; // unsupported case +} + /// Check legality for ONNXGemm. template <> bool isSuitableForZDNN( @@ -690,8 +881,8 @@ bool isSuitableForZDNN( Value C = op.getC(); // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check data type. if (!isValidElementTypeAndRank(op.getOperation(), A)) @@ -759,13 +950,91 @@ bool isSuitableForZDNN( return true; } +// Common function for ReduceMax and ReduceMin +template +static bool checkReduceParam(OP_TYPE op) { + OpBuilder b(op); + Location loc = op.getLoc(); + IndexExprBuilderForAnalysis createIE(loc); + IndexExprScope ieScope(&b, loc); + + Value data = op.getData(); + Value axesVal = op.getAxes(); + int64_t keepdims = op.getKeepdims(); + int64_t noop_with_empty_axes = op.getNoopWithEmptyAxes(); + + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15); + + // Check data type. + int64_t rank = getRank(data.getType()); + if (!isValidElementTypeAndRank(op.getOperation(), data)) + return false; + + // NNPA does not support reduction over all axes. + if (isNoneValue(axesVal)) + return onnxToZHighUnsupportedReport( + op.getOperation(), "Does not support reduction over all axes."); + + // Check keepdims and noop_with_empty_axes, we only support the default + // value. Attributes: keepdims (default is 1) and noop_with_empty_axes + // (default is 0) + if ((noop_with_empty_axes == 1) || (keepdims == 0)) { + std::string message = "`noop_with_empty_axes` (" + + std::to_string(noop_with_empty_axes) + + ") must be 0 and `keepdims` (" + + std::to_string(keepdims) + ") must be 1."; + return onnxToZHighUnsupportedReport(op, message); + } + + // Check axes value + DimsExpr axesIE; + createIE.getIntFromArrayAsDims(axesVal, axesIE); + if (axesIE.size() != 1) + return onnxToZHighUnsupportedReport( + op.getOperation(), "Does not support multiple reduction axes."); + if (!axesIE[0].isLiteral()) + return onnxToZHighUnsupportedReport( + op.getOperation(), "Reduction axis is unknown at compile time."); + int64_t axis = axesIE[0].getLiteral(); + // Accepted range is [-r, r-1] where r = rank(data) + if (axis < -rank || axis > rank - 1) { + std::string message = + "Reduction axis is out of the accepted range which is [-r, r-1]"; + return onnxToZHighUnsupportedReport(op, message); + } + if ((axis != -1) && (axis != rank - 1)) { + std::string message = "Reduction axis must be the innermost dimension. "; + return onnxToZHighUnsupportedReport(op, message); + } + + return true; +} + +/// Check legality for ONNXReduceMax. +template <> +bool isSuitableForZDNN( + ONNXReduceMaxOp op, const DimAnalysis *dimAnalysis) { + // Check parameter restrictions for ReduceMax + return checkReduceParam(op); +} + +/// Check legality for ONNXReduceMin. +template <> +bool isSuitableForZDNN( + ONNXReduceMinOp op, const DimAnalysis *dimAnalysis) { + // Check parameter restrictions for ReduceMin + return checkReduceParam(op); +} + /// Check legality for ONNXReduceMeanV13. template <> bool isSuitableForZDNN( ONNXReduceMeanV13Op op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check data type. if (!isValidElementTypeAndRank(op.getOperation(), op.getData())) @@ -826,7 +1095,7 @@ template <> bool isSuitableForZDNN( ONNXSoftplusOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) return false; if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) return false; @@ -844,8 +1113,8 @@ bool isSuitableForZDNN( Value B = op.getB(); // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check direction. if ((direction != FORWARD) && (direction != REVERSE) && @@ -869,7 +1138,8 @@ bool isSuitableForZDNN( std::string message = "The first dimension of weight tensor `W` for `num_directions` (" + std::to_string(wShape[0]) + - ") must be 1 or 2, and the second dimension of it for `hidden_size` (" + + ") must be 1 or 2, and the second dimension of it for `hidden_size` " + "(" + std::to_string(wShape[1]) + ") must be static."; return onnxToZHighUnsupportedReport(op.getOperation(), message); } @@ -877,9 +1147,9 @@ bool isSuitableForZDNN( ArrayRef rShape = mlir::cast(R.getType()).getShape(); if (!mlir::cast(R.getType()).hasStaticShape() || (rShape[0] != 1 && rShape[0] != 2)) { - std::string message = - "The recurrence weight tensor `R` must have static dimension, and the " - "first dimension of it must be 1 or 2."; + std::string message = "The recurrence weight tensor `R` must have static " + "dimension, and the " + "first dimension of it must be 1 or 2."; return onnxToZHighUnsupportedReport(op.getOperation(), message); } // Check hidden_size. @@ -957,8 +1227,8 @@ bool isSuitableForZDNN( Value B = op.getB(); // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check direction. if ((direction != FORWARD) && (direction != REVERSE) && @@ -982,7 +1252,8 @@ bool isSuitableForZDNN( std::string message = "The first dimension of weight tensor `W` for `num_directions` (" + std::to_string(wShape[0]) + - ") must be 1 or 2, and the second dimension of it for `hidden_size` (" + + ") must be 1 or 2, and the second dimension of it for `hidden_size` " + "(" + std::to_string(wShape[1]) + ") must be static."; return onnxToZHighUnsupportedReport(op.getOperation(), message); } @@ -1062,8 +1333,8 @@ template <> bool isSuitableForZDNN( ONNXMaxPoolSingleOutOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check data type. if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) @@ -1094,8 +1365,8 @@ template <> bool isSuitableForZDNN( ONNXAveragePoolOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check data type. if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) @@ -1111,9 +1382,9 @@ bool isSuitableForZDNN( ONNXAveragePoolOpShapeHelper>(op, op.getY(), dimAnalysis); } -/// Check if input, output, kernel, strides, and paddingType for each axis meet -/// parameter restrictions for conv2d. See "Conv2D Parameter Restrictions" -/// in "zDNN API Reference" +/// Check if input, output, kernel, strides, and paddingType for each axis +/// meet parameter restrictions for conv2d. See "Conv2D Parameter +/// Restrictions" in "zDNN API Reference" static bool checkConv2DParamRestrictions(Operation *op, int64_t inputDim, int64_t kernelDim, int64_t stride, int64_t outputDim, StringRef paddingType) { @@ -1164,7 +1435,7 @@ static bool checkConv2DParamRestrictions(Operation *op, int64_t inputDim, } if (paddingType == "SAME_PADDING") { // height_out restriction. - int64_t reqOutputShape = ceil((float)inputDim / stride); + int64_t reqOutputShape = ceil(static_cast(inputDim) / stride); if (outputDim != reqOutputShape) { std::string message = "When the strides (" + std::to_string(stride) + @@ -1189,7 +1460,8 @@ static bool checkConv2DParamRestrictions(Operation *op, int64_t inputDim, return onnxToZHighUnsupportedReport(op, message); } // height_out restriction. - int64_t reqOutputShape = ceil((float)(inputDim - kernelDim + 1) / stride); + int64_t reqOutputShape = + ceil(static_cast(inputDim - kernelDim + 1) / stride); if (outputDim != reqOutputShape) { std::string message = "When the strides (" + std::to_string(stride) + @@ -1217,8 +1489,8 @@ template <> bool isSuitableForZDNN( ONNXConvOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // Check data type. if (!isValidElementTypeAndRank(op.getOperation(), op.getX())) @@ -1254,7 +1526,8 @@ bool isSuitableForZDNN( ShapedType::isDynamic(shapeOutput[2]) || ShapedType::isDynamic(shapeOutput[3])) return onnxToZHighUnsupportedReport(op, - "Height and/or width have dynamic dimensions. They are not supported."); + "Height and/or width have dynamic dimensions. They are not " + "supported."); // Do not support group. if (operandAdaptor.getGroup() != 1) @@ -1270,7 +1543,8 @@ bool isSuitableForZDNN( } // `getStrPaddingType` returns `SAME_PADDING`, `VALID_PADDING`, or empty. - // `zdnn_conv2d` only support padding for `SAME_PADDING` and `VALID_PADDING`. + // `zdnn_conv2d` only support padding for `SAME_PADDING` and + // `VALID_PADDING`. StringRef paddingType = getStrPaddingType( op); @@ -1323,8 +1597,8 @@ bool isSuitableForZDNN( ArrayRef shapeOutput = outputType.getShape(); // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14); // 4D tensors(N x C x H x W) are supported as input and output. if (shapeInput.size() != 4 || shapeOutput.size() != 4) @@ -1343,3 +1617,19 @@ bool isSuitableForZDNN( // Noop Reshape is suitable for zAIU as this pass removes such reshape ops. return isIdentityReshape(op, dimAnalysis); } + +/// Check legality for ONNXDequantizeLinearOp. +template <> +bool isSuitableForZDNN( + ONNXDequantizeLinearOp op, const DimAnalysis *dimAnalysis) { + // The pass rewrite-onnx-for-zhigh has a rule to rewrite the pattern + // `DequantizeLinear (QLinearMatMul inputs)` where ND inputs are reshaped + // into 3D inputs. This rule uses the function template + // `addDynamicallyLegalOpFor` to define legality using a custom lambda + // function instead of `isSuitableForZDNN`. Hence, the legality here should + // not be used/called. This legality is here to complete the function + // template `addDynamicallyLegalOpFor` so that it's not failed when building + // the compiler. + llvm_unreachable("Not used"); + return false; +} diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp index f9c36372c4..09bfa6f4f6 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp @@ -17,6 +17,7 @@ #ifndef ONNX_MLIR_LEGALITY_H #define ONNX_MLIR_LEGALITY_H +#include "src/Accelerators/NNPA/Support/NNPALimit.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -53,6 +54,8 @@ bool onnxToZHighUnsupportedReport( mlir::Operation *op, const std::string &message); bool onnxToZHighInCompatibilityReport( - mlir::Operation *op, std::string inputNNPALevel); + mlir::Operation *op, const std::string &message); + +bool onnxToZHighInCompatibilityReport(mlir::Operation *op, NNPALevel level); #endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 6b3abd8947..78e94a6a2a 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -4,7 +4,7 @@ //====------ ONNXToZHigh.cpp - ONNX dialect to ZHigh lowering -------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -13,6 +13,9 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Support/Debug.h" + +#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp" #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" @@ -25,6 +28,8 @@ #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" #include "src/Dialect/ONNX/Transforms/ShapeInference.hpp" +#define DEBUG_TYPE "onnx-to-zhigh" + using namespace mlir; // @@ -33,6 +38,17 @@ using namespace mlir; namespace onnx_mlir { +using namespace zhigh; + +#define QUANT_PATTERN_BENEFIT 1000 + +/// Checks whether a constant tensor's elements are of type FloatType. +bool isFloatType(Value constValue) { + ElementsAttr constElements = getElementAttributeFromONNXValue(constValue); + Type elemType = constElements.getElementType(); + return mlir::isa(elemType); +} + ArrayAttr getLSTMGRUBiasSplitShape( Location loc, PatternRewriter &rewriter, ArrayRef shapeR) { int64_t hiddenSize = shapeR[2]; @@ -207,7 +223,7 @@ Value getLSTMGRUGetYc( SmallVector emitONNXSplitOp(Location loc, PatternRewriter &rewriter, Value input, IntegerAttr axis, ArrayAttr split) { Type elementType = mlir::cast(input.getType()).getElementType(); - SmallVector outputTypes; + SmallVector outputTypes; int64_t splitNum = split.size(); ArrayRef inputShape = mlir::cast(input.getType()).getShape(); @@ -253,6 +269,349 @@ SmallVector getArrayStrides(OP op) { return shapeHelper.strides; } +/// Get approximate +template +StringRef getStrApproximateType(OP op) { + return op.getApproximate(); +} + +// Computes the folded bias to be passed to quantized matmul call when +// operation is MATMUL_OP_ADDITION. Zb should be equal to 0, meaning the +// correction term for input_a is also equal to 0. This allows the +// correction term for input_b to be folded into qc_tilde, which removes the +// need for correction being applied after the quantized matmul call. +// +// The original equation for qc_tilde is: +// M = (Sa * Sb) / Sy +// qc_tilde = Zy - (Sc / Sy) * Zc + (Sc / Sy) * input_c[j] + M*N*Za*Zb +// +// Given Zb = 0, the equation becomes: +// M = (Sa * Sb) / Sy +// qc_tilde = Zy - (Sc / Sy) * Zc + (Sc / Sy) * input_c[j] +// +// Given scales are stored as the reciprocal in zTensor, the modified equation +// becomes: +// M = RSy / (RSa * RSb) +// qc_tilde = Zy - (RSy / RSc) * Zc + (RSy / RSc) * input_c[j] +// +// where RS = 1/S. +// +// We can reorder this to: +// M = RSy / (RSa * RSb) +// qc_tilde = input_c[j] * (RSy / RSc) + Zy - (RSy / RSc) * Zc +// +// This allows us to pre-compute a scale and offset to apply to input_c[j]: +// M = RSy / (RSa * RSb). +// scale = (RSy / RSc) +// offset = Zy - scale * Zc +// qc_tilde[j] = input_c[j] * scale + offset +// +// The original equation for the correction term for input_b is: +// M = (RSa * RSb) / RSy +// term_b = M * Za * sum(input_b[:,j]) +// +// Given scales are stored as the reciprocal, the modified equation becomes: +// M = RSy / (RSa * RSb) +// term_b = M * Za * sum(input_b[:,j]) +// +// This gives us the equation: +// M = RSy / (RSa * RSb) +// MZa = M * Za +// scale = (RSy / RSc) +// offset = Zy - scale * Zc +// qc_tilde[j] = input_c[j] * scale + offset - MZa * sum(input_b[:,j]) +// +// In case of MatMulInteger, input_c = 0, RSc = 1, Zc = 0, the final equation +// is: +// M = RSy / (RSa * RSb) +// MZa = M * Za +// scale = RSy +// offset = Zy +// qc_tilde[j] = offset - Za * (RSy / RSa / RSb) * sum(input_b[:,j]) +// +// When Zy = 0, qc_tilde[j] = -Za * (RSy / RSa / RSb) * sum(input_b[:,j]) +static void preComputeBias(MultiDialectBuilder &create, Value RSa, + Value Za, Value BI8, Value RSb, Value RSy, Value Zy, Value &qcTilde, + Value &RSqctilde, Value &Zqctilde) { + OpBuilder rewriter = create.getBuilder(); + Location loc = create.getLoc(); + + Type i64Ty = rewriter.getI64Type(); + Type f32Ty = rewriter.getF32Type(); + auto cstMinus2Attr = DenseElementsAttr::get( + RankedTensorType::get({}, i64Ty), static_cast(-2)); + auto cst0Attr = DenseElementsAttr::get( + RankedTensorType::get({}, f32Ty), static_cast(0)); + auto cst1Attr = DenseElementsAttr::get( + RankedTensorType::get({}, f32Ty), static_cast(1)); + + Value cst0 = create.onnx.constant(cst0Attr); + Value cst1 = create.onnx.constant(cst1Attr); + + // Can be optimized further when Zy is zero. + bool ZyIsZero = isDenseONNXConstant(Zy) && isConstOf(Zy, 0.); + + Value qcF32; + Value B = create.onnx.cast(BI8, f32Ty); + Value lastSecondAxis = create.onnx.constant(cstMinus2Attr); + // Emit: sum(input_b[:,j]) + Value BSum = create.onnx.reduceSum( + UnrankedTensorType::get(f32Ty), B, lastSecondAxis, false); + // RSy, RSa, RSb, Za are scalar, do scalar computation. + // Emit: Za * (RSy / RSa / RSb) + Value RSyRSa = create.onnx.div(RSy, RSa); + Value RSyRSaRSb = create.onnx.div(RSyRSa, RSb); + Value MZa = create.onnx.mul(RSyRSaRSb, Za); + // Negate ZaRSyRSa to avoid broadcasting Sub: + // `Zy - Za * (RSy / RSa / RSb) * ...` + MZa = create.onnx.sub(cst0, MZa); + // Broadcast ops. + // Emit: - Za * (RSy / RSa / RSb) * sum(input_b[:,j]) + Value MZaBSum = create.onnx.mul(MZa, BSum); + // Emit: Zy - Za * (RSy / RSa / RSb) * sum(input_b[:,j]) + if (ZyIsZero) { + qcF32 = MZaBSum; + } else { + qcF32 = create.onnx.add(Zy, MZaBSum); + } + + // Use 1 for recscale and 0 for offset. This is a dlfloat16 stickification. + int64_t rank = getRank(qcF32.getType()); + StringAttr layoutAttr = + rewriter.getStringAttr((rank == 1) ? LAYOUT_1D : LAYOUT_2DS); + ZHighQuantizedStickOp qcOp = rewriter.create(loc, + qcF32, cst1, cst0, layoutAttr, rewriter.getStringAttr(QTYPE_DLFLOAT16)); + qcTilde = qcOp.getResult(0); + RSqctilde = qcOp.getResult(1); + Zqctilde = qcOp.getResult(2); +} + +static Value getOrCastToI8(Value val, MultiDialectBuilder &create, + bool simpleCast = false) { + if (!getElementType(val.getType()).isUnsignedInteger()) + return val; + + Type i8Ty = create.getBuilder().getI8Type(); + if (simpleCast) + return create.onnx.cast(val, i8Ty); + + // Use int16 to avoid integer overflow. + Type i16Ty = create.getBuilder().getI16Type(); + auto cst128Attr = DenseElementsAttr::get( + RankedTensorType::get({}, i16Ty), static_cast(128)); + Value valI16 = create.onnx.cast(val, i16Ty); + valI16 = create.onnx.sub(valI16, create.onnx.constant(cst128Attr)); + Value valI8 = create.onnx.cast(valI16, i8Ty); + return valI8; +} + +// Dynamic quantization helper to match and rewrite values A, B, C of A*B+C. +class DynQuantI8PatternHelper { +public: + DynQuantI8PatternHelper(PatternRewriter &rewriter, Location loc, + Operation *op, Value A, Value B, Value C, bool symForA) + : rewriter(rewriter), loc(loc), op(op), A(A), B(B), C(C), + symForA(symForA) {} + + // Check the inputs A, B, C of `A*B+C` to see if they are suitable for doing + // dynamic quantization on NNPA. + LogicalResult match() { + // A is of f32. + if (!mlir::isa(getElementType(A.getType()))) + return rewriter.notifyMatchFailure(op, "MatMul's A is not of f32."); + + // Weight is a constant. + if (!isDenseONNXConstant(B)) + return rewriter.notifyMatchFailure(op, "MatMul's B is not a constant."); + + if (C) { + // Bias is a constant. + if (!isDenseONNXConstant(C)) + return rewriter.notifyMatchFailure(op, "MatMul's C is not a constant"); + // B and C shapes must be consistent. The reduction shape of B on the + // second dim from the last is the same as the shape of B, e.g. If B is + // [2x3x4], C must be [2x4]. + ArrayRef bShape = getShape(B.getType()); + ArrayRef cShape = getShape(C.getType()); + int64_t bRank = bShape.size(); + int64_t cRank = cShape.size(); + if (bRank - 1 != cRank) + return rewriter.notifyMatchFailure( + op, "The ranks of B and C are imcompatible."); + if (bShape[bRank - 1] != cShape[cRank - 1]) + return rewriter.notifyMatchFailure( + op, "The last dimensions of B and C are not the same."); + if (bShape.drop_back(2) != cShape.drop_back(1)) + return rewriter.notifyMatchFailure( + op, "The shapes of B and C are imcompatible."); + } + + return success(); + } + + // clang-format off + /* + * Emit the following code to compute `A*B+C` using i8 dynamic quantization. + * A can be quantized using asymmetric or symmetric quantization depending on + * the flag `symForA`, while B is always quantized using symmetric quantization. + * (Note that: If C is given, it will be added into the pre_computed_bias) + * + * ``` + * (Quantize A using asymmetric/symmetric quant by setting `sym_mode` attr to the `symForA` flag) + * %qa, %a_recscale, %a_offset = zhigh.QuantizedStick(%A, none, none) { quantized_type = QUANTIZED_DLFLOAT16, sym_mode = 1/0} + * + * (Quantize B using symmetric quant) + * %b_offset = 0 // Symmetric quant mode for i8. Offset is always zero, qmin = * -127, qmax = 127. + * %absmax = onnx.ReduceMax(onnx.Abs(%B)) + * %b_rescale = onnx.Div(127, absmax) + * %qb = onnx.cast(onnx.Clip(onnx.Round(onnx.Mul(%B, %b_rescale)), qmin, qmax)) + * %qb, %b_recscale, %b_offset = zhigh.QuantizedStick(%qb, %b_recscale, %b_offset) { quantized_type = QUANTIZED_WEIGHTS_INT8 } + * + * (Pre computed bias, %C is added) + * %qc = emit_ops_for_pre_computed_bias_at_compile_time + * %qc = zhigh.Add(%qc, zhigh.Stick(%C)) // only done if C is given. + * %qc_recscale = 1 + * %qc_offset = 0 + * + * %Y_recscale = 1 + * %Y_offset = 0 + * %Y, %Y_recscale, %Y_offset = zhigh.QuantizedMatMul (%qa, %a_recscale, %a_offset, + * %qb, %b_recscale, %b_offset, + * %qc, %c_recscale, %c_offset, + * %Y_recscale, %Y_offset) { + * PreComputedBias = true, DisableClipping = true, DequantizeOutput = false + * } + * ``` + * + * where the computation of `%qb` and `%qb_recscale` are expected to be folded by constant + * propagation so that they become constants. + * + * For more information about dynamic quantization, see https://www.maartengrootendorst.com/blog/quantization + */ + // clang-format on + Value rewriteSym() { + MultiDialectBuilder create(rewriter, loc); + + Type i8Ty = rewriter.getIntegerType(8); + Type si64Ty = rewriter.getIntegerType(64, true); + Type f16Ty = rewriter.getF16Type(); + Type f32Ty = rewriter.getF32Type(); + RankedTensorType scalarTy = RankedTensorType::get({}, f32Ty); + + IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1); + IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0); + + Value none = create.onnx.none(); + Value cst0 = create.onnx.constant( + DenseElementsAttr::get(scalarTy, static_cast(0))); + Value cst1 = create.onnx.constant( + DenseElementsAttr::get(scalarTy, static_cast(1))); + Value cst127 = create.onnx.constant( + DenseElementsAttr::get(scalarTy, static_cast(127))); + Value cstNeg127 = create.onnx.constant( + DenseElementsAttr::get(scalarTy, static_cast(-127))); + + int64_t rankA = getRank(A.getType()); + int64_t rankB = getRank(B.getType()); + StringAttr aLayoutAttr = + rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS); + StringAttr bLayoutAttr = + rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS); + + // Quantize and stickify A. + IntegerAttr symModeAttr = + rewriter.getIntegerAttr(rewriter.getI64Type(), symForA ? 1 : 0); + ZHighQuantizedStickOp qAOp = + rewriter.create(loc, A, none, none, aLayoutAttr, + rewriter.getStringAttr(QTYPE_DLFLOAT16), symModeAttr); + Value AI8 = qAOp.getResult(0); + Value ARecScale = qAOp.getResult(1); + Value AOffset = qAOp.getResult(2); + + // Quantize B. All computations here would be folded by constprop. + // Though computation here can be generalized for other integer types by + // changing qmin and qmax, we optimize it for i8 since NNPA supports i8 only + // at this moment. + // Symmetric mode for i8, meaning offset = 0, qmin = -127, qmax = 127. + Value BOffset = cst0; + Value qmin = cstNeg127; + Value qmax = cst127; + // %absmax = onnx.ReduceMax(onnx.Abs(%B)) + // %b_rescale = onnx.Div(127, absmax) + Value absMax = + create.onnx.reduceMax(scalarTy, create.onnx.abs(B), none, false, false); + Value BRecScale = create.onnx.div(cst127, absMax); + // %qb = onnx.Cast( + // onnx.Clip(onnx.Round(onnx.Mul(%B, %b_rescale)), qmin, qmax)) + Value BI8 = create.onnx.cast( + create.onnx.clip( + create.onnx.round(create.onnx.mul(B, BRecScale)), qmin, qmax), + i8Ty); + // Stickify B. + ZHighQuantizedStickOp qBOp = + rewriter.create(loc, BI8, BRecScale, BOffset, + bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS)); + + // Output information. + Value YRecScale = cst1; + Value YOffset = cst0; + + // When A is also quantized using symmetric mode, both correction terms for + // A and B are canceled out. Thus, no precomputation is needed. + Value qcTilde = none, qcTildeRecScale = cst1, qcTildeOffset = cst0; + if (!symForA) { + // When only B is quantized using symmetric mode, precompute the + // correction term for B only. + preComputeBias(create, ARecScale, AOffset, BI8, BRecScale, YRecScale, + YOffset, qcTilde, qcTildeRecScale, qcTildeOffset); + } + // Add up C into bias if C is given. + if (C) { + int64_t rankC = getRank(C.getType()); + assert((rankC == rankB - 1) && + "C has a wrong shape to be added into pre_computed_bias"); + assert((rankC == 1 || rankC == 2) && "Wrong rank for C"); + StringAttr cLayoutAttr = + rewriter.getStringAttr((rankC == 1) ? LAYOUT_1D : LAYOUT_2DS); + Value stickC = rewriter.create(loc, C, cLayoutAttr); + if (symForA) + qcTilde = stickC; + else + qcTilde = rewriter.create( + loc, qcTilde.getType(), qcTilde, stickC); + } + + // Emit zhigh.QuantizedMatMul. + // No need to dequantize since Y's rescale is 1. + // Do not clip the output values to i8, keep i32. + SmallVector resTypes; + resTypes.emplace_back(UnrankedTensorType::get(f16Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + ZHighQuantizedMatMulOp zhighQuantizedMatMulOp = + rewriter.create(loc, resTypes, AI8, ARecScale, + AOffset, qBOp.getResult(0), BRecScale, BOffset, qcTilde, + qcTildeRecScale, qcTildeOffset, + /*OutRecScale*/ YRecScale, /*OutOffset*/ YOffset, + /*PreComputedBias*/ trueAttr, /*DisableClipping*/ trueAttr, + /*DequantizeOutput*/ falseAttr); + (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {}); + + // Unstickify the matmul result that is int8-as-float. + Value res = rewriter.create( + loc, zhighQuantizedMatMulOp.getResult(0)); + return res; + } + +private: + PatternRewriter &rewriter; + Location loc; + Operation *op; + Value A, B, C; + // Whether do symmetric quant for activation input A or not. + bool symForA = false; +}; + //===----------------------------------------------------------------------===// // ONNX to ZHigh Lowering Pass //===----------------------------------------------------------------------===// @@ -262,9 +621,9 @@ namespace { #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXONNXToZHigh.inc" // Enhance 'replaceONNXSumOpPatternRecursion' to allow operating recursively. -struct ONNXSumOpPatternEnhancedRecursion +struct replaceONNXSumOpPatternEnhancedRecursion : public replaceONNXSumOpPatternRecursion { - ONNXSumOpPatternEnhancedRecursion(MLIRContext *context) + replaceONNXSumOpPatternEnhancedRecursion(MLIRContext *context) : replaceONNXSumOpPatternRecursion(context) {} void initialize() { // This pattern recursively unpacks one variadic operand at a time. The @@ -274,6 +633,892 @@ struct ONNXSumOpPatternEnhancedRecursion } }; +/** + * This is a pattern for doing i8 dynamic quantization (symmetric mode) for + * onnx.MatMul(%A, %B), where %B is a constant. + */ + +class replaceONNXMatMulByDynQuantI8Pattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + replaceONNXMatMulByDynQuantI8Pattern( + MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) + : OpRewritePattern(context, benefit), symForA(symForA) {} + + LogicalResult matchAndRewrite( + ONNXMatMulOp mmOp, PatternRewriter &rewriter) const override { + Location loc = mmOp.getLoc(); + Operation *op = mmOp.getOperation(); + Value A = mmOp.getA(); + Value B = mmOp.getB(); + + // Dynamic quantization helper. + DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, nullptr, symForA); + + // Match + if (!isSuitableForZDNN(mmOp) || failed(dqHelper.match())) + return rewriter.notifyMatchFailure(op, "MatMul is not suitable for zDNN"); + + // Rewrite + Value res = dqHelper.rewriteSym(); + rewriter.replaceOp(op, res); + return success(); + } + +private: + bool symForA = false; +}; + +/** + * This is a pattern for doing i8 dynamic quantization (symmetric mode) for + * `onnx.Add(onnx.MatMul(%A, %B), %C)`. where + * - %B and %C are a constant and + * - %B and %C must have compatible shape, i.e. the reduction shape on the last + * second dim of %B is the same as %C's shape. + */ +class replaceONNXMatMulAddByDynQuantI8Pattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + replaceONNXMatMulAddByDynQuantI8Pattern( + MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) + : OpRewritePattern(context, benefit), symForA(symForA) {} + + LogicalResult matchAndRewrite( + ONNXAddOp addOp, PatternRewriter &rewriter) const override { + Location loc = addOp.getLoc(); + Operation *op = addOp.getOperation(); + Value lhs = addOp.getOperand(0); + Value rhs = addOp.getOperand(1); + + // Match A*B+C and C+A*B where B and C are constants, and then rewrite. + Value AB, C; + if (!areDefinedBy(lhs, rhs, AB, C)) + return rewriter.notifyMatchFailure( + op, "MatMulAdd is not suitable for zDNN."); + ONNXMatMulOp mmOp = AB.getDefiningOp(); + Value A = mmOp.getA(); + Value B = mmOp.getB(); + + // Match A, B, C. + DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, C, symForA); + if (succeeded(dqHelper.match())) { + Value res = dqHelper.rewriteSym(); + rewriter.replaceOp(op, res); + return success(); + } + + return failure(); + } + +private: + bool symForA = false; +}; + +/** + * This is a pattern for doing i8 dynamic quantization (symmetric mode) for + * onnx.Gemm(%A, %B, %C), where %B and %C are constants. + * + * This pattern is applied only when the compiler option + * `--nnpa-quantization={DynSymI8|SymSymI8}` is specified. + * + */ + +class replaceONNXGemmByDynQuantI8Pattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + replaceONNXGemmByDynQuantI8Pattern( + MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) + : OpRewritePattern(context, benefit), symForA(symForA) {} + + LogicalResult matchAndRewrite( + ONNXGemmOp gemmOp, PatternRewriter &rewriter) const override { + Location loc = gemmOp.getLoc(); + Operation *op = gemmOp.getOperation(); + + Value A = gemmOp.getA(); + Value B = gemmOp.getB(); + Value C = gemmOp.getC(); + bool transA = (gemmOp.getTransA() != 0); + bool transB = (gemmOp.getTransB() != 0); + + // Dynamic quantization helper. + DynQuantI8PatternHelper dqHelper( + rewriter, loc, op, A, B, isNoneValue(C) ? nullptr : C, symForA); + + // Match + // TODO: if B is a constant and it is transposed, we can do transpose + // explicitly. + if (transA || transB) + return rewriter.notifyMatchFailure(op, "Gemm is with transpose"); + if (!isSuitableForZDNN(gemmOp)) + return rewriter.notifyMatchFailure(op, "Gemm is not suitable for zDNN"); + if (failed(dqHelper.match())) + return failure(); + + // Rewrite + Value res = dqHelper.rewriteSym(); + rewriter.replaceOp(op, res); + return success(); + } + +private: + bool symForA = false; +}; + +class replaceONNXMatMulIntegerPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXMatMulIntegerOp mmiOp, PatternRewriter &rewriter) const override { + Location loc = mmiOp.getLoc(); + Operation *op = mmiOp.getOperation(); + MultiDialectBuilder create(rewriter, loc); + + // Match + if (failed(canBeRewritten(rewriter, mmiOp))) + return failure(); + + Type si64Ty = rewriter.getIntegerType(64, true); + Type f16Ty = rewriter.getF16Type(); + Type f32Ty = rewriter.getF32Type(); + Type outElemTy = getElementType(mmiOp.getY().getType()); + IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1); + IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0); + + auto cst0Attr = DenseElementsAttr::get( + RankedTensorType::get({}, f32Ty), static_cast(0)); + auto cst1Attr = DenseElementsAttr::get( + RankedTensorType::get({}, f32Ty), static_cast(1)); + Value none = create.onnx.none(); + Value zero = create.onnx.constant(cst0Attr); + Value zeroI64 = create.onnx.constantInt64({0}); + Value one = create.onnx.constant(cst1Attr); + + // Prepare inputs for zhigh QuantizedMatMul. + + // I8 tensors + Value AI8 = getOrCastToI8(mmiOp.getA(), create, true); + Value BI8 = getOrCastToI8(mmiOp.getB(), create, true); + + // Zero points in f32. + Value AZeroPointI8 = mmiOp.getAZeroPoint(); + if (getRank(AZeroPointI8.getType()) == 1) { + // Normalize the zeropoint tensor to tensor. + AZeroPointI8 = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(AZeroPointI8.getType())), + AZeroPointI8, {zeroI64}); + } + AZeroPointI8 = getOrCastToI8(AZeroPointI8, create, true); + Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty); + // TESTING: minus zeropoint in advance to cancel out the software part of + // zdnn quantized matmul. + // AI8 = create.onnx.sub(AI8, AZeroPointI8); + // Value AZeroPointF32 = zero; + Value BZeroPointI8 = mmiOp.getBZeroPoint(); + if (getRank(BZeroPointI8.getType()) == 1) { + // Normalize the zeropoint tensor to tensor. + BZeroPointI8 = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(BZeroPointI8.getType())), + BZeroPointI8, {zeroI64}); + } + BZeroPointI8 = getOrCastToI8(BZeroPointI8, create, true); + Value BZeroPointF32 = create.onnx.cast(BZeroPointI8, f32Ty); + // TESTING: minus zeropoint in advance to cancel out the software part of + // zdnn quantized matmul. + // BI8 = create.onnx.sub(BI8, AZeroPointI8); + // Value BZeroPointF32 = zero; + Value YZeroPointF32 = zero; + + // Recscale in f32. + // Set recscale of A and B to 1. In dynamic quantization the output of + // MatMulInteger is scaled later outside the op. + Value ARecScale = one; + Value BRecScale = one; + Value YRecScale = one; + + // Only pre-compute bias when B is a constant and BZeroPoint is zero. + bool canPreComputeBias = isDenseONNXConstant(BI8) && + isDenseONNXConstant(BZeroPointI8) && + isConstOf(BZeroPointI8, 0.0); + + // Stickify AI8, Transform AI8 into zTensor format. + int64_t rankA = getRank(AI8.getType()); + StringAttr aLayoutAttr = + rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS); + ZHighQuantizedStickOp qAOp = + rewriter.create(loc, AI8, ARecScale, + AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8)); + + // Stickify BI8. It is potentially folded at compile time. + int64_t rankB = getRank(BI8.getType()); + StringAttr bLayoutAttr = + rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS); + ZHighQuantizedStickOp qBOp = + rewriter.create(loc, BI8, BRecScale, + BZeroPointF32, bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS)); + + // Bias is none or precomputed. + Value qcTilde, qcTildeRecScale, qcTildeZeroPointF32; + if (canPreComputeBias) + preComputeBias(create, ARecScale, AZeroPointF32, BI8, BRecScale, + YRecScale, YZeroPointF32, qcTilde, qcTildeRecScale, + qcTildeZeroPointF32); + + // Emit zhigh.QuantizedMatMul. Bias is none. + // Do not dequantize, we want to keep the integer values that will be scaled + // outside this op. + // Do not clip the output values to i8, keep i32. + SmallVector resTypes; + resTypes.emplace_back(UnrankedTensorType::get(f16Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + ZHighQuantizedMatMulOp zhighQuantizedMatMulOp = + rewriter.create(loc, resTypes, + qAOp.getResult(0), qAOp.getResult(1), qAOp.getResult(2), + qBOp.getResult(0), qBOp.getResult(1), qBOp.getResult(2), + /*Bias*/ canPreComputeBias ? qcTilde : none, + /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none, + /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPointF32 : none, + /*OutRecScale*/ YRecScale, /*OutOffset*/ YZeroPointF32, + /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr, + /*DisableClipping*/ trueAttr, + /*DequantizeOutput*/ falseAttr); + (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {}); + + // Unstickify the matmul result that is int8-as-float. + Value resI8F32 = rewriter.create( + loc, zhighQuantizedMatMulOp.getResult(0)); + Value res = create.onnx.cast(resI8F32, outElemTy); + + rewriter.replaceOp(op, res); + return success(); + } + + static mlir::LogicalResult canBeRewritten( + PatternRewriter &rewriter, ONNXMatMulIntegerOp mmiOp) { + if (!isSuitableForZDNN(mmiOp)) + return rewriter.notifyMatchFailure( + mmiOp, "MatMulInteger is not suitable for zDNN"); + return success(); + } +}; + +// Replace by zhigh ops the following pattern: +// clang-format off +// func.func @pattern_in_bert(%X: tensor) : (tensor) -> tensor { +// %y = onnx.Constant dense_resource<__elided__> : tensor<768x768xi8> +// %y_scale = onnx.Constant dense<0.00656270096> : tensor +// %y_zero_point = onnx.Constant dense<0> : tensor +// +// %x, %x_scale, %x_zero_point = "onnx.DynamicQuantizeLinear"(%X) : (tensor) -> (tensor, tensor, tensor) +// +// %matmul = "onnx.MatMulInteger"(%x, %y, %x_zero_point, %y_zero_point) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor +// %cast = "onnx.Cast"(%matmul) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// %mul_1= "onnx.Mul"(%cast, %x_scale) : (tensor, tensor) -> tensor +// %mul_2= "onnx.Mul"(%mul_1, %y_scale) : (tensor, tensor) -> tensor +// +// return %mul_2: tensor +// } +// clang-format on +class replaceMatMulIntegerSubGraphFromMulPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXMulOp mulOp, PatternRewriter &rewriter) const override { + Location loc = mulOp.getLoc(); + Operation *op = mulOp.getOperation(); + MultiDialectBuilder create(rewriter, loc); + + // Match + Value A, AI8, AScale, AZeroPointI8, BI8, BScale, BZeroPointI8; + if (failed(canBeRewritten(rewriter, mulOp, A, AI8, AScale, AZeroPointI8, + BI8, BScale, BZeroPointI8))) + return failure(); + + Type si64Ty = rewriter.getIntegerType(64, true); + Type f16Ty = rewriter.getF16Type(); + Type f32Ty = rewriter.getF32Type(); + IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1); + IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0); + Value none = create.onnx.none(); + + // Only pre-compute bias when BZeroPoint is zero. + bool canPreComputeBias = isDenseONNXConstant(BI8) && + isDenseONNXConstant(BZeroPointI8) && + isConstOf(BZeroPointI8, 0.0); + + // Stickify A. + int64_t rankA = getRank(A.getType()); + StringAttr aLayoutAttr = + rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS); + ZHighQuantizedStickOp qAOp; + if (nnpaUseDynamicQuantizeLinearOnCPU) { + Value zeroI64 = create.onnx.constantInt64({0}); + // Input A was quantized on CPU by onnx.DynamicQuantizedLinear: f32 to i8. + if (getRank(AZeroPointI8.getType()) == 1) { + // Normalize the zeropoint tensor to tensor. + AZeroPointI8 = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(AZeroPointI8.getType())), + AZeroPointI8, {zeroI64}); + } + AZeroPointI8 = getOrCastToI8(AZeroPointI8, create, true); + Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty); + Value ARecScale = create.onnx.reciprocal(AScale); + AI8 = getOrCastToI8(AI8, create, true); + // Stickify the quantized input A to ztensor format. + qAOp = rewriter.create(loc, AI8, ARecScale, + AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8)); + } else { + // Stickify input A to dlfloat16, and it will be quantized internally by + // the NNPA quantized matmul. + qAOp = rewriter.create(loc, A, none, none, + aLayoutAttr, rewriter.getStringAttr(QTYPE_DLFLOAT16)); + } + Value qA = qAOp.getResult(0); + Value ARecScale = qAOp.getResult(1); + Value AZeroPoint = qAOp.getResult(2); + + // Stickify B. It is potentially folded at compile time. + int64_t rankB = getRank(BI8.getType()); + StringAttr bLayoutAttr = + rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS); + Value BRecScale = create.onnx.reciprocal(BScale); + Value BZeroPoint = create.onnx.cast(BZeroPointI8, f32Ty); + ZHighQuantizedStickOp qBOp = + rewriter.create(loc, BI8, BRecScale, BZeroPoint, + bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS)); + Value qB = qBOp.getResult(0); + + // Output's rescale and zeropoint + auto cst0Attr = + DenseElementsAttr::get(RankedTensorType::get({}, f32Ty), (float)0); + auto cst1Attr = + DenseElementsAttr::get(RankedTensorType::get({}, f32Ty), (float)1); + Value OutRecScale = create.onnx.constant(cst1Attr); + Value OutZeroPoint = create.onnx.constant(cst0Attr); + + // Bias is none or precomputed. + Value qcTilde, qcTildeRecScale, qcTildeZeroPoint; + if (canPreComputeBias) + preComputeBias(create, ARecScale, AZeroPoint, BI8, BRecScale, OutRecScale, + OutZeroPoint, qcTilde, qcTildeRecScale, qcTildeZeroPoint); + + // Emit zhigh.QuantizedMatMul. + SmallVector resTypes; + resTypes.emplace_back(UnrankedTensorType::get(f16Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + ZHighQuantizedMatMulOp zhighQuantizedMatMulOp = + rewriter.create(loc, resTypes, qA, ARecScale, + AZeroPoint, qB, BRecScale, BZeroPoint, + /*Bias*/ canPreComputeBias ? qcTilde : none, + /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none, + /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPoint : none, + /*OutRecScale*/ OutRecScale, /*OutOffset*/ OutZeroPoint, + /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr, + /*DequantizeOutput*/ trueAttr); + (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {}); + + // Unstickify the matmul result. + Value res = rewriter.create( + loc, zhighQuantizedMatMulOp.getResult(0)); + + rewriter.replaceOp(op, res); + return success(); + } + + // clang-format off + // func.func @pattern_in_bert(%A) { + // // A is dynamically quantized. + // %a, %a_scale, %a_zero_point = "onnx.DynamicQuantizeLinear"(%A) + // + // // B is a constant and already quantized. + // %b = onnx.Constant + // %b_scale = onnx.Constant + // %b_zero_point = onnx.Constant + // + // + // %matmul = "onnx.MatMulInteger"(%b, %b, %b_zero_point, %b_zero_point) + // + // // Scale the output. + // %mm_f32 = "onnx.Cast"(%matmul) {to = f32} + // %mm_a_scale = "onnx.Mul"(%mm_f32, %a_scale) + // %mm_ab_scale = "onnx.Mul"(%mm_a_scale, %b_scale) + // + // return %mm_y_scale + // } + // clang-format on + static mlir::LogicalResult canBeRewritten(PatternRewriter &rewriter, + ONNXMulOp mulOp, Value &A, Value &AI8, Value &AScale, Value &AZeroPoint, + Value &BI8, Value &BScale, Value &BZeroPoint) { + + // Match `cast(mm_out) * a_scale * b_scale` to find two scales but we don't + // know yet which scale is for A or B. + Value scale1, scale2; + ONNXCastOp castOp; + ONNXMulOp mulScaleOp; + + Value opr1 = mulOp.getOperand(0); + Value opr2 = mulOp.getOperand(1); + + // Match cast(mm_out) * (a_scale * b_scale) + castOp = opr1.getDefiningOp(); + mulScaleOp = opr2.getDefiningOp(); + bool foundScales = false; + if (castOp && mulScaleOp && isScalarTensor(opr2)) { + Value lhs = mulScaleOp.getOperand(0); + Value rhs = mulScaleOp.getOperand(1); + if (isScalarTensor(lhs) && isScalarTensor(rhs)) { + // mulScaleOp is a_scale * b_scale; + foundScales = true; + scale1 = lhs; + scale2 = rhs; + } + } + // Match (a_scale * b_scale) * cast(mm_out) + if (!foundScales) { + mulScaleOp = opr1.getDefiningOp(); + castOp = opr2.getDefiningOp(); + if (mulScaleOp && isScalarTensor(opr1) && castOp) { + Value lhs = mulScaleOp.getOperand(0); + Value rhs = mulScaleOp.getOperand(1); + if (isScalarTensor(lhs) && isScalarTensor(rhs)) { + // mulScaleOp is a_scale * b_scale; + foundScales = true; + scale1 = lhs; + scale2 = rhs; + } + } + } + // Match [cast(mm_out) * a_scale] * b_scale + if (!foundScales & isScalarTensor(opr2)) { + scale1 = opr2; + mulScaleOp = opr1.getDefiningOp(); + if (mulScaleOp) { + Value lhs = mulScaleOp.getOperand(0); + Value rhs = mulScaleOp.getOperand(1); + castOp = lhs.getDefiningOp(); + if (castOp && isScalarTensor(rhs)) { + // Match cast(mm_out) * a_scale + scale2 = rhs; + foundScales = true; + } + if (!foundScales) { + // Match a_scale * cast(mm_out) + castOp = rhs.getDefiningOp(); + if (isScalarTensor(lhs) && castOp) { + scale2 = lhs; + foundScales = true; + } + } + } + // Match b_scale * [cast(mm_out) * a_scale] + if (!foundScales && isScalarTensor(opr1)) { + scale1 = opr1; + mulScaleOp = opr2.getDefiningOp(); + if (mulScaleOp) { + Value lhs = mulScaleOp.getOperand(0); + Value rhs = mulScaleOp.getOperand(1); + castOp = lhs.getDefiningOp(); + if (castOp && isScalarTensor(rhs)) { + // Match cast(mm_out) * a_scale + scale2 = rhs; + foundScales = true; + } + if (!foundScales) { + // Match a_scale * cast(mm_out) + castOp = rhs.getDefiningOp(); + if (isScalarTensor(lhs) && castOp) { + scale2 = lhs; + foundScales = true; + } + } + } + } + } + if (!foundScales) + return rewriter.notifyMatchFailure(mulOp, "Not found scale values"); + + // Identify a_scale and b_scale. + // a_scale is from DynamicQuantizeLinear. + if (scale1.getDefiningOp()) { + AScale = scale1; + BScale = scale2; + } else if (scale2.getDefiningOp()) { + AScale = scale2; + BScale = scale1; + } else { + return rewriter.notifyMatchFailure( + mulOp, "Could not identify a_scale and b_scale"); + } + + // Match cast. + // %cast = "onnx.Cast"(%matmul) {saturate = 1 : si64, to = f32} + Type castOutputType = castOp.getOutput().getType(); + Type castInputType = castOp.getInput().getType(); + if (isRankedShapedType(castInputType) && + isRankedShapedType(castOutputType)) { + if (!getElementType(castInputType).isInteger(32)) + return rewriter.notifyMatchFailure( + mulOp, "ONNXCast is not casting from i32"); + if (!getElementType(castOutputType).isF32()) + return rewriter.notifyMatchFailure( + mulOp, "ONNXCast is not casting to f32"); + } else { + return rewriter.notifyMatchFailure(mulOp, "ONNXCast is unranked"); + } + + // Match matmul to get BI8 and BZeroPoint. + ONNXMatMulIntegerOp matmulOp = + castOp.getInput().getDefiningOp(); + if (!matmulOp) + return rewriter.notifyMatchFailure( + mulOp, "The input of the CastOp is not defined by MatMulIntegerOp"); + if (!isSuitableForZDNN(matmulOp)) + return rewriter.notifyMatchFailure( + mulOp, "MatMulInteger is not suitable for zDNN"); + + AI8 = matmulOp->getOperand(0); + BI8 = matmulOp->getOperand(1); + AZeroPoint = matmulOp->getOperand(2); + BZeroPoint = matmulOp->getOperand(3); + if (!isDenseONNXConstant(BI8)) + return rewriter.notifyMatchFailure(mulOp, "Quantized Y is not constant"); + if (!isDenseONNXConstant(BZeroPoint)) + return rewriter.notifyMatchFailure(mulOp, "BZeroPoint is not constant"); + if (!(getElementType(BI8.getType()).isUnsignedInteger(8) || + getElementType(BI8.getType()).isSignlessInteger(8))) + return rewriter.notifyMatchFailure( + mulOp, "Quantized Y is not signed int8"); + + // Match dynamic quantize linear to get A. + if (auto dqlOp = + llvm::dyn_cast(AI8.getDefiningOp())) { + if (AScale != dqlOp.getResult(1)) + return rewriter.notifyMatchFailure(mulOp, "AScale is not used"); + if (AZeroPoint != dqlOp.getResult(2)) + return rewriter.notifyMatchFailure(mulOp, "AZeroPoint is not used"); + // return A. + A = dqlOp.getOperand(); + } else { + return rewriter.notifyMatchFailure( + mulOp, "Quantized A is not defined by DynamicQuantizeLinearOp"); + } + + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Fuse ZHighQuantizedMatMul and ONNXAdd +//===----------------------------------------------------------------------===// +// Rewrite this pattern: +// (ONNXAddOp +// $x, +// (ZHighUnstickOp +// (ZHighQuantizedMatMulOp:$mm_res +// $a, $Sa, $Za, +// $b, $Sb, $Zb, +// (ZHighQuantizedStick $c), $Sc, $Zb, +// $So, $Zo, +// $preComputed, $disableClipping, $dequantized))), +// +// into this pattern where $x is added to $c: +// +// (ZHighUnstickOp +// (ZHighQuantizedMatMulOp +// $a, $Sa, $Za, +// $b, $Sb, $Zb, +// (ZHighQuantizedStick (ONNXAddOp $x, $c)), $Sc, $Zb, +// $So, $Zo, +// $preComputed, $disableClipping, $dequantized)), +// +// Requirement: `preComputed` is true. + +class fuseZHighQuantizedMatMulONNXAddPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXAddOp addOp, PatternRewriter &rewriter) const override { + Location loc = addOp.getLoc(); + Operation *op = addOp.getOperation(); + MultiDialectBuilder create(rewriter, loc); + + ZHighUnstickOp unstickOp; + ZHighQuantizedMatMulOp mmOp; + ZHighQuantizedStickOp qstickOp; + Value addInput; + + // match + if (failed(canBeRewritten( + rewriter, addOp, unstickOp, mmOp, qstickOp, addInput))) + return failure(); + + // rewrite + Value newBias = create.onnx.add(addInput, qstickOp.getIn()); + ZHighQuantizedStickOp newQStickOp = rewriter.create( + loc, newBias, qstickOp.getInRecScale(), qstickOp.getInOffset(), + qstickOp.getLayoutAttr(), qstickOp.getQuantizedTypeAttr()); + + SmallVector resTypes; + resTypes.emplace_back(mmOp.getResult(0).getType()); + resTypes.emplace_back(mmOp.getResult(1).getType()); + resTypes.emplace_back(mmOp.getResult(2).getType()); + ZHighQuantizedMatMulOp newQMMOp = rewriter.create( + loc, resTypes, mmOp.getX(), mmOp.getXRecScale(), mmOp.getXOffset(), + mmOp.getY(), mmOp.getYRecScale(), mmOp.getYOffset(), + newQStickOp.getResult(0), newQStickOp.getResult(1), + newQStickOp.getResult(2), mmOp.getOutRecScaleIn(), + mmOp.getOutOffsetIn(), mmOp.getPreComputedBiasAttr(), + mmOp.getDisableClippingAttr(), mmOp.getDequantizeOutputAttr()); + ZHighUnstickOp newUnstickOp = + rewriter.create(loc, newQMMOp.getResult(0)); + + rewriter.replaceOp(op, newUnstickOp); + return success(); + } + + static mlir::LogicalResult canBeRewritten(PatternRewriter &rewriter, + ONNXAddOp addOp, ZHighUnstickOp &unstickOp, ZHighQuantizedMatMulOp &mmOp, + ZHighQuantizedStickOp &qstickOp, Value &addInput) { + Value lhs = addOp.getOperand(0); + Value rhs = addOp.getOperand(1); + bool found = false; + if (auto op1 = lhs.getDefiningOp()) { + addInput = rhs; + unstickOp = op1; + Value mmOutput = unstickOp.getIn(); + if (auto op2 = mmOutput.getDefiningOp()) { + mmOp = op2; + bool precomputed = (mmOp.getPreComputedBias() == -1); + if (!precomputed) + return rewriter.notifyMatchFailure( + addOp, "not precomputed quantized matmul"); + Value qBias = mmOp.getB(); + if (auto op3 = qBias.getDefiningOp()) { + qstickOp = op3; + Value bias = qstickOp.getIn(); + // Check rank. + if (getRank(bias.getType()) != getRank(addInput.getType())) + return rewriter.notifyMatchFailure(addOp, "rank mismatched"); + found = true; + } + } + } + if (found) + return success(); + + if (auto op1 = rhs.getDefiningOp()) { + addInput = lhs; + unstickOp = op1; + Value mmOutput = unstickOp.getIn(); + if (auto op2 = mmOutput.getDefiningOp()) { + mmOp = op2; + bool precomputed = (mmOp.getPreComputedBias() == -1); + if (!precomputed) + return rewriter.notifyMatchFailure( + addOp, "not precomputed quantized matmul"); + Value qBias = mmOp.getB(); + if (auto op3 = qBias.getDefiningOp()) { + qstickOp = op3; + Value bias = qstickOp.getIn(); + // Check rank. + if (getRank(bias.getType()) != getRank(addInput.getType())) + return rewriter.notifyMatchFailure(addOp, "rank mismatched"); + found = true; + } + } + } + if (found) + return success(); + + return rewriter.notifyMatchFailure(addOp, "unstick not found"); + } +}; + +class replaceONNXQLinearMatMulPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXQLinearMatMulOp qmmOp, PatternRewriter &rewriter) const override { + Location loc = qmmOp.getLoc(); + Operation *op = qmmOp.getOperation(); + MultiDialectBuilder create(rewriter, loc); + + // Match + if (failed(canBeRewritten(rewriter, qmmOp))) + return failure(); + + Type si64Ty = rewriter.getIntegerType(64, true); + Type f16Ty = rewriter.getF16Type(); + Type f32Ty = rewriter.getF32Type(); + IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1); + IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0); + + Value A = qmmOp.getA(); + Value AScale = qmmOp.getAScale(); + Value AZeroPoint = qmmOp.getAZeroPoint(); + Value B = qmmOp.getB(); + Value BScale = qmmOp.getBScale(); + Value BZeroPoint = qmmOp.getBZeroPoint(); + Value Y = qmmOp.getY(); + Value YScale = qmmOp.getYScale(); + Value YZeroPoint = qmmOp.getYZeroPoint(); + + // Only pre-compute bias when B is a constant and BZeroPoint is int8 zero. + bool canPreComputeBias = false; + if (isDenseONNXConstant(B) && isDenseONNXConstant(BZeroPoint)) { + if (getElementType(BZeroPoint.getType()).isUnsignedInteger()) + canPreComputeBias = isConstOf(BZeroPoint, 128.0); + else + canPreComputeBias = isConstOf(BZeroPoint, 0.0); + } + + // Emit some common values. + Value none = create.onnx.none(); + Value zero = create.onnx.constantInt64({0}); + + // Normalize scalar tensors to tensor. + if (getRank(AScale.getType()) == 1) { + AScale = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(AScale.getType())), AScale, + {zero}); + } + if (getRank(AZeroPoint.getType()) == 1) { + AZeroPoint = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(AZeroPoint.getType())), + AZeroPoint, {zero}); + } + if (getRank(BScale.getType()) == 1) { + BScale = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(BScale.getType())), BScale, + {zero}); + } + if (getRank(BZeroPoint.getType()) == 1) { + BZeroPoint = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(BZeroPoint.getType())), + BZeroPoint, {zero}); + } + if (getRank(YScale.getType()) == 1) { + YScale = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(YScale.getType())), YScale, + {zero}); + } + if (getRank(YZeroPoint.getType()) == 1) { + YZeroPoint = create.onnx.squeeze( + RankedTensorType::get({}, getElementType(YZeroPoint.getType())), + YZeroPoint, {zero}); + } + + // zdnn supports signed int8, convert unsigned int8 inputs to signed int8. + Value AI8 = getOrCastToI8(A, create); + Value BI8 = getOrCastToI8(B, create); + + Value ARecScale = create.onnx.reciprocal(AScale); + Value AZeroPointI8 = getOrCastToI8(AZeroPoint, create); + Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty); + + Value BRecScale = create.onnx.reciprocal(BScale); + Value BZeroPointI8 = getOrCastToI8(BZeroPoint, create); + Value BZeroPointF32 = create.onnx.cast(BZeroPointI8, f32Ty); + + Value YRecScale = create.onnx.reciprocal(YScale); + Value YZeroPointI8 = getOrCastToI8(YZeroPoint, create); + Value YZeroPointF32 = create.onnx.cast(YZeroPointI8, f32Ty); + + // Stickify AI8, Transform AI8 into zTensor format. + int64_t rankA = getRank(AI8.getType()); + StringAttr aLayoutAttr = + rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS); + ZHighQuantizedStickOp qAOp = + rewriter.create(loc, AI8, ARecScale, + AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8)); + + // Stickify BI8. It is potentially folded at compile time. + int64_t rankB = getRank(BI8.getType()); + StringAttr bLayoutAttr = + rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS); + ZHighQuantizedStickOp qBOp = + rewriter.create(loc, BI8, BRecScale, + BZeroPointF32, bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS)); + + // Bias is none or precomputed. + Value qcTilde, qcTildeRecScale, qcTildeZeroPointF32; + if (canPreComputeBias) + preComputeBias(create, ARecScale, AZeroPointF32, BI8, BRecScale, + YRecScale, YZeroPointF32, qcTilde, qcTildeRecScale, + qcTildeZeroPointF32); + + // Emit zhigh.QuantizedMatMul. Bias is none. + // DisableClipping gives the same output as the onnx backend test since the + // onnx backend test uses `astype` instead of `clipping` to cast the output + // to i8. + SmallVector resTypes; + resTypes.emplace_back(UnrankedTensorType::get(f16Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + resTypes.emplace_back(RankedTensorType::get({}, f32Ty)); + ZHighQuantizedMatMulOp zhighQuantizedMatMulOp = + rewriter.create(loc, resTypes, + qAOp.getResult(0), qAOp.getResult(1), qAOp.getResult(2), + qBOp.getResult(0), qBOp.getResult(1), qBOp.getResult(2), + /*Bias*/ canPreComputeBias ? qcTilde : none, + /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none, + /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPointF32 : none, + /*OutRecScale*/ YRecScale, /*OutOffset*/ YZeroPointF32, + /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr, + /*DisableClipping*/ trueAttr, + /*DequantizeOutput*/ falseAttr); + (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {}); + + // Unstickify the matmul result that is int8-as-float. + Value resI8F32 = rewriter.create( + loc, zhighQuantizedMatMulOp.getResult(0)); + Value res; + Type outElemTy = getElementType(Y.getType()); + if (outElemTy.isUnsignedInteger(8)) { + // The zdnn output is int8. Convert int8 to uint8. + // Use int16 to avoid integer overflow. + Type i16Ty = rewriter.getI16Type(); + Type ui16Ty = rewriter.getIntegerType(16, false); + auto cst128Attr = DenseElementsAttr::get( + RankedTensorType::get({}, i16Ty), static_cast(128)); + // clang-format off + Value resUI16 = + create.onnx.cast( + create.onnx.add(create.onnx.cast(resI8F32, i16Ty), + create.onnx.constant(cst128Attr)), + ui16Ty); + // clang-format on + res = create.onnx.cast(resUI16, outElemTy); + } else { + res = create.onnx.cast(resI8F32, outElemTy); + } + rewriter.replaceOp(op, res); + return success(); + } + + static mlir::LogicalResult canBeRewritten( + PatternRewriter &rewriter, ONNXQLinearMatMulOp qmmOp) { + if (!isSuitableForZDNN(qmmOp)) + return rewriter.notifyMatchFailure( + qmmOp, "QLinearMatMul is not suitable for zDNN"); + return success(); + } +}; + struct ONNXToZHighLoweringPass : public PassWrapper> { @@ -290,14 +1535,85 @@ struct ONNXToZHighLoweringPass ONNXToZHighLoweringPass() = default; ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass) : PassWrapper>() {} + ONNXToZHighLoweringPass(NNPAQuantType quantMode) { + this->quantMode = quantMode; + } void runOnOperation() final; + +public: + Option quantMode{*this, "quantization", + llvm::cl::desc("Enable quantization"), + llvm::cl::values( + clEnumVal(DynSymI8, + "Dynamic Quantization to signed integer 8. Asymmetric quant for " + "activations and symmetric quant for weights."), + clEnumVal(SymSymI8, + "Dynamic Quantization to signed integer 8. Symmetric quant for " + "activations and symmetric quant for weights."), + clEnumVal(QNONE, "No quantization (default).")), + llvm::cl::init(QNONE)}; }; } // end anonymous namespace. -void getONNXToZHighOneOpPatterns(RewritePatternSet &patterns) { +void getONNXToZHighOneOpPatterns( + RewritePatternSet &patterns, NNPAQuantType quantMode) { MLIRContext *context = patterns.getContext(); - populateWithGenerated(patterns); - patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + + // Pattern for i8 dynamic quantization, symmetric mode. + if (isCompatibleWithNNPALevel(NNPALevel::M15) && + (quantMode == NNPAQuantType::DynSymI8 || + quantMode == NNPAQuantType::SymSymI8)) { + // Bump up the pattern benefit to run these before non-quantization + // patterns. + PatternBenefit quantPriority(QUANT_PATTERN_BENEFIT); + patterns.insert( + context, quantPriority, quantMode == NNPAQuantType::SymSymI8); + patterns.insert( + context, quantPriority, quantMode == NNPAQuantType::SymSymI8); + } } void getONNXToZHighOneOpDynamicallyLegal( @@ -309,7 +1625,10 @@ void getONNXToZHighOneOpDynamicallyLegal( addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); @@ -319,18 +1638,42 @@ void getONNXToZHighOneOpDynamicallyLegal( addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); + addDynamicallyLegalOpFor(target, dimAnalysis); } -void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) { +void getONNXToZHighMultipleOpPatterns( + RewritePatternSet &patterns, NNPAQuantType quantMode) { MLIRContext *context = patterns.getContext(); patterns.insert(context); patterns.insert(context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + + // Pattern for i8 dynamic quantization, symmetric mode. + if (isCompatibleWithNNPALevel(NNPALevel::M15) && + (quantMode == NNPAQuantType::DynSymI8 || + quantMode == NNPAQuantType::SymSymI8)) { + // Bump up the pattern benefit to run these before non-quantization + // patterns. + PatternBenefit quantPriority(QUANT_PATTERN_BENEFIT); + patterns.insert( + context, quantPriority, quantMode == NNPAQuantType::SymSymI8); + } + // Shape inference for newly-added operations. getShapeInferencePatterns(patterns); } @@ -363,7 +1706,8 @@ void ONNXToZHighLoweringPass::runOnOperation() { // a single ONNX Op, because the single op lowering might have conditions that // prohibit the combined ops lowering happened. RewritePatternSet combinedPatterns(&getContext()); - onnx_mlir::getONNXToZHighMultipleOpPatterns(combinedPatterns); + onnx_mlir::getONNXToZHighMultipleOpPatterns( + combinedPatterns, this->quantMode); // It's ok to fail. (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns)); @@ -375,7 +1719,7 @@ void ONNXToZHighLoweringPass::runOnOperation() { // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); - onnx_mlir::getONNXToZHighOneOpPatterns(patterns); + onnx_mlir::getONNXToZHighOneOpPatterns(patterns, this->quantMode); // This is to make sure we don't want to alloc any MemRef at this high-level // representation. @@ -398,4 +1742,8 @@ std::unique_ptr createONNXToZHighPass() { return std::make_unique(); } +std::unique_ptr createONNXToZHighPass(NNPAQuantType quantMode) { + return std::make_unique(quantMode); +} + } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp index 034f92a6e3..d121058168 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp @@ -18,13 +18,16 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" namespace onnx_mlir { // Exports ONNXtoZHigh patterns. -void getONNXToZHighOneOpPatterns(mlir::RewritePatternSet &patterns); -void getONNXToZHighMultipleOpPatterns(mlir::RewritePatternSet &patterns); +void getONNXToZHighOneOpPatterns( + mlir::RewritePatternSet &patterns, NNPAQuantType quantMode); +void getONNXToZHighMultipleOpPatterns( + mlir::RewritePatternSet &patterns, NNPAQuantType quantMode); // Exports ONNXtoZHigh dynamically legal checks. void getONNXToZHighOneOpDynamicallyLegal( diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index 075ae98e98..7e6f724c57 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -2,7 +2,7 @@ //===- ONNXToZHigh.td - Replacing ONNX Ops by ZHigh Ops -*- tablegen ------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -32,6 +32,8 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td" def IsEnableScalarBcastBinary: Constraint>; +def HasOneUse : Constraint, "op has exactly one use">; + def IsNotNoneType : Constraint(($_self).getType())">>; class HasRankOf : Constraint< @@ -91,6 +93,8 @@ def GetI64ZeroAttr : class GetI64ArrayAttr : NativeCodeCall<"$_builder.getI64ArrayAttr(" # n # ")">; +def GetFloatAttr : NativeCodeCall<"$_builder.getFloatAttr($0.getType(), $0.getValueAsDouble())">; + def GetUnrankedTensorTypeOf : NativeCodeCall< "UnrankedTensorType::get(mlir::cast($0.getType()).getElementType())" >; @@ -109,6 +113,36 @@ def IsF32ScalarConstantTensor: Constraint< def GetScalarF32AttrFromConstant : NativeCodeCall<"getScalarF32AttrFromConstant($0)">; +def IsConstOfHalf : Constraint< + CPred<"isDenseONNXConstant($_self) && isConstOf($_self, 0.5)">, + "Value is all halves (0.5) for a constant tensor">; + +def IsConstOfOnes : Constraint< + CPred<"isDenseONNXConstant($_self) && isConstOf($_self, 1.0)">, + "Value is an all-ones constant tensor">; + +def IsConstOfTwos : Constraint< + CPred<"isDenseONNXConstant($_self) && isConstOf($_self, 2.0)">, + "Value is all twos for a constant tensor">; + +def IsConstOfThrees : Constraint< + CPred<"isDenseONNXConstant($_self) && isConstOf($_self, 3.0)">, + "Value is all threes for a constant tensor">; + +def IsFloatType : Constraint< + CPred<"isFloatType($_self)">, + "Value has element type of float">; + +//===----------------------------------------------------------------------===// +// ONNXLeakyReluOp %X = ZHighUnstickOp (ZHighLeakyReluOp (ZHighStickOp %X)) +//===----------------------------------------------------------------------===// +def replaceONNXLeakyReluPattern : Pat< + (ONNXLeakyReluOp $x, $alpha), + (ZHighUnstickOp (ZHighLeakyReluOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr), (GetDefaultSaturation)), (GetFloatAttr $alpha), + (returnType $s_x))), + [(IsCompatibleWithNNPALevelArch15)] +>; + //===----------------------------------------------------------------------===// // ONNXReluOp %X = ZHighUnstickOp (ZHighReluOp (ZHighStickOp %X)) //===----------------------------------------------------------------------===// @@ -118,6 +152,17 @@ def replaceONNXReluPattern : Pat< (returnType $s_x))) >; +//===----------------------------------------------------------------------===// +// ONNXGeluOp %X = ZHighUnstickOp (ZHighGeluOp (ZHighStickOp %X), $approximate) +//===----------------------------------------------------------------------===// +def replaceONNXGeluPattern : Pat< + (ONNXGeluOp:$res $x, $approximate), + (ZHighUnstickOp (ZHighGeluOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr), (GetDefaultSaturation)), + $approximate, + (returnType $s_x))), + [(IsCompatibleWithNNPALevelArch15)] +>; + //===----------------------------------------------------------------------===// // ONNXTanhOp %X = ZHighUnstickOp (ZHighTanhOp (ZHighStickOp %X)) //===----------------------------------------------------------------------===// @@ -294,6 +339,38 @@ def replaceONNXMaxPattern : Pat< (returnType $s_x))) >; +//===----------------------------------------------------------------------===// +// ONNXDivOp(1, (ONNXSqrtOp %X)) = ZHighUnstickOp +// (ZHighInvSqrtOp (ZHighStickOp %X)) +//===----------------------------------------------------------------------===// +def replaceDiv1SqrtPattern : Pat< + (ONNXDivOp $a, (ONNXSqrtOp $x)), + (ZHighUnstickOp (ZHighInvSqrtOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr), (GetDefaultSaturation)), + (returnType $s_x))), + [(IsCompatibleWithNNPALevelArch15),(IsConstOfOnes:$a),(IsFloatType:$a)] +>; + +//===----------------------------------------------------------------------===// +// ONNXReciprocalOp(ONNXSqrtOp %X) = ZHighUnstickOp +// (ZHighInvSqrtOp (ZHighStickOp %X)) +//===----------------------------------------------------------------------===// +def replaceReciprocalSqrtPattern : Pat< + (ONNXReciprocalOp (ONNXSqrtOp $x)), + (ZHighUnstickOp (ZHighInvSqrtOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr), (GetDefaultSaturation)), + (returnType $s_x))), + [(IsCompatibleWithNNPALevelArch15)] +>; + +//===----------------------------------------------------------------------===// +// ONNXSqrtOp %X = ZHighUnstickOp (ZHighSqrtOp (ZHighStickOp %X)) +//===----------------------------------------------------------------------===// +def replaceONNXSqrtPattern : Pat< + (ONNXSqrtOp $x), + (ZHighUnstickOp (ZHighSqrtOp (ZHighStickOp:$s_x $x, (NoneLayoutAttr), (GetDefaultSaturation)), + (returnType $s_x))), + [(IsCompatibleWithNNPALevelArch15)] +>; + //===----------------------------------------------------------------------===// // ONNXSoftmaxOp %X = ONNXSqueezeOp // (ZHighUnstickOp @@ -378,13 +455,38 @@ def replaceONNXLogSoftmaxPattern : Pattern< //===----------------------------------------------------------------------===// // ONNXReduceMeanV13Op %X = (ZHighUnstickOp // (ZHighMeanReduce2DOp -// (ZHighStickOp %X)))) +// (ZHighStickOp %X))) //===----------------------------------------------------------------------===// def replaceONNXReduceMeanV13Pattern : Pat< (ONNXReduceMeanV13Op:$res $x, $_, $_), (ZHighUnstickOp (ZHighMeanReduce2DOp (ZHighStickOp $x, (NHWCLayoutAttr), (GetDefaultSaturation)))) >; +//===----------------------------------------------------------------------===// +// ONNXReduceMaxOp %X = +// (ZHighUnstickOp +// (ZHighReduceMaxOp +// (ZHighStickOp %X))) +//===----------------------------------------------------------------------===// +def replaceONNXReduceMaxPattern : Pat< + (ONNXReduceMaxOp:$res $data, $axes, $keepdims, $noop_with_empty_axes), + (ZHighUnstickOp (ZHighReduceMaxOp (ZHighStickOp:$s_x $data, (NoneLayoutAttr), + (GetDefaultSaturation)))), + [(IsCompatibleWithNNPALevelArch15)] +>; + +//===----------------------------------------------------------------------===// +// ONNXReduceMinOp %X = +// (ZHighUnstickOp +// (ZHighReduceMinOp +// (ZHighStickOp %X))) +//===----------------------------------------------------------------------===// +def replaceONNXReduceMinPattern : Pat< + (ONNXReduceMinOp:$res $data, $axes, $keepdims, $noop_with_empty_axes), + (ZHighUnstickOp (ZHighReduceMinOp (ZHighStickOp:$s_x $data, (NoneLayoutAttr), + (GetDefaultSaturation)))), + [(IsCompatibleWithNNPALevelArch15)] +>; //===----------------------------------------------------------------------===// // ONNXMaxPoolSingleOutOp %X = // (ZHighUnstickOp @@ -478,7 +580,7 @@ def GetMatMulLayoutStringAttr : NativeCodeCall< >; def GetMatMulBiasLayoutStringAttr : NativeCodeCall< - "$_builder.getStringAttr((($0 == 3) && ($1 == 3)) ? LAYOUT_2DS : LAYOUT_1D)" + "$_builder.getStringAttr(((($0 == 3) && ($1 == 3)) || (($0 == 2) && ($1 == 3))) ? LAYOUT_2DS : LAYOUT_1D)" >; //===----------------------------------------------------------------------===// @@ -495,7 +597,7 @@ def replaceONNXMatMulPattern : Pat< (ZHighMatMulOp (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), - (CreateNoneValue))) + (CreateNoneValue), (GetZeroI64Attr), (GetZeroI64Attr))) >; //===----------------------------------------------------------------------===// @@ -533,7 +635,8 @@ def replaceONNXMatMulAddPattern1 : Pat< (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x), - (GetRank $y)), (GetDefaultSaturation)))), + (GetRank $y)), (GetDefaultSaturation)), + (GetZeroI64Attr), (GetZeroI64Attr))), [(IsMatMulLegalForZDNN $m), (HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b)], [], (addBenefit 0) @@ -548,17 +651,113 @@ def replaceONNXMatMulAddPattern2 : Pat< (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x), - (GetRank $y)), (GetDefaultSaturation)))), + (GetRank $y)), (GetDefaultSaturation)), + (GetZeroI64Attr), (GetZeroI64Attr))), [(IsMatMulLegalForZDNN $m), (HasRankOf<2> $y), (HasRankOf<1> $b), (HaveSameLastDimR2R1 $y, $b)], [], (addBenefit 0) >; +//===----------------------------------------------------------------------===// +// Replace onnx.add and onnx.matmul with bcast1 tensors with ZHighMatMul +//===----------------------------------------------------------------------===// + +def replaceONNXMatMulAddPatternBcast1A : Pat< + // From Add $b, (MatMul $x, $y) + (ONNXAddOp $b, (ONNXMatMulOp:$m $x, $y)), + // To ZHighMatMulOp + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), + (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), + (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x), + (GetRank $y)), (GetDefaultSaturation)), + (GetZeroI64Attr), (GetZeroI64Attr))), + [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m), + (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [], + (addBenefit 0) +>; + +def replaceONNXMatMulAddPatternBcast1B : Pat< + // From Add (MatMul $x, $y), $b + (ONNXAddOp (ONNXMatMulOp:$m $x, $y), $b), + // To ZHighMatMulOp + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), + (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), + (ZHighStickOp $b, (GetMatMulBiasLayoutStringAttr (GetRank $x), + (GetRank $y)), (GetDefaultSaturation)), + (GetZeroI64Attr), (GetZeroI64Attr))), + [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m), + (HasRankOf<2> $x), (HasRankOf<3> $y), (HasRankOf<2> $b)], [], + (addBenefit 0) +>; + +//===----------------------------------------------------------------------===// +// Replace onnx.matmul when following onnx.transpose +// with ZHighMatMul with either A or B transposed +//===----------------------------------------------------------------------===// + +// TODO: This could likely be done in a cleaner way, such as comparing array vs array... +// Make sure the transpose permutation is the default as ZHighMatMul only supports default +def IsStandardTranspose : Constraint< + CPred<"((mlir::cast($0).size() == 2 ) " # + " && (mlir::cast(mlir::cast($0)[0]).getInt() == 1)" # + " && (mlir::cast(mlir::cast($0)[1]).getInt() == 0))" # + " || ((mlir::cast($0).size() == 3 ) " # + " && (mlir::cast(mlir::cast($0)[0]).getInt() == 0)" # + " && (mlir::cast(mlir::cast($0)[1]).getInt() == 2)" # + " && (mlir::cast(mlir::cast($0)[2]).getInt() == 1))"> +>; + +def replaceONNXTransBMatMulPattern : Pat< + (ONNXMatMulOp:$m $x,(ONNXTransposeOp $y,$perm)), + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), + (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), + (CreateNoneValue), (GetI64ZeroAttr), (GetI64NAttr<1>) + ) + ), + [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m), (IsStandardTranspose $perm)] +>; + +def replaceONNXTransAMatMulPattern : Pat< + (ONNXMatMulOp:$m (ONNXTransposeOp $x,$perm), $y), + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), + (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), + (CreateNoneValue), (GetI64NAttr<1>), (GetI64ZeroAttr) + ) + ), + [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m), (IsStandardTranspose $perm)] +>; + +def replaceONNXTransABMatMulPattern : Pat< + (ONNXMatMulOp:$m (ONNXTransposeOp $x,$perma), (ONNXTransposeOp $y,$permb)), + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $x, (GetMatMulLayoutStringAttr (GetRank $x)), (GetDefaultSaturation)), + (ZHighStickOp $y, (GetMatMulLayoutStringAttr (GetRank $y)), (GetDefaultSaturation)), + (CreateNoneValue), (GetI64NAttr<1>), (GetI64NAttr<1>) + ) + ), + [(IsCompatibleWithNNPALevelArch15),(IsMatMulLegalForZDNN $m), + (IsStandardTranspose $perma), (IsStandardTranspose $permb)] +>; + //===----------------------------------------------------------------------===// // GEMM //===----------------------------------------------------------------------===// def IsTransposed: Constraint($_self).getSInt() == 1)">>; +def IsAorBTransposed: Constraint< + CPred<"(mlir::cast($0).getSInt() == 1) ||" + "(mlir::cast($1).getSInt() == 1)" > +>; + def Transpose2D: NativeCodeCall< "emitONNXTranspose($_loc, $_builder, $0, SmallVector({1, 0}))">; @@ -604,7 +803,8 @@ def replaceONNXGemmBiasNoneOr1DPattern : Pat< (ZHighMatMulOp (ZHighStickOp $a, (_2DLayoutAttr), (GetDefaultSaturation)), (ZHighStickOp $b, (_2DLayoutAttr), (GetDefaultSaturation)), - (ZHighStickOp $c, (_1DLayoutAttr), (GetDefaultSaturation)))), + (ZHighStickOp $c, (_1DLayoutAttr), (GetDefaultSaturation)), + (GetZeroI64Attr), (GetZeroI64Attr))), [(IsBiasNoneOr1D:$c)], [], (addBenefit 0) >; @@ -616,13 +816,39 @@ def replaceONNXGemmBias2DPattern : Pat< (ZHighMatMulOp (ZHighStickOp $a, (_2DLayoutAttr), (GetDefaultSaturation)), (ZHighStickOp $b, (_2DLayoutAttr), (GetDefaultSaturation)), - (CreateNoneValue)), + (CreateNoneValue), + (GetZeroI64Attr), (GetZeroI64Attr)), (returnType $res)), $c), [(HasRankOf<2> $c)], [], (addBenefit 0) >; +//===----------------------------------------------------------------------===// +// Replace onnx.Gemm when transA or transB set to ZHighMatMulOp with transpose. +// +// ONNXGemmOp $a, $b, $c, $alpha, $beta, $transA, $transB = +// ZHighUnstickOp +// (ZHighMatMulOp +// (ZHighStickOp $a, (_2DLayoutAttr)), +// (ZHighStickOp $b, (_2DLayoutAttr)), +// (ZHighStickOp $c, (_1DLayoutAttr)), +// $transA, $transB)) +//===----------------------------------------------------------------------===// +def replaceONNXGemmTransPattern : Pat< + (ONNXGemmOp $a, $b, $c, $alpha, $beta, $transA, $transB), + (ZHighUnstickOp + (ZHighMatMulOp + (ZHighStickOp $a, (_2DLayoutAttr), (GetDefaultSaturation)), + (ZHighStickOp $b, (_2DLayoutAttr), (GetDefaultSaturation)), + (ZHighStickOp $c, (_1DLayoutAttr), (GetDefaultSaturation)), + $transA, $transB)), + [(IsAorBTransposed $transA, $transB), + (IsCompatibleWithNNPALevelArch15), + (IsBiasNoneOr1D:$c)], [], + (addBenefit 1) +>; + //===----------------------------------------------------------------------===// // LSTM //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp index 5699b851e9..ce7c4160bd 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp @@ -73,18 +73,17 @@ ValueRange splitAlongAxis( return splits; } -bool isF32ScalarConstantTensor(mlir::Value v) { +bool isF32ScalarConstantTensor(Value v) { if (!isScalarConstantTensor(v)) return false; - auto t = dyn_cast(v.getType()); + auto t = mlir::dyn_cast(v.getType()); return t.getElementType().isF32(); } FloatAttr getScalarF32AttrFromConstant(Value v) { if (!isF32ScalarConstantTensor(v)) return nullptr; - DenseElementsAttr constElements = ElementsAttrBuilder::toDenseElementsAttr( - getElementAttributeFromONNXValue(v)); + ElementsAttr constElements = getElementAttributeFromONNXValue(v); return constElements.getSplatValue(); } @@ -93,7 +92,7 @@ Value getDynShape(Location loc, PatternRewriter &rewriter, Value x) { llvm_unreachable("The input must have shape and rank"); OnnxBuilder create(rewriter, loc); - auto t = dyn_cast(x.getType()); + auto t = mlir::dyn_cast(x.getType()); int64_t r = t.getRank(); SmallVector dims; for (int64_t i = 0; i < r; ++i) { diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index e8e68a0e37..382d596e35 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -2,8 +2,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -//===---------- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh -//---------===// +//===---- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh ---------===// // // Copyright 2019-2024 The IBM Research Authors. // @@ -117,4 +116,4 @@ mlir::Value getDynShape( mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value x); } // namespace onnx_mlir -#endif \ No newline at end of file +#endif diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td index 6ffdc29815..efd05e5fc8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td @@ -2,7 +2,7 @@ //===- ONNXToZHigh.td - Replacing ONNX Ops by ZHigh Ops -*- tablegen ------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -64,4 +64,14 @@ def GetZeroI64Attr: NativeCodeCall< "IntegerAttr::get($_builder.getIntegerType(64, /*isSigned=*/true), APInt(64, 0, /*isSigned=*/true))" >; +def IsCompatibleWithNNPALevelArch14: Constraint< + CPred<"isCompatibleWithNNPALevel(NNPALevel::M14)">, + "Input level is compatible with NNPA level" +>; + +def IsCompatibleWithNNPALevelArch15: Constraint< + CPred<"isCompatibleWithNNPALevel(NNPALevel::M15)">, + "Input level is compatible with NNPA level" +>; + #endif // ONNX_TO_ZHIGH_COMMON diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp index d0acc5e2dd..ad445deae9 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp @@ -416,44 +416,44 @@ double estimateTimeForUnstickOp(Value oper) { bool estimateTimeForOpWithModel(Operation *op, const DimAnalysis *dimAnalysis, double &cpuEstimatedTime, double &nnpaEstimatedTime) { bool opHasModel = true; - if (auto addOp = dyn_cast(op)) + if (auto addOp = mlir::dyn_cast(op)) estimateTimeForOp(addOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto divOp = dyn_cast(op)) + else if (auto divOp = mlir::dyn_cast(op)) estimateTimeForOp(divOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto maxOp = dyn_cast(op)) + else if (auto maxOp = mlir::dyn_cast(op)) estimateTimeForOp(maxOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto minOp = dyn_cast(op)) + else if (auto minOp = mlir::dyn_cast(op)) estimateTimeForOp(minOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto mulOp = dyn_cast(op)) + else if (auto mulOp = mlir::dyn_cast(op)) estimateTimeForOp(mulOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto powOp = dyn_cast(op)) + else if (auto powOp = mlir::dyn_cast(op)) estimateTimeForOp(powOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto subOp = dyn_cast(op)) + else if (auto subOp = mlir::dyn_cast(op)) estimateTimeForOp(subOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); // Unary elementwise NNPA candidate ops. - else if (auto expOp = dyn_cast(op)) + else if (auto expOp = mlir::dyn_cast(op)) estimateTimeForOp(expOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto logOp = dyn_cast(op)) + else if (auto logOp = mlir::dyn_cast(op)) estimateTimeForOp(logOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto reluOp = dyn_cast(op)) + else if (auto reluOp = mlir::dyn_cast(op)) estimateTimeForOp(reluOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto sigmoidOp = dyn_cast(op)) + else if (auto sigmoidOp = mlir::dyn_cast(op)) estimateTimeForOp( sigmoidOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto softmaxOp = dyn_cast(op)) + else if (auto softmaxOp = mlir::dyn_cast(op)) estimateTimeForOp( softmaxOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto tanhOp = dyn_cast(op)) + else if (auto tanhOp = mlir::dyn_cast(op)) estimateTimeForOp(tanhOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); // Reduce - else if (auto reduceMeanOp = dyn_cast(op)) + else if (auto reduceMeanOp = mlir::dyn_cast(op)) estimateTimeForOp( reduceMeanOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); // Matmul. - else if (auto matMulOp = dyn_cast(op)) + else if (auto matMulOp = mlir::dyn_cast(op)) estimateTimeForOp( matMulOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); - else if (auto gemmOp = dyn_cast(op)) + else if (auto gemmOp = mlir::dyn_cast(op)) estimateTimeForOp(gemmOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime); else opHasModel = false; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp index 9e7592ab69..bf1ebd37b6 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.cpp @@ -219,7 +219,7 @@ bool canInferencePadsForNNPAConv(ONNXConvOp op) { // Create an ArrayAttr of IntegerAttr(s) of zero values. // This function is used for padding attribute in Conv. ArrayAttr getPadsForNNPAConv(PatternRewriter &rewriter, Value ret) { - ONNXConvOp op = dyn_cast(ret.getDefiningOp()); + ONNXConvOp op = mlir::dyn_cast(ret.getDefiningOp()); ONNXConvOpShapeHelper shapeHelper(op.getOperation(), {}); shapeHelper.computeShapeAndAssertOnFailure(); SmallVector vals; @@ -240,7 +240,7 @@ ArrayAttr getPadsForNNPAConv(PatternRewriter &rewriter, Value ret) { // This function is used for padding attribute in Conv. DenseElementsAttr insertZerosForNonPaddedDims( PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { - int nDims = (int)origAttrs.getValue().size() / 2; + int nDims = static_cast(origAttrs.getValue().size()) / 2; int nElements = (nDims + extensionLength) * 2; SmallVector pads(nElements, 0); for (int i = 0; i < nDims; ++i) { @@ -451,7 +451,7 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern { if (isa(B)) return false; bool BIsZero = false; - if (auto expandOp = dyn_cast(B.getDefiningOp())) { + if (auto expandOp = mlir::dyn_cast(B.getDefiningOp())) { Value input = expandOp.getInput(); if (isDenseONNXConstant(input)) { // Expand's input is 0? @@ -532,8 +532,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor( target, dimAnalysis, [](ONNXAddOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); // Check element type. if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true)) return true; @@ -547,8 +548,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor( target, dimAnalysis, [](ONNXDivOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); // Check element type. if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true)) return true; @@ -560,8 +562,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor( target, dimAnalysis, [](ONNXMulOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); // Check element type. if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true)) return true; @@ -573,8 +576,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor( target, dimAnalysis, [](ONNXSubOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); // Check element type. if (!isValidElementTypeAndRank(op.getOperation(), op.getA(), true)) return true; @@ -597,8 +601,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor( target, dimAnalysis, [](ONNXMatMulOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); Value A = op.getA(); Value B = op.getB(); @@ -660,6 +665,140 @@ void getRewriteONNXForZHighDynamicallyLegal( return true; }); + // Determine if the pattern `DequantizeLinear (QLinearMatMul inputs)` is + // already legal (no need to rewrite) or need to rewrite. Considering the + // inputs of QLinearMatMul, the following cases must be rewritten: + // - both inputs are *the same* N-D (N > 3) and there is no broadcasting, or + // - one input is N-D (N > 3) and the other is 2-D, or + // + // For such cases, rewrite patterns will be added to turn QLinearMatMulOp into + // the one where N-D will become 3-D. + // + // Starting from ONNXDequantizeLinearOp in order to move Reshape after + // DequantizeLinear, so that QLinearMatMul is still followed by + // DequantizeLinear, which makes optimization for the pattern + // QLinearMatMul-DequantizeLinear easier. In other words, the result pattern + // would look like: + // + // ``` + // A_3D = ReshapeTo3D A_ND + // B_3D = ReshapeTo3D B_ND + // QA_3D = Quantize (A_3D) + // QB_3D = Quantize (B_3D) + // QY_3D = QLinearMatMul QA_3D, QB_3D + // Y_3D = Dequantize QY_3D + // Y_ND = ReshapeToND Y_3D + // ``` + // + // instead of + // + // ``` + // QA_ND = Quantize (A_ND) + // QB_ND = Quantize (B_ND) + // QA_3D = ReshapeTo3D QA_ND + // QB_3D = ReshapeTo3D QB_ND + // QY_3D = QLinearMatMul QA_3D, QB_3D + // QY_ND = ReshapeToND QY_3D + // Y_3D = Dequantize QY_ND + // ``` + // + addDynamicallyLegalOpFor(target, dimAnalysis, + [](ONNXDequantizeLinearOp dlOp, const DimAnalysis *dimAnalysis) { + // Check NNPA level. + if (!isCompatibleWithNNPALevel(NNPALevel::M15)) + return !onnxToZHighInCompatibilityReport( + dlOp.getOperation(), NNPALevel::M15); + + ONNXQLinearMatMulOp op = + dlOp.getX().getDefiningOp(); + if (!op) + return !onnxToZHighUnsupportedReport(dlOp.getOperation(), + "Input is not defined by ONNXQLinearMatMulOp"); + + Value A = op.getA(); + Value AScale = op.getAScale(); + Value AZeroPoint = op.getAZeroPoint(); + Value B = op.getB(); + Value BScale = op.getBScale(); + Value BZeroPoint = op.getBZeroPoint(); + Value Y = op.getY(); + Value YScale = op.getYScale(); + Type aType = A.getType(); + Type bType = B.getType(); + Type yType = Y.getType(); + + if (!isRankedShapedType(aType) || !isRankedShapedType(bType)) { + std::string message = "A or B is not shaped type with rank"; + return !onnxToZHighUnsupportedReport(op.getOperation(), message); + } + + int64_t aRank = getRank(aType); + int64_t bRank = getRank(bType); + ArrayRef aShape = getShape(aType); + ArrayRef bShape = getShape(bType); + + // Only support float32 <-> int8/uint8. + Type elemTyA = getElementType(aType); + Type elemTyAScale = getElementType(AScale.getType()); + Type elemTyB = getElementType(bType); + Type elemTyBScale = getElementType(BScale.getType()); + Type elemTyY = getElementType(yType); + Type elemTyYScale = getElementType(YScale.getType()); + if (!elemTyAScale.isF32() || !elemTyBScale.isF32() || + !elemTyYScale.isF32()) + return !onnxToZHighUnsupportedReport( + op.getOperation(), "A or B or Y's scale is not f32"); + if (!(elemTyA.isInteger(8) || elemTyA.isUnsignedInteger(8))) + return !onnxToZHighUnsupportedReport( + op.getOperation(), "A is not i8 or ui8"); + if (!(elemTyB.isInteger(8) || elemTyB.isUnsignedInteger(8))) + return !onnxToZHighUnsupportedReport( + op.getOperation(), "B is not i8 or ui8"); + if (!(elemTyY.isInteger(8) || elemTyY.isUnsignedInteger(8))) + return !onnxToZHighUnsupportedReport( + op.getOperation(), "Y is not i8 or ui8"); + + // Only support per-tensor quantization. + if (!isScalarTensor(AScale) || !isScalarTensor(BScale) || + !isScalarTensor(AZeroPoint) || !isScalarTensor(BZeroPoint)) + return !onnxToZHighUnsupportedReport( + op.getOperation(), "Not per-tensor quantization"); + + // - one input is N-D (N > 3) and the other is 2-D. + if (aRank == 2 && bRank > 3) + return false; + + if (bRank == 2 && aRank > 3) + return false; + + // - both inputs are *the same* N-D, N > 3 and there is no broadcasting + if (aRank > 3 && (aRank == bRank)) { + bool sameBatchDims = true; + std::string message = ""; + for (int64_t i = 0; i < aRank - 2; ++i) { + sameBatchDims &= (aShape[i] == bShape[i]); + if (aShape[i] != bShape[i]) + message += "The dim " + std::to_string(i) + " of A and dim " + + std::to_string(i) + " of B are not the same."; + + if (sameBatchDims && ShapedType::isDynamic(aShape[i])) { + sameBatchDims &= + dimAnalysis->sameDynDim(op.getA(), i, op.getB(), i); + if (!sameBatchDims) + message += "The dynamic dimension analysis couldn't identify " + "that dim " + + std::to_string(i) + " of A and dim " + + std::to_string(i) + " of B are the same."; + } + } + return (!sameBatchDims) || + onnxToZHighUnsupportedReport(op.getOperation(), message); + } + + // Make other cases legal. + return true; + }); + // Illegalize SoftmaxOp if // - the NNPA level is not compatible, or // - axis is the last dimension. @@ -667,8 +806,9 @@ void getRewriteONNXForZHighDynamicallyLegal( addDynamicallyLegalOpFor(target, dimAnalysis, [](ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) { // Check NNPA level. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - return !onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16); + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + return !onnxToZHighInCompatibilityReport( + op.getOperation(), NNPALevel::M14); Value input = op.getInput(); // std::string message = "The `input` is not reshaped to 3D because it diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td index 854df5f0e5..2173000382 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td @@ -43,6 +43,11 @@ class HasRankOf : def GetSqrtResultBatchNormA : NativeCodeCall<"getSqrtResultBatchNormA($_loc, $_builder, $0, $1)">; +// Return a tensor type whose shape is from the 1st operand and element type is +// from the 2nd operand. +def GetShapeAndElemType: NativeCodeCall< + "RankedTensorType::get(getShape($0.getType()), getElementType($1.getType()))">; + //===----------------------------------------------------------------------===// // Rewrite // @@ -234,6 +239,140 @@ def rewriteMatMulNDto3D_Broadcast_2: Pat< [(HasRankOf<2> $a), (HasRankGT<3> $b)] >; +//===----------------------------------------------------------------------===// +// Rules to turn ONNXQLinearMatMulOp with N-D inputs into the one with 3-D inputs. +//===----------------------------------------------------------------------===// + +// Rewrite QLinearMatMul where N-dimension inputs are reshaped to 3-dimension +// ones. Start rewriting from a DequantizeLinear operation and hoist Reshape +// operations out of the pattern QuantizeLinear-QLinearMatMul-DequantizeLinear +// so that if there is any optimization for the pattern, it is easy to be +// applied. +// In other words, the result pattern would look like: +// +// ``` +// A_3D = ReshapeTo3D A_ND +// B_3D = ReshapeTo3D B_ND +// QA_3D = Quantize (A_3D) +// QB_3D = Quantize (B_3D) +// QY_3D = QLinearMatMul QA_3D, QB_3D +// Y_3D = Dequantize QY_3D +// Y_ND = ReshapeToND Y_3D +// ``` +// +// instead of +// +// ``` +// QA_ND = Quantize (A_ND) +// QB_ND = Quantize (B_ND) +// QA_3D = ReshapeTo3D QA_ND +// QB_3D = ReshapeTo3D QB_ND +// QY_3D = QLinearMatMul QA_3D, QB_3D +// QY_ND = ReshapeToND QY_3D +// Y_3D = Dequantize QY_ND +// ``` +// +def rewriteQLinearMatMulNDto3D_NonBroadcast: Pattern< + (ONNXDequantizeLinearOp:$y_nd + (ONNXQLinearMatMulOp:$qy_nd + (ONNXQuantizeLinearOp:$qa_nd $a_nd, $a_scale, $a_zeropoint, $a_axis, $a_saturate), $_, $_, + (ONNXQuantizeLinearOp:$qb_nd $b_nd, $b_scale, $b_zeropoint, $b_axis, $b_saturate), $_, $_, + $_, $_), + $y_scale, $y_zeropoint, $y_axis), + [ + // Reshape A to 3D and do quantization. + (ReshapeTo3D:$a_3d $a_nd), + (ONNXQuantizeLinearOp:$qa_3d + $a_3d, $a_scale, $a_zeropoint, $a_axis, $a_saturate, + (returnType (GetShapeAndElemType $a_3d, $qa_nd))), + + // Reshape B to 3D and do quantization. + (ReshapeTo3D:$b_3d $b_nd), + (ONNXQuantizeLinearOp:$qb_3d + $b_3d, $b_scale, $b_zeropoint, $b_axis, $b_saturate, + (returnType (GetShapeAndElemType $b_3d, $qb_nd))), + + // Call QLinearMatMul on 3D inputs. + (ONNXQLinearMatMulOp:$qy_3d + $qa_3d, $a_scale, $a_zeropoint, + $qb_3d, $b_scale, $b_zeropoint, + $y_scale, $y_zeropoint, + (returnType (GetMatMulResultType $qa_3d, $qb_3d))), + + // Dequantize the 3D output. + (ONNXDequantizeLinearOp:$y_3d $qy_3d, $y_scale, $y_zeropoint, $y_axis, + (returnType (GetShapeAndElemType $qy_3d, $y_nd))), + + // Reshape the output back to ND. + (ONNXReshapeOp $y_3d, (GetMatMulResultShape $a_nd, $b_nd), (GetZeroI64Attr)) + ], + [(HasRankGT<3> $a_nd), (HasRankGT<3> $b_nd)] +>; + +// A is ND, B is 2D. +def rewriteQLinearMatMulNDto3D_Broadcast1: Pattern< + (ONNXDequantizeLinearOp:$y_nd + (ONNXQLinearMatMulOp:$qy_nd + (ONNXQuantizeLinearOp:$qa_nd $a_nd, $a_scale, $a_zeropoint, $a_axis, $a_saturate), $_, $_, + $qb, $b_scale, $b_zeropoint, + $_, $_), + $y_scale, $y_zeropoint, $y_axis), + [ + // Reshape A to 3D and do quantization. + (ReshapeTo3D:$a_3d $a_nd), + (ONNXQuantizeLinearOp:$qa_3d + $a_3d, $a_scale, $a_zeropoint, $a_axis, $a_saturate, + (returnType (GetShapeAndElemType $a_3d, $qa_nd))), + + // Call QLinearMatMul on 3D inputs. + (ONNXQLinearMatMulOp:$qy_3d + $qa_3d, $a_scale, $a_zeropoint, + $qb, $b_scale, $b_zeropoint, // Keep B unchanged. + $y_scale, $y_zeropoint, + (returnType (GetMatMulResultType $qa_3d, $qb))), + + // Dequantize the 3D output. + (ONNXDequantizeLinearOp:$y_3d $qy_3d, $y_scale, $y_zeropoint, $y_axis, + (returnType (GetShapeAndElemType $qy_3d, $y_nd))), + + // Reshape the output back to ND. + (ONNXReshapeOp $y_3d, (GetMatMulResultShape $a_nd, $qb), (GetZeroI64Attr)) + ], + [(HasRankGT<3> $a_nd), (HasRankOf<2> $qb)] +>; + +// A is 2D, B is ND. +def rewriteQLinearMatMulNDto3D_Broadcast2: Pattern< + (ONNXDequantizeLinearOp:$y_nd + (ONNXQLinearMatMulOp:$qy_nd + $qa, $a_scale, $a_zeropoint, + (ONNXQuantizeLinearOp:$qb_nd $b_nd, $b_scale, $b_zeropoint, $b_axis, $b_saturate), $_, $_, + $_, $_), + $y_scale, $y_zeropoint, $y_axis), + [ + // Reshape B to 3D and do quantization. + (ReshapeTo3D:$b_3d $b_nd), + (ONNXQuantizeLinearOp:$qb_3d + $b_3d, $b_scale, $b_zeropoint, $b_axis, $b_saturate, + (returnType (GetShapeAndElemType $b_3d, $qb_nd))), + + // Call QLinearMatMul on 3D inputs. + (ONNXQLinearMatMulOp:$qy_3d + $qa, $a_scale, $a_zeropoint, // Keep A unchanged. + $qb_3d, $b_scale, $b_zeropoint, + $y_scale, $y_zeropoint, + (returnType (GetMatMulResultType $qa, $qb_3d))), + + // Dequantize the 3D output. + (ONNXDequantizeLinearOp:$y_3d $qy_3d, $y_scale, $y_zeropoint, $y_axis, + (returnType (GetShapeAndElemType $qy_3d, $y_nd))), + + // Reshape the output back to ND. + (ONNXReshapeOp $y_3d, (GetMatMulResultShape $qa, $b_nd), (GetZeroI64Attr)) + ], + [(HasRankOf<2> $qa), (HasRankGT<3> $b_nd)] +>; + //===----------------------------------------------------------------------===// // Rules to turn ONNXSoftmaxOp with N-D inputs into the one with 2-D inputs. //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td index 2175d0ecc2..8fd410c98f 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td @@ -37,52 +37,30 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create($_loc, $0.getT // ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighAddPattern1 : Pat< - (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXAddOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)] ->; - -def replaceZHighAddPattern2 : Pat< - (ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXAddOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)] +def replaceZHighAddPattern : Pat< + (ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXAddOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMulPattern1 : Pat< - (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXMulOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMulPattern2 : Pat< - (ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXMulOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [], - (addBenefit 0) +def replaceZHighMulPattern : Pat< + (ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXMulOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighSubPattern1 : Pat< - (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (ONNXSubOp $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighSubPattern2 : Pat< - (ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (ONNXSubOp (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighSubPattern : Pat< + (ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (ONNXSubOp $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// @@ -90,54 +68,30 @@ def replaceZHighSubPattern2 : Pat< // %X),(ZHighStickOp %Y)) // Note: turn off this pattern since NNPA is faster at this moment. //===----------------------------------------------------------------------===// -//def replaceZHighDivPattern1 : Pat< -// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)), -// (ONNXDivOp $x, (ZHighUnstickOp $y)), -// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], -// (addBenefit 1) -//>; -// -//def replaceZHighDivPattern2 : Pat< -// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))), -// (ONNXDivOp (ZHighUnstickOp $x), $y), -// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], -// (addBenefit 0) -//>; +// def replaceZHighDivPattern : Pat< +// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), +// (ONNXDivOp $x, $y), +// [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] +// >; //===----------------------------------------------------------------------===// // ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMinPattern1 : Pat< - (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMinPattern2 : Pat< - (ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (CreateONNXMinOp $u, (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighMinPattern : Pat< + (ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (CreateONNXMinOp $u, $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// // ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X), // (ZHighStickOp %Y)) //===----------------------------------------------------------------------===// -def replaceZHighMaxPattern1 : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), $y)), - (CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)), - [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ], - (addBenefit 1) ->; - -def replaceZHighMaxPattern2 : Pat< - (ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_, $_))), - (CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y), - [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ], - (addBenefit 0) +def replaceZHighMaxPattern : Pat< + (ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_, $_), (ZHighStickOp:$s_y $y, $_, $_))), + (CreateONNXMaxOp $u, $x, $y), + [(NotBlockArgument:$x), (HasOneUse:$s_x), (NotBlockArgument:$y), (HasOneUse:$s_y)] >; //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ZHighToZLow/CMakeLists.txt index fb412aa5e9..237ad53d0a 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/CMakeLists.txt @@ -1,8 +1,10 @@ add_onnx_mlir_library(OMZHighToZLow + ProcessStickData.cpp ZHighToZLow.cpp LINK_LIBS PUBLIC MLIRMemRefTransforms + OMKrnlToLLVM OMLayoutHelper OMONNXToKrnl OMStickify diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.cpp new file mode 100644 index 0000000000..9829892237 --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.cpp @@ -0,0 +1,170 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====---------- ProcessStickData.cpp - Process Stick data ----------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of ZHigh operations to Krnl/Affine/SCF +// operations that operates on stickified input/output data. +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp" +#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" +#include "src/Compiler/CompilerOptions.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/SmallVectorHelper.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +// Implementation of quantize helper function. +void emitDynamicQuantizationLinearMinMaxFromStickifiedInput( + ConversionPatternRewriter &rewriter, Location loc, Operation *op, + Value input, StringAttr inputLayout, Value &inputMin, Value &inputMax, + bool enableSIMD, bool enableParallel) { + using MDBuilder = MultiDialectBuilder; + MDBuilder create(rewriter, loc); + + // Extract dims from input, set lbs/ubs. + DimsExpr dims; + create.krnlIE.getShapeAsSymbols(input, dims); + int64_t rank = dims.size(); + IndexExpr zero = LitIE(0); + DimsExpr lbs(rank, zero); + DimsExpr ubs = dims; + + // Decide parameters. + // UnrollVL decides how many vectors of 8 DLF16 will be processed at once. + int64_t unrollVL = 4; // Experimentally good unroll factor. + int64_t archVL = 8; // DLF16. + int64_t totVL = unrollVL * archVL; + + // If not parallel, threadNum = 1, forExplicitParallelLoopIE will simply pass + // through the lb/ub, so ok to have parID = 0 for the sequential cases. + int64_t parId = 0; + int64_t threadNum = 1; + if (enableParallel) { + if (findSuitableParallelDimension(lbs, ubs, 0, rank - 1, parId, 8)) { + threadNum = 8; // TODO use more flexible value. + onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], + "simd min/max for DQL in parallel "); + } else { + enableParallel = false; + onnxToKrnlParallelReport( + op, false, -1, -1, "not enough work in simd min/max for DQL"); + } + } + + // Alloc temp buffers (more when using parallel). + Type f32Type = rewriter.getF32Type(); + // For each thread, we can use totVL temp values for the current min/max. + // But to increase the compute ratio over mem, we will reuse the same tmp + // memory location for a pair of totVL values being processed. + int64_t tmpSizePerThread = totVL / 2; // Reduce pair in same tmp. + int64_t tmpSize = threadNum * tmpSizePerThread; + MemRefType redType = MemRefType::get({tmpSize}, f32Type); + VectorType vec8xF32Type = VectorType::get({archVL}, f32Type); + VectorType vec4xF32Type = VectorType::get({archVL / 2}, f32Type); + + Value minTmp = create.mem.alignedAlloc(redType); + Value maxTmp = create.mem.alignedAlloc(redType); + + // Init min and max. + Value minInit = create.math.positiveInf(f32Type); + Value splatMinInit = create.vec.splat(vec8xF32Type, minInit); + Value maxInit = create.math.negativeInf(f32Type); + Value splatMaxInit = create.vec.splat(vec8xF32Type, maxInit); + // Could parallelize init, here main thread do it all. Use SIMD of 8x. + for (int64_t u = 0; u < tmpSize; u += 8) { + IndexExpr offset = LitIE(u); + create.vec.storeIE(splatMinInit, minTmp, {offset}); + create.vec.storeIE(splatMaxInit, maxTmp, {offset}); + } + + // Reduction into these temps. + IndexExpr tNum = LitIE(threadNum); + create.krnl.forExplicitParallelLoopIE(lbs[parId], ubs[parId], tNum, + [&](const KrnlBuilder &ck, ValueRange loopInd) { + IndexExprScope scope(ck); + IndexExpr t = DimIE(loopInd[0]); + DimsExpr currDims = SymListIE(dims); + // Reduce lbs, ubs for parallel region, if any. + DimsExpr currLbs = SymListIE(lbs); + DimsExpr currUbs = SymListIE(ubs); + // In sequential cases (threadNum ==1, loopInd[1,2]== orig lb,ub). + currLbs[parId] = SymIE(loopInd[1]); + currUbs[parId] = SymIE(loopInd[2]); + // Cannot use krnl because we may not have affine bounds. + SCFBuilder sb(ck); + IterateOverStickInputData( + sb, op, currLbs, currUbs, currDims, inputLayout, input, nullptr, + unrollVL, /*enableParallel*/ false, + /*prefetch, disable as it causes issue with affine*/ false, + [&](const KrnlBuilder &b, SmallVectorImpl &vecOf4xF32Vals, + DimsExpr &loopIndices) { + MDBuilder create(b); + int64_t size = vecOf4xF32Vals.size(); + assert((size == 2 || size == 2 * unrollVL) && "unexpected size"); + // Since all threads share the same tmpMin/Max, needs to offset by + // t * . + IndexExpr threadOffset = SymIE(t) * tmpSizePerThread; + size = size / 2; // handle pairs of 2, so size=1 or unrollVL. + for (int i = 0; i < size; ++i) { + Value val0 = vecOf4xF32Vals[2 * i]; + Value val1 = vecOf4xF32Vals[2 * i + 1]; + // Load appropriate tmp, compute min/max, store in tmp. + IndexExpr offset = threadOffset + LitIE(4 * i); + Value currMin = + create.vec.loadIE(vec4xF32Type, minTmp, {offset}); + Value currMax = + create.vec.loadIE(vec4xF32Type, maxTmp, {offset}); + currMin = create.math.min(currMin, val0); + currMax = create.math.max(currMax, val0); + currMin = create.math.min(currMin, val1); + currMax = create.math.max(currMax, val1); + create.vec.storeIE(currMin, minTmp, {offset}); + create.vec.storeIE(currMax, maxTmp, {offset}); + } + }, + [&](const KrnlBuilder &b, Value scalarF32Val, + DimsExpr &loopIndices) { + MDBuilder create(b); + Value currMin = create.krnl.loadIE(minTmp, {zero}); + Value currMax = create.krnl.loadIE(maxTmp, {zero}); + currMin = create.math.min(currMin, scalarF32Val); + currMax = create.math.max(currMax, scalarF32Val); + create.krnl.storeIE(currMin, minTmp, {zero}); + create.krnl.storeIE(currMax, maxTmp, {zero}); + }); // Iterate over stick. + }); // Explicit parallel loop (sequential if threadNum==1). + + // Now we have all the partial min/max inside the minTmp/maxTmp: reduce each + // vectors with each others. Main thread reduces all the values. Use SIMD of + // 8x. + Value finalVecMin = create.vec.loadIE(vec8xF32Type, minTmp, {zero}); + Value finalVecMax = create.vec.loadIE(vec8xF32Type, maxTmp, {zero}); + for (int u = 8; u < tmpSize; u += 8) { + IndexExpr offset = LitIE(u); + Value currMin = create.vec.loadIE(vec8xF32Type, minTmp, {offset}); + Value currMax = create.vec.loadIE(vec8xF32Type, maxTmp, {offset}); + finalVecMin = create.math.min(finalVecMin, currMin); + finalVecMax = create.math.max(finalVecMax, currMax); + } + + // Horizontal reduction of the vectors into a scalar. + inputMin = create.vec.reduction(VectorBuilder::MIN, finalVecMin); + inputMax = create.vec.reduction(VectorBuilder::MAX, finalVecMax); +} + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp new file mode 100644 index 0000000000..c72bed046c --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====---------- ProcessStickData.cpp - Process Stick data ----------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of ZHigh operations to Krnl/Affine/SCF +// operations that operates on stickified input/output data. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_MLIR_PROCESS_STICK_DATA_H +#define ONNX_MLIR_PROCESS_STICK_DATA_H + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +namespace onnx_mlir { + +// By definition of the conversion from dlf16 to f32, vecOfF32Vals should always +// contain pairs of vectors. +using ContiguousVectorOfF32IterateBodyFn = std::function &vecOfF32Vals, + DimsExpr &loopIndices)>; + +using ScalarF32IterateBodyFn = std::function; + +// Iterate over each values in the input's sticks, processing vectors (of 4 F32) +// with processVectorOfF32Vals, and scalars (1 F32) with processScalarF32Val, By +// definition, processVectorOfF32Vals contains either 2 or 2*unrollVL vectors. +// And processScalarF32Val process only 1 scalar value. Output is only used for +// prefetching. If output is null, skip output prefetching. In general, we +// expects lbs={0...0} and ubs=dims. WHen parallelized outside of this loop, +// then lbs and ubs can reflect the subset of iterations assigned to this +// thread. Iterations cannot be partitioned on the innermost dim. +template +void IterateOverStickInputData(const BUILDER &b, mlir::Operation *op, + DimsExpr &lbs, DimsExpr &ubs, DimsExpr &dims, mlir::StringAttr layout, + mlir::Value input, mlir::Value output, int64_t unrollVL, + bool enableParallel, bool enablePrefetch, + ContiguousVectorOfF32IterateBodyFn processVectorOfF32Vals, + ScalarF32IterateBodyFn processScalarF32Val); + +// Compute min/max from stickified input. Currently support 2DS, 3D, 3DS, +// 4D formats. +void emitDynamicQuantizationLinearMinMaxFromStickifiedInput( + mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Operation *op, mlir::Value input, mlir::StringAttr inputLayout, + mlir::Value &inputMin, mlir::Value &inputMax, bool enableSIMD, + bool enableParallel); + +} // namespace onnx_mlir + +// Include template code. +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp.inc" + +#endif diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp.inc b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp.inc new file mode 100644 index 0000000000..3e8a51b46c --- /dev/null +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp.inc @@ -0,0 +1,261 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====-------- ProcessStickData.hpp.inc - Process Stick data --------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of ZHigh operations to Krnl/Affine/SCF +// operations that operates on stickified input/output data. +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp" +#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" +#include "src/Compiler/CompilerOptions.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Support/SmallVectorHelper.hpp" + +namespace onnx_mlir { + +// Iterate over each stick, for an original size of dims, and cover the +// iterations from lbs to ubs. In most cases, lbs={0...} and ubs=dims, namely we +// cover all iterations. But we can parallelize the loops from the outside, in +// which case we expect lbs and ubs to reflect the iterations assigned to this +// thread. Note that we cannot tile in the innermost dim (as this is the +// dimension of the sticks). +template +void IterateOverStickInputData(const BUILDER &b, mlir::Operation *op, + DimsExpr &lbs, DimsExpr &ubs, DimsExpr &dims, mlir::StringAttr inputLayout, + mlir::Value input, mlir::Value output, int64_t unrollVL, + bool enableParallel, bool enablePrefetch, + ContiguousVectorOfF32IterateBodyFn processVectorOfF32Vals, + ScalarF32IterateBodyFn processScalarF32Val) { + // Init builder and scopes. + using MDBuilder = MultiDialectBuilder; + MDBuilder create(b); + //IndexExprScope initialScope(b); + // Get info and check some inputs. + int64_t rank = dims.size(); + int64_t d1 = rank - 1; + IndexExpr E1 = dims[d1]; + assert(lbs.size() == ubs.size() && "expected same sizes"); + assert(lbs.size() == dims.size() && "expected same sizes"); + assert((inputLayout.getValue().equals_insensitive("4D") || + inputLayout.getValue().equals_insensitive("3D") || + inputLayout.getValue().equals_insensitive("2D") || + inputLayout.getValue().equals_insensitive("3DS") || + inputLayout.getValue().equals_insensitive("NHWC")) && + "unsupported inputLayout"); + + // Info for SIMD Vector Length (VL). + int64_t archVL = 8; // FP16 archVL. + int64_t archVLHalf = archVL / 2; // FP32 archVL. + int64_t totVL = archVL * unrollVL; + int64_t stickLen = 64; + assert(stickLen % totVL == 0 && "bad unrollVL factor"); + mlir::Type f16Type = b.getBuilder().getF16Type(); + mlir::Type f32Type = b.getBuilder().getF32Type(); + mlir::VectorType vecF16Type = mlir::VectorType::get({archVL}, f16Type); + mlir::MemRefType bufferF32Type = mlir::MemRefType::get({archVL}, f32Type); + + // Useful constants. + IndexExpr litZero = LitIE(0); + IndexExpr lit1 = LitIE(1); + IndexExpr lit2 = LitIE(2); + IndexExpr litArchVLHalf = LitIE(archVLHalf); + IndexExpr litArchVL = LitIE(archVL); + IndexExpr litStickLen = LitIE(stickLen); + + // Create loop iterations. We iterate over E1 as sticks of 64 elements. Lbs + // and ubs reflect the iteration over the sticks (tiled data points). + DimsExpr tiledLbs = lbs; + DimsExpr tiledUbs = ubs; + tiledUbs[d1] = E1.ceilDiv(litStickLen); + + // Predicates used to avoid creating code that is never used. + bool neverHas64 = E1.isLiteralAndSmallerThan(stickLen); + bool neverHas8 = E1.isLiteralAndSmallerThan(archVL); + bool hasOnly64 = E1.isLiteral() && (E1.getLiteral() % stickLen == 0); + bool hasOnly8 = E1.isLiteral() && (E1.getLiteral() % archVL == 0); + + // Parallel... Should not be turned on when parallelized in the outside. + int64_t parId = 0; + if (enableParallel) { + // TODO: may want to check if ub of rank makes sense here. + // Its ok here even to partition rank-1, included in (0..rank(, because + // rank-1 is tiled. So we are still dealing with multiple of sticks. + if (findSuitableParallelDimension(tiledLbs, tiledUbs, 0, rank, parId, 8)) { + onnxToKrnlParallelReport(op, true, parId, tiledLbs[parId], + tiledUbs[parId], "compiler-generated stickify"); + } else { + enableParallel = false; + onnxToKrnlParallelReport(op, false, -1, -1, + "no dim with enough work in compiler-generated stickify"); + } + } + + // Compute max sticks (tiles of 64 values). It is actually not easy to compute + // the max number of sticks. Since we don't allocate, it is just a "view", we + // only need to index by the "stick size", it is sufficient to assume 2 or + // more. + DimsExpr reallocStickDims = {lit2, litStickLen}; + mlir::Value inputAsSticks = + create.mem.reinterpretCast(input, reallocStickDims); + + llvm::SmallVector steps(rank, 1); + llvm::SmallVector useParallel(rank, false); + if (enableParallel) + useParallel[parId] = true; + b.forLoopsIE(tiledLbs, tiledUbs, steps, useParallel, + [&](const BUILDER &b, mlir::ValueRange tiledLoopInd) { + MDBuilder create(b); + IndexExprScope outerScope(b); + DimsExpr tiledOuterIndices = DimListIE(tiledLoopInd); + // Computation for accessing data (not tiled, actual indices). + DimsExpr outerIndices = tiledOuterIndices; + IndexExpr E1 = SymIE(dims[d1]); // Original upper bound in d1. + IndexExpr e1 = outerIndices[d1] = tiledOuterIndices[d1] * litStickLen; + // Translate the tile index t1 to the actual targetted data. Have to + // give the actual indices, not the tiled ones. + mlir::Value inputOffset = + create.krnl.getLinearOffsetIndexIE(input, outerIndices); + // Offset in inputAsSticks's first dim is as multiple of litStickLen, so + // divide by it. + IndexExpr inputStickOffset = SymIE(inputOffset).floorDiv(litStickLen); + // Buffer for small leftovers (used when E1 % 8 != 0) + mlir::Value bufferF32; + if (!hasOnly8) + bufferF32 = create.mem.alignedAlloc(bufferF32Type); + if (enablePrefetch) { + // Prefetch current line + create.krnl.prefetchIE(input, outerIndices, /*write*/ false, + /*locality*/ 1); + if (output) + create.krnl.prefetchIE(output, outerIndices, /*write*/ true, + /*locality*/ 1); + } + // Check if we have a full stick (aka end of stick is not beyond UB). + IndexExpr hasFullStick; + if (hasOnly64) { + hasFullStick = PredIE(true); // Has only full sicks. + } else if (neverHas64) { + hasFullStick = PredIE(false); // Doesn't even has 1 stick. + } else { + IndexExpr isFull = create.krnlIE.isTileFull(e1, litStickLen, E1); + hasFullStick = (isFull >= 0); + } + create.scf.ifThenElse( + hasFullStick.getValue(), + // If is full. + [&](const SCFBuilder b) { + if (neverHas64) + return; // Nothing to do here. Avoid generating dead code. + MDBuilder create(b); + // Iterate through stick by totVL (aka 8 * unroll). + create.scf.forLoopIE(litZero, litStickLen, totVL, /*par*/ false, + [&](const SCFBuilder b, mlir::ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope innerScope(b, &outerScope); + IndexExpr l = DimIE(loopInd[0]); + DimsExpr innerIndices = SymListIE(outerIndices); + innerIndices[d1] = innerIndices[d1] + l; + mlir::SmallVector vecOfF32Vals; + // Load archVL (8) f16 values from input via reinterpreted + // data tile, and then convert them into f32 and enqueue in + // vecOfF32Vals. + for (int64_t u = 0; u < unrollVL; ++u) { + mlir::Value vecOfF16 = + create.vec.loadIE(vecF16Type, inputAsSticks, + {SymIE(inputStickOffset), l + (u * archVL)}); + auto convertOp = + b.getBuilder() + .create( + b.getLoc(), vecOfF16); + vecOfF32Vals.emplace_back(convertOp.getResult(0)); + vecOfF32Vals.emplace_back(convertOp.getResult(1)); + } + processVectorOfF32Vals( + create.krnl, vecOfF32Vals, innerIndices); + }); + }, + // Else, we don't have a full (64 e1) tile. + [&](SCFBuilder b) { + if (hasOnly64) + return; // Do not generate dead code. + MDBuilder create(b); + IndexExprScope middleScope(b, &outerScope); + IndexExpr tripCount = SymIE(E1) - SymIE(e1); + if (!neverHas8) { + // Note: if we only have multiple of VL, loop below will + // handle all as we subtract (VL-1). Aka if VL=8 and tripCount + // = 16, tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we + // iterate over i=0 & i=8 as both are < 9. + IndexExpr tripCountWithoutPartialLastVL = + tripCount - (archVL - 1); + create.scf.forLoopIE(litZero, tripCountWithoutPartialLastVL, + archVL, /*par*/ false, + [&](SCFBuilder b, mlir::ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope innerScope(b, &middleScope); + IndexExpr l = DimIE(loopInd[0]); + DimsExpr innerIndices = SymListIE(outerIndices); + innerIndices[d1] = innerIndices[d1] + l; + mlir::SmallVector vecOfF32Vals; + // Load f16 values from input via reinterpreted data + // tile. + mlir::Value vecOfF16 = create.vec.loadIE(vecF16Type, + inputAsSticks, {SymIE(inputStickOffset), l}); + // Convert back to f32. + auto convertOp = + b.getBuilder() + .create( + b.getLoc(), vecOfF16); + vecOfF32Vals.emplace_back(convertOp.getResult(0)); + vecOfF32Vals.emplace_back(convertOp.getResult(1)); + processVectorOfF32Vals( + create.krnl, vecOfF32Vals, innerIndices); + }); + } + if (!hasOnly8) { + // Deal with the last <8 values: compute f32 using simd. + IndexExpr remainingScalarValues = tripCount % archVL; + IndexExpr lastL = tripCount - remainingScalarValues; + mlir::Value vecOfF16 = create.vec.loadIE(vecF16Type, + inputAsSticks, {SymIE(inputStickOffset), lastL}); + // Convert back to f32. + auto convertOp = + b.getBuilder().create( + b.getLoc(), vecOfF16); + mlir::Value vecF32H = convertOp.getResult(0); + mlir::Value vecF32L = convertOp.getResult(1); + // Save into archVL value buffer. + create.vec.storeIE(vecF32H, bufferF32, {litZero}); + create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}); + create.scf.forLoopIE(litZero, remainingScalarValues, 1, + /*par*/ false, [&](SCFBuilder b, mlir::ValueRange loopInd) { + MDBuilder create(b); + IndexExprScope innerScope(b, &middleScope); + IndexExpr l = DimIE(loopInd[0]); + // Load converted value. + mlir::Value f32 = create.krnl.loadIE(bufferF32, {l}); + + DimsExpr innerIndices = SymListIE(outerIndices); + innerIndices[d1] = innerIndices[d1] + SymIE(lastL); + innerIndices[d1] = innerIndices[d1] + l; + processScalarF32Val(create.krnl, f32, innerIndices); + }); + } + }); + }); +} + +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index d5b3910730..2cdc850e02 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -19,15 +19,19 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/DialectResourceBlobManager.h" +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp" #include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" +#include "src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Accelerators/NNPA/Support/Stickify/Convert.hpp" +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Support/TypeUtilities.hpp" @@ -49,7 +53,8 @@ using MDBuilder = MultiDialectBuilder dims, - ZTensorEncodingAttr::DataLayout layout, Operation *op, + ZTensorEncodingAttr::DataLayout layout, + ZTensorEncodingAttr::QuantizedType qtype, Operation *op, PatternRewriter &rewriter, int64_t alignment = gAlignment) { // Construct a MemRefType for the given dimensions and element type. SmallVector shape; @@ -57,7 +62,7 @@ Value insertAllocForZMemRefByDim(ArrayRef dims, shape.emplace_back((d.isLiteral() ? d.getLiteral() : ShapedType::kDynamic)); RankedTensorType tensorType = RankedTensorType::get(shape, rewriter.getF32Type(), - ZTensorEncodingAttr::get(op->getContext(), layout)); + ZTensorEncodingAttr::get(op->getContext(), layout, qtype)); ZMemRefType zMemRefType = convertZTensorToMemRefType(tensorType); // Insert alloc. @@ -168,12 +173,33 @@ static Value insertAllocForWorkAreaForRNNOps(IndexExprBuilderForKrnl &createIE, return create.mem.alignedAlloc(resultType, dims, gAlignment); } +/// Get a dense resource attribute to store stickified data of a given i8 value. +/// Attribute type: tensor +DenseResourceElementsAttr getDenseResourceElementsAttrOfValue( + PatternRewriter &rewriter, ZHighStickifiedConstantOp stickifiedConstant, + int8_t val, int64_t sizeInBytes) { + char *rawData = static_cast(malloc(sizeInBytes)); + assert(rawData && "failed to allocate memory for stickified data"); + memset(rawData, val, sizeInBytes); + DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( + RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), + stickifiedConstant.getOperation() + ->getDialect() + ->getNamespace(), // use the dialect as the blob "hint" + HeapAsmResourceBlob::allocateAndCopyWithAlign( + llvm::ArrayRef(rawData, sizeInBytes), alignof(char))); + free(rawData); + return valueAttr; +} + /// This function emits a buffer of zero elements for the given dimensions and /// layout. If the given dimensions are static, then a stickified constant is /// returned. Value insertAllocOrEmitZeroConstant(ArrayRef dims, ZTensorEncodingAttr::DataLayout layout, Operation *op, - PatternRewriter &rewriter, Location loc) { + PatternRewriter &rewriter, Location loc, + ZTensorEncodingAttr::QuantizedType qtype = + ZTensorEncodingAttr::QuantizedType::UNDEFINED) { Value res; bool allStaticDims = llvm::all_of(dims, [](IndexExpr ie) { return ie.isLiteral(); }); @@ -182,9 +208,12 @@ Value insertAllocOrEmitZeroConstant(ArrayRef dims, SmallVector shape; for (IndexExpr d : dims) shape.emplace_back(d.getLiteral()); - RankedTensorType tensorType = - RankedTensorType::get(shape, rewriter.getF32Type(), - ZTensorEncodingAttr::get(op->getContext(), layout)); + Type elemType = rewriter.getF32Type(); + if (qtype == ZTensorEncodingAttr::QuantizedType::WEIGHTS || + qtype == ZTensorEncodingAttr::QuantizedType::INT8) + elemType = rewriter.getI8Type(); + RankedTensorType tensorType = RankedTensorType::get(shape, elemType, + ZTensorEncodingAttr::get(op->getContext(), layout, qtype)); ZMemRefType zMemRefType = convertZTensorToMemRefType(tensorType); MemRefType resType = affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); @@ -199,24 +228,14 @@ Value insertAllocOrEmitZeroConstant(ArrayRef dims, // Attribute type: tensor int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(resType).value(); - char *rawData = (char *)malloc(sizeInBytes); - assert(rawData && "failed to allocate memory for stickified data"); - memset(rawData, 0, sizeInBytes); - DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( - RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), - stickifiedConstant.getOperation() - ->getDialect() - ->getNamespace(), // use the dialect as the blob "hint" - HeapAsmResourceBlob::allocateAndCopyWithAlign( - llvm::ArrayRef(rawData, sizeInBytes), alignof(char))); + DenseResourceElementsAttr valueAttr = getDenseResourceElementsAttrOfValue( + rewriter, stickifiedConstant, 0, sizeInBytes); stickifiedConstant.setValueAttr(valueAttr); - free(rawData); - res = stickifiedConstant.getResult(); } else { MultiDialectBuilder create(rewriter, loc); - res = insertAllocForZMemRefByDim(dims, layout, op, rewriter); - Value initValue = create.math.constant(rewriter.getF16Type(), 0); + res = insertAllocForZMemRefByDim(dims, layout, qtype, op, rewriter); + Value initValue = create.math.constant(getElementType(res.getType()), 0); create.krnl.memset(res, initValue, /*delayed=*/true); } return res; @@ -228,7 +247,7 @@ Value insertShapeMemRefI64( MultiDialectBuilder create( rewriter, loc); MemRefType shapeMemRefType = MemRefType::get( - {(int64_t)originalDims.size()}, rewriter.getIntegerType(64)); + {static_cast(originalDims.size())}, rewriter.getIntegerType(64)); Value shapeMemRef = create.mem.alignedAlloc(shapeMemRefType); for (uint64_t i = 0; i < originalDims.size(); ++i) { Value dim = @@ -249,11 +268,31 @@ ZMemRefType convertZTensorToMemRefType(Type type) { Type elementType = tensorType.getElementType(); int64_t rank = shape.size(); if (tensorType.getEncoding()) { - // Obtain element type and affine map. + // Obtain element type. + ZTensorEncodingAttr::QuantizedType qtype = getZTensorQuantizedType(type); + if (qtype == ZTensorEncodingAttr::QuantizedType::DLFLOAT16) + elementType = b.getF16Type(); + else if (qtype == ZTensorEncodingAttr::QuantizedType::INT8) + elementType = b.getI8Type(); + else if (qtype == ZTensorEncodingAttr::QuantizedType::WEIGHTS) + elementType = b.getI8Type(); + else + elementType = b.getF16Type(); + // Obtain affine map. AffineExpr constExpr0 = getAffineConstantExpr(0, b.getContext()); AffineExpr constExpr31 = getAffineConstantExpr(31, b.getContext()); - AffineExpr constExpr32 = getAffineConstantExpr(32, b.getContext()); - AffineExpr constExpr64 = getAffineConstantExpr(64, b.getContext()); + AffineExpr constE2Block = getAffineConstantExpr(32, b.getContext()); + AffineExpr constE1Block = getAffineConstantExpr(64, b.getContext()); + if (qtype == ZTensorEncodingAttr::QuantizedType::INT8) { + // For quantized i8, 128 cells per stick. + constE1Block = getAffineConstantExpr(128, b.getContext()); + } else if (qtype == ZTensorEncodingAttr::QuantizedType::WEIGHTS) { + // WEIGHTS has two vectors interleaved, therefore only 64 cells vs 128 + // Due to this interleaving, number_of_sticks is halved, but must be + // rounded up to stay even for proper interleaving. + constE2Block = getAffineConstantExpr(64, b.getContext()); + } + unsigned e4, e3, e2, e1; AffineExpr n, c, h, w, res32, res64; SmallVector dimExpr; @@ -263,22 +302,22 @@ ZMemRefType convertZTensorToMemRefType(Type type) { // (e1) -> (1, 1, 1, e1) -> (1, ceil(e1/64), 1, 1, 32, 64) e1 = 0; n = constExpr0; - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = constExpr0; c = constExpr0; res32 = constExpr31; - res64 = b.getAffineDimExpr(e1) % constExpr64; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_2D) { // (e2, e1) -> (1, 1, e2, e1) -> (1, ceil(e1/64), 1, ceil(e2/32), 32 // 64) e2 = 0; e1 = 1; n = constExpr0; - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = constExpr0; - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_3D) { // (e3, e2, e1) -> (1, e3, e2, e1) // -> (1, ceil(e1/64), e3, ceil(e2/32), 32, 64) @@ -286,11 +325,11 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e2 = 1; e1 = 2; n = constExpr0; - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = b.getAffineDimExpr(e3); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_4D) { // (e4, e3, e2, e1) -> (e4, ceil(e1/64), e3, ceil(e2/32), 32, 64) e4 = 0; @@ -298,21 +337,21 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e2 = 2; e1 = 3; n = b.getAffineDimExpr(e4); - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = b.getAffineDimExpr(e3); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_2DS) { // (e4, e1) -> (e4, 1, 1, e1) -> (e4, ceil(e1/64), 1, 1, 32, 64) e4 = 0; e1 = 1; n = b.getAffineDimExpr(e4); - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = constExpr0; c = constExpr0; res32 = constExpr31; - res64 = b.getAffineDimExpr(e1) % constExpr64; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_3DS) { // (e4, e2, e1) -> (e4, 1, e2, e1) // -> (e4, ceil(e1/64), 1, ceil(e2/32), 32, 64) @@ -320,11 +359,11 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e2 = 1; e1 = 2; n = b.getAffineDimExpr(e4); - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = constExpr0; - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::_4DS) { // for normal // (e4, e3, e2, e1) @@ -340,16 +379,16 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e1 = 3; n = b.getAffineDimExpr(e4); if (shape[1] == 1) { - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); } else { AffineExpr padded_e1 = - b.getAffineDimExpr(e1).ceilDiv(constExpr64) * constExpr64; - h = (2 * padded_e1).floorDiv(constExpr64); + b.getAffineDimExpr(e1).ceilDiv(constE1Block) * constE1Block; + h = (2 * padded_e1).floorDiv(constE1Block); } w = b.getAffineDimExpr(e3); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::NHWC) { // (e4, e3, e2, e1) -> (e4, ceil(e1/64), e3, ceil(e2/32), 32, 64) e4 = 0; @@ -357,11 +396,11 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e2 = 2; e1 = 3; n = b.getAffineDimExpr(e4); - h = b.getAffineDimExpr(e1).floorDiv(constExpr64); + h = b.getAffineDimExpr(e1).floorDiv(constE1Block); w = b.getAffineDimExpr(e3); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::NCHW) { // (e4, e3, e2, e1) -> (e4, ceil(e2/64), e1, ceil(e3/32), 32, 64) llvm_unreachable("Not tested yet"); @@ -370,11 +409,11 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e2 = 2; e1 = 3; n = b.getAffineDimExpr(e4); - h = b.getAffineDimExpr(e2).floorDiv(constExpr64); + h = b.getAffineDimExpr(e2).floorDiv(constE1Block); w = b.getAffineDimExpr(e1); - c = b.getAffineDimExpr(e3).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e3) % constExpr32; - res64 = b.getAffineDimExpr(e2) % constExpr64; + c = b.getAffineDimExpr(e3).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e3) % constE2Block; + res64 = b.getAffineDimExpr(e2) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::HWCK) { // HWCK (e4, e3, e2, e1) -> KHWC (ceil(e1/64), e4,, e3, ceil(e2/32), // 32, 64) @@ -382,12 +421,12 @@ ZMemRefType convertZTensorToMemRefType(Type type) { e3 = 1; e2 = 2; e1 = 3; - n = b.getAffineDimExpr(e1).floorDiv(constExpr64); + n = b.getAffineDimExpr(e1).floorDiv(constE1Block); h = b.getAffineDimExpr(e4); w = b.getAffineDimExpr(e3); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; - res64 = b.getAffineDimExpr(e1) % constExpr64; + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; + res64 = b.getAffineDimExpr(e1) % constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::FICO) { // (e4, e3, e2, e1) -> (e4, 4*ceil(e1/4/64), e3, ceil(e2/32), 32, 64) assert(!ShapedType::isDynamic(shape[rank - 1]) && @@ -395,7 +434,7 @@ ZMemRefType convertZTensorToMemRefType(Type type) { "wrong concatenated dimension size"); int64_t s = shape[rank - 1] / 4; // ((s + 64 - 1) / 64) * 64; - int64_t s_pad = ceil((double)s / 64) * 64; + int64_t s_pad = ceil(static_cast(s) / 64) * 64; int64_t pad_size = s_pad - s; AffineExpr constExprS = getAffineConstantExpr(s, b.getContext()); if (rank == 2) { @@ -417,12 +456,12 @@ ZMemRefType convertZTensorToMemRefType(Type type) { h = (((rank == 2) ? shape[0] : 1) * (b.getAffineDimExpr(e1) + pad_size * (b.getAffineDimExpr(e1).floorDiv(constExprS)))) - .floorDiv(constExpr64); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; + .floorDiv(constE1Block); + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; res64 = (b.getAffineDimExpr(e1) + pad_size * (b.getAffineDimExpr(e1).floorDiv(constExprS))) % - constExpr64; + constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::ZRH) { // (e4, e3, e2, e1) -> (e4, 3*ceil(e1/4/64), e3, ceil(e2/32), 32, 64) int64_t hidden_size = shape[rank - 1]; @@ -431,7 +470,8 @@ ZMemRefType convertZTensorToMemRefType(Type type) { "in affine_map generation."); assert((hidden_size % 3) == 0 && "wrong concatenated dimension size."); int64_t s = hidden_size / 3; - int64_t s_pad = ceil((float)s / 64) * 64; // ((s + 64 - 1) / 64) * 64; + int64_t s_pad = + ceil(static_cast(s) / 64) * 64; // ((s + 64 - 1) / 64) * 64; int64_t pad_size = s_pad - s; AffineExpr constExprS = getAffineConstantExpr(s, b.getContext()); if (rank == 2) { @@ -453,12 +493,12 @@ ZMemRefType convertZTensorToMemRefType(Type type) { h = (((rank == 2) ? shape[0] : 1) * (b.getAffineDimExpr(e1) + pad_size * (b.getAffineDimExpr(e1).floorDiv(constExprS)))) - .floorDiv(constExpr64); - c = b.getAffineDimExpr(e2).floorDiv(constExpr32); - res32 = b.getAffineDimExpr(e2) % constExpr32; + .floorDiv(constE1Block); + c = b.getAffineDimExpr(e2).floorDiv(constE2Block); + res32 = b.getAffineDimExpr(e2) % constE2Block; res64 = (b.getAffineDimExpr(e1) + pad_size * (b.getAffineDimExpr(e1).floorDiv(constExprS))) % - constExpr64; + constE1Block; } else if (layout == ZTensorEncodingAttr::DataLayout::BFICO) { llvm_unreachable("Unsupported layout yet"); } else if (layout == ZTensorEncodingAttr::DataLayout::BZRH) { @@ -474,7 +514,7 @@ ZMemRefType convertZTensorToMemRefType(Type type) { dimExpr.emplace_back(res64); AffineMap smap = AffineMap::get(rank, 0, dimExpr, b.getContext()); // Output type is F16 for zAIU. - MemRefType outType = MemRefType::get(shape, b.getF16Type()); + MemRefType outType = MemRefType::get(shape, elementType); resZMemRefType.value = MemRefType::Builder(outType).setLayout(AffineMapAttr::get(smap)); resZMemRefType.layout = convertZTensorDataLayoutToStringAttr(b, layout); @@ -501,7 +541,7 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); - ZHighStickOp stickOp = cast(op); + ZHighStickOp stickOp = mlir::cast(op); ZHighStickOpAdaptor operandAdaptor(operands); Value input = operandAdaptor.getIn(); @@ -612,6 +652,137 @@ struct ZHighToZLowStickForGRUOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Lower ZHigh QuantizedStick to ZLow QuantizedStick +//===----------------------------------------------------------------------===// + +struct ZHighToZLowQuantizedStickOpLowering : public ConversionPattern { + ZHighToZLowQuantizedStickOpLowering(TypeConverter &typeConverter, + MLIRContext *ctx, bool enableSIMD, bool enableParallel) + : ConversionPattern( + typeConverter, ZHighQuantizedStickOp::getOperationName(), 1, ctx) { + this->enableSIMD = enableSIMD; + this->enableParallel = enableParallel; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + auto qstickOp = cast(op); + + ZHighQuantizedStickOpAdaptor operandAdaptor(operands); + Value X = operandAdaptor.getIn(); + Value XRecScale = operandAdaptor.getInRecScale(); + Value XOffset = operandAdaptor.getInOffset(); + StringAttr layout = qstickOp.getLayoutAttr(); + StringAttr quantizedType = qstickOp.getQuantizedTypeAttr(); + bool symmetricMode = qstickOp.getSymMode() != 0; + + MultiDialectBuilder + create(rewriter, loc); + ZHighQuantizedStickOpShapeHelper shapeHelper(op, operands, &create.krnlIE); + shapeHelper.computeShapeAndAssertOnFailure(); + + // Convert ZTensor type to MemRefType. + ZMemRefType zMemRefType = + convertZTensorToMemRefType(*op->result_type_begin()); + + Type si64Ty = rewriter.getIntegerType(64, true); + Type i8Ty = rewriter.getIntegerType(8); + Type f32Ty = rewriter.getF32Type(); + MemRefType scalarF32MemRefTy = MemRefType::get({}, f32Ty); + MemRefType scalarI8MemRefTy = MemRefType::get({}, i8Ty); + + // Attributes. + IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1); + + // Compute rec_scale and offset. + Value recScale = nullptr; + Value offset = nullptr; + if (!isNoneValue(XRecScale)) + recScale = create.krnl.load(XRecScale); + if (!isNoneValue(XOffset)) + offset = create.krnl.load(XOffset); + + // Find out more about the original input tensor. + Type inputOriginalType = op->getOperand(0).getType(); + StringAttr xLayout = getZTensorLayoutAttr(rewriter, inputOriginalType); + bool xIsZTensorOfDLF16 = (xLayout != nullptr); + + if (!recScale && !offset) { + if (symmetricMode) { + if (xIsZTensorOfDLF16) { + llvm_unreachable("Does not support symmetric quantization for a " + "ztensor at this moment"); + } + offset = create.math.constant(f32Ty, 0.0); + emitSymmetricQuantRecscaleToScalar( + rewriter, loc, op, X, 8, recScale, enableSIMD, enableParallel); + } else { + // Get layout of the defining operation of X. Do not checking that we + // have a supported z tensor, as this checking will be performed in + // emitDynamicQuantizationLinearMinMaxFromStickifiedInput, when called. + // Compute min/max. + Value inputMin, inputMax; + if (xIsZTensorOfDLF16) { + // Call will test that we can handle the specific xLayout. + emitDynamicQuantizationLinearMinMaxFromStickifiedInput(rewriter, loc, + op, X, xLayout, inputMin, inputMax, enableSIMD, enableParallel); + } else { + // Proceed with computing min/max using normal tensor of normal types. + assert(xLayout == nullptr && "expected no layout"); + emitDynamicQuantizationLinearMinMax(rewriter, loc, op, X, inputMin, + inputMax, enableSIMD, enableParallel); + } + // Compute scale & zero point. NNPA uses signed i8 so QMax is 127 and + // QMin is -128. + Value scale, quantizedOffset; + Value qMax = create.math.constant(f32Ty, 127.0); + Value qMin = create.math.constant(f32Ty, -128.0); + emitDynamicQuantizationLinearScalarParametersFromMinMax(rewriter, loc, + op, scalarF32MemRefTy, scalarI8MemRefTy, inputMin, inputMax, qMin, + qMax, scale, offset, quantizedOffset, /*want zero point*/ true, + enableParallel); + // Compute recScale. + Value one = create.math.constant(f32Ty, 1.0); + recScale = create.math.div(one, scale); + } + } + + // MemRefs for recScale and offset. + Value memrefRecScale = create.mem.alignedAlloc(scalarF32MemRefTy); + create.krnl.store(recScale, memrefRecScale); + Value memrefOffset = create.mem.alignedAlloc(scalarF32MemRefTy); + create.krnl.store(offset, memrefOffset); + + if (xIsZTensorOfDLF16) { + // Already stickified. + rewriter.replaceOp(op, {X, memrefRecScale, memrefOffset}); + return success(); + } + + // Allocate a buffer for the result MemRef. + Value alloc = insertAllocForZMemRef( + zMemRefType, shapeHelper.getOutputDims(), op, rewriter); + // Emit a ZLow operation. + if (quantizedType.getValue().equals_insensitive(QTYPE_DLFLOAT16)) { + // Use normal stickification for dlfloat16 type so that we can flexibly + // switch between compiler-generated and zdnn stick. + create.zlow.stick(X, alloc, layout, trueAttr); + } else { + create.zlow.quantizedStick( + X, memrefRecScale, memrefOffset, alloc, layout, quantizedType); + } + rewriter.replaceOp(op, {alloc, memrefRecScale, memrefOffset}); + return success(); + } + +private: + bool enableSIMD = false; + bool enableParallel = false; +}; + //===----------------------------------------------------------------------===// // Lower ZHigh Unstick to ZLow Unstick //===----------------------------------------------------------------------===// @@ -712,19 +883,36 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { affine::normalizeMemRefType(mlir::cast(zMemRefType.value)); ArrayRef normalizedShape = normalizedType.getShape(); - // Get dense resource attribute. - auto blob = mlir::cast( - stickifiedConstOp.getValue().value()) - .getRawHandle() - .getBlob(); - assert(blob && "Expecting dense resource with a valid blob"); - ArrayRef data = blob->getData(); - // Validate the stickified tensor. - int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType); - memRefSizeInBytes *= normalizedType.getNumElements(); - assert((data.size() == (uint64_t)memRefSizeInBytes) && - "The stickified tensor's buffer size and MemRef's size mismatched"); + Attribute valueAttr = stickifiedConstOp.getValueAttr(); + int64_t sizeInBytes = getMemRefEltSizeInBytes(normalizedType); + sizeInBytes *= normalizedType.getNumElements(); + if (auto denseAttr = mlir::dyn_cast_or_null(valueAttr)) { + ArrayRef data = denseAttr.getRawData(); + if (denseAttr.isSplat()) { + // Constant ztensor's buffer is tensor. + int8_t v = denseAttr.getSplatValue(); + // NNPA does not work with a splat buffer. + // Expand the memory buffer for NNPA by using DenseResourceElementsAttr. + valueAttr = getDenseResourceElementsAttrOfValue( + rewriter, stickifiedConstOp, v, sizeInBytes); + } else { + assert( + (data.size() == static_cast(sizeInBytes)) && + "The stickified tensor's buffer size and MemRef's size mismatched"); + } + } else if (auto resourceAttr = + mlir::dyn_cast_or_null( + valueAttr)) { + auto blob = resourceAttr.getRawHandle().getBlob(); + assert(blob && "Expecting dense resource with a valid blob"); + ArrayRef data = blob->getData(); + assert( + (data.size() == static_cast(sizeInBytes)) && + "The stickified tensor's buffer size and MemRef's size mismatched"); + } else { + llvm_unreachable("Unsupported ElementsAttr"); + } // Create a KrnlGlobalOp. KrnlGlobalOp constantGlobal = @@ -734,7 +922,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern { /*name=*/ rewriter.getStringAttr( "constant_stickify_" + std::to_string(constantID)), - /*value=*/stickifiedConstOp.getValueAttr(), + /*value=*/valueAttr, /*offset=*/nullptr, /*alignment=*/stickifiedConstOp.getAlignmentAttr()); @@ -839,11 +1027,26 @@ struct ZLowOpFor { using Op = ZLowExpOp; }; +template <> +struct ZLowOpFor { + using Op = ZLowSqrtOp; +}; + +template <> +struct ZLowOpFor { + using Op = ZLowInvSqrtOp; +}; + template <> struct ZLowOpFor { using Op = ZLowReluOp; }; +template <> +struct ZLowOpFor { + using Op = ZLowGeluOp; +}; + template <> struct ZLowOpFor { using Op = ZLowTanhOp; @@ -890,6 +1093,69 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin +//===----------------------------------------------------------------------===// +template +struct ZLowReduceOpFor { + using Op = void; +}; + +template <> +struct ZLowReduceOpFor { + using Op = ZLowReduceMaxOp; +}; + +template <> +struct ZLowReduceOpFor { + using Op = ZLowReduceMinOp; +}; + +template +struct ZHighToZLowReduceOpLowering : public ConversionPattern { + ZHighToZLowReduceOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + MLIRContext *context = rewriter.getContext(); + OP_TYPE reduceOp = mlir::cast(op); + Location loc = op->getLoc(); + Value data = operands[0]; + + // Helper builders. + MultiDialectBuilder + create(rewriter, loc); + + // Convert ZTensor type to MemRefType. + ZMemRefType zMemRefType = + convertZTensorToMemRefType(*op->result_type_begin()); + + // Shape helper. + ZHighReductionOpShapeHelper shapeHelper( + op, operands, &create.krnlIE); + shapeHelper.computeShapeAndAssertOnFailure(); + + // Allocate a buffer for the result MemRef. + Value alloc = insertAllocForZMemRef( + zMemRefType, shapeHelper.getOutputDims(), op, rewriter); + + // Get the original shape before it is vanished by lower passes. + DimsExpr dataDims; + create.krnlIE.getShapeAsDims(data, dataDims); + Value shape = insertShapeMemRefI64(rewriter, loc, dataDims); + + // Emit 'alloc' for work_area that is of 4K-aligned 8K bytes. + Value workArea = create.mem.alignedAlloc( + MemRefType::get({8 * 1024}, rewriter.getIntegerType(8)), gAlignment); + + // Emit a ZLow operation. + rewriter.create::Op>( + loc, data, workArea, shape, alloc, zMemRefType.layout); + rewriter.replaceOp(op, alloc); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Lower ZHigh Softmax to ZLow Softmax //===----------------------------------------------------------------------===// @@ -1018,6 +1284,50 @@ struct ZHighToZLowPool2DOpLowering : public ConversionPattern { } }; +//===----------------------------------------------------------------------===// +// Lower ZHigh LeakyRelu to ZLow LeakyRelu +//===----------------------------------------------------------------------===// + +struct ZHighToZLowLeakyReluOpLowering : public ConversionPattern { + ZHighToZLowLeakyReluOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern( + typeConverter, ZHighLeakyReluOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + ZHighLeakyReluOp leakyreluOp = llvm::dyn_cast(op); + ZHighLeakyReluOpAdaptor operandAdaptor(operands); + + // Helper builders. + MultiDialectBuilder create(rewriter, loc); + + // Convert ZTensor type to MemRefType. + ZMemRefType zMemRefType = + convertZTensorToMemRefType(*op->result_type_begin()); + + // Shape helper. + ZHighUnaryOpShapeHelper shapeHelper(op, operands, &create.krnlIE); + shapeHelper.computeShapeAndAssertOnFailure(); + SmallVector &dims = shapeHelper.getOutputDims(); + + // Allocate a buffer for the result MemRef. + Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter); + + // Get the original shape before it is vanished by lower passes. + Value shape = insertShapeMemRefI64(rewriter, loc, dims); + + // Attributes. + FloatAttr alphaVal = leakyreluOp.getAlphaAttr(); + + // Emit zlow.leakyrelu. + rewriter.create( + loc, operandAdaptor.getX(), shape, alloc, alphaVal, zMemRefType.layout); + rewriter.replaceOp(op, alloc); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Lower ZHigh MatMul to ZLow MatMul //===----------------------------------------------------------------------===// @@ -1030,6 +1340,7 @@ struct ZHighToZLowMatMulOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); + ZHighMatMulOp matmulOp = llvm::dyn_cast(op); ZHighMatMulOpAdaptor operandAdaptor(operands); // Helper builders. @@ -1055,7 +1366,8 @@ struct ZHighToZLowMatMulOpLowering : public ConversionPattern { // - 2nd item: n // - 3rd item: p // - In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) - // or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) + // or broadcasting1: X(m, n) * Y(s, n, p) + Bias(s, p) + // or broadcasting23: X(s, m, n) * Y(n, p) + Bias(p) // shape is a 1D MemRef (memref<4xindex>) whose items are: // - 1st item: s // - 2nd item: m @@ -1071,7 +1383,7 @@ struct ZHighToZLowMatMulOpLowering : public ConversionPattern { SmallVector resDims, biasDims; create.krnlIE.getShapeAsDims(alloc, resDims); ZTensorEncodingAttr::DataLayout biasLayout; - if (shapeHelper.isStacked) { + if (shapeHelper.isStacked || shapeHelper.isBroadcasted1) { // Bias type is 2DS. biasDims.emplace_back(resDims[0]); biasDims.emplace_back(resDims[2]); @@ -1087,22 +1399,196 @@ struct ZHighToZLowMatMulOpLowering : public ConversionPattern { } // Attributes. - int64_t bcast = (shapeHelper.isBroadcasted) ? -1 : 0; + int64_t bcast1 = (shapeHelper.isBroadcasted1) ? -1 : 0; + int64_t bcast23 = (shapeHelper.isBroadcasted23) ? -1 : 0; int64_t stacked = (shapeHelper.isStacked) ? -1 : 0; - IntegerAttr is_bcastAttr = - rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), bcast); + int64_t transposeA = (matmulOp.getTransposeA() != 0) ? 1 : 0; + int64_t transposeB = (matmulOp.getTransposeB() != 0) ? 1 : 0; + IntegerAttr is_bcast1Attr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), bcast1); + IntegerAttr is_bcast23Attr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), bcast23); IntegerAttr is_stackedAttr = rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), stacked); + IntegerAttr transposeAAttr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), transposeA); + IntegerAttr transposeBAttr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), transposeB); // Emit zlow.matmul. rewriter.create(loc, operandAdaptor.getX(), - operandAdaptor.getY(), bias, shapeMemRef, alloc, is_bcastAttr, - is_stackedAttr); + operandAdaptor.getY(), bias, shapeMemRef, alloc, is_bcast1Attr, + is_bcast23Attr, is_stackedAttr, transposeAAttr, transposeBAttr); rewriter.replaceOp(op, alloc); return success(); } }; +//===----------------------------------------------------------------------===// +// Lower ZHigh QuantizedMatMul to ZLow QuantizedMatMul +//===----------------------------------------------------------------------===// + +struct ZHighToZLowQuantizedMatMulOpLowering : public ConversionPattern { + ZHighToZLowQuantizedMatMulOpLowering( + TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(typeConverter, + ZHighQuantizedMatMulOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + ZHighQuantizedMatMulOp matmulOp = + llvm::dyn_cast(op); + ZHighQuantizedMatMulOpAdaptor operandAdaptor(operands); + + // Helper builders. + MultiDialectBuilder + create(rewriter, loc); + + // Compute shape. + ZHighQuantizedMatMulOpShapeHelper shapeHelper(op, operands, &create.krnlIE); + shapeHelper.computeShapeAndAssertOnFailure(); + + // Convert ZTensor type to MemRefType. + ZMemRefType zMemRefType = + convertZTensorToMemRefType(*op->result_type_begin()); + Type f32Ty = rewriter.getF32Type(); + MemRefType scalarF32MemRefTy = MemRefType::get({}, f32Ty); + + Value zero = create.math.constant(f32Ty, 0.0); + Value one = create.math.constant(f32Ty, 1.0); + + Value alloc = insertAllocForZMemRef( + zMemRefType, shapeHelper.getOutputDims(), op, rewriter); + Value outRecScale = operandAdaptor.getOutRecScaleIn(); + if (mlir::isa(outRecScale.getType())) { + outRecScale = create.mem.alignedAlloc( + MemRefType::get({}, rewriter.getF32Type()), {}); + create.krnl.store(one, outRecScale); + } + Value outOffset = operandAdaptor.getOutOffsetIn(); + if (mlir::isa(outOffset.getType())) { + outOffset = create.mem.alignedAlloc( + MemRefType::get({}, rewriter.getF32Type()), {}); + create.krnl.store(zero, outOffset); + } + + // Get the original shape before it is vanished by lower passes. + // Create a 1D MemRef containing necessary dimensions for constructing + // original shapes. + // - In case of unstacked: X(m, n) * Y(n, p) + Bias(p) + // shape is a 1D MemRef (memref<3xindex>) whose items are: + // - 1st item: m + // - 2nd item: n + // - 3rd item: p + // - In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) + // or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) + // shape is a 1D MemRef (memref<4xindex>) whose items are: + // - 1st item: s + // - 2nd item: m + // - 3rd item: n + // - 4th item: p + Value shapeMemRef = + insertShapeMemRefI64(rewriter, loc, shapeHelper.allOriginalDims); + + // Attributes. + int64_t bcast = (shapeHelper.isBroadcasted) ? -1 : 0; + int64_t stacked = (shapeHelper.isStacked) ? -1 : 0; + IntegerAttr is_bcastAttr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), bcast); + IntegerAttr is_stackedAttr = + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), stacked); + // QuantizedType attributes. + StringAttr xQTypeAttr = convertZTensorQuantizedTypeToStringAttr( + rewriter, getZTensorQuantizedType(matmulOp.getX().getType())); + StringAttr yQTypeAttr = convertZTensorQuantizedTypeToStringAttr( + rewriter, getZTensorQuantizedType(matmulOp.getY().getType())); + StringAttr outQTypeAttr = + StringAttr::get(rewriter.getContext(), QTYPE_DLFLOAT16); + StringAttr bQTypeAttr; + + // Prepare optional bias. + SmallVector resDims; + create.krnlIE.getShapeAsDims(alloc, resDims); + Value bias = operandAdaptor.getB(); + Value biasRecScale = operandAdaptor.getBRecScale(); + Value biasOffset = operandAdaptor.getBOffset(); + SmallVector bDims; + if (shapeHelper.isStacked) { + // Bias type is 2DS. + bDims.emplace_back(resDims[0]); + bDims.emplace_back(resDims[2]); + } else { + // Bias type is 1D. Get the last dim size. + bDims.emplace_back(resDims[resDims.size() - 1]); + } + ZTensorEncodingAttr::DataLayout bLayout; + ZTensorEncodingAttr::QuantizedType bQType; + if (mlir::isa(bias.getType())) { + if (shapeHelper.isStacked) { + // Bias type is 2DS. + bLayout = ZTensorEncodingAttr::DataLayout::_2DS; + } else { + // Bias type is 1D. Get the last dim size. + bLayout = ZTensorEncodingAttr::DataLayout::_1D; + } + bool preCompute = matmulOp.getPreComputedBias() != 0; + // Allocate bias. + if (preCompute) + bQType = ZTensorEncodingAttr::QuantizedType::DLFLOAT16; + else + bQType = ZTensorEncodingAttr::QuantizedType::INT8; + bQTypeAttr = convertZTensorQuantizedTypeToStringAttr(rewriter, bQType); + bias = insertAllocOrEmitZeroConstant( + bDims, bLayout, op, rewriter, loc, bQType); + } else { + Type bTensorType = matmulOp.getB().getType(); + bLayout = getZTensorLayout(bTensorType); + ZTensorEncodingAttr::QuantizedType qtype = + getZTensorQuantizedType(bTensorType); + if (qtype == ZTensorEncodingAttr::QuantizedType::UNDEFINED) { + // Bias is a non-quantized or normal ztensor. Use DLFLOAT16 type. + qtype = ZTensorEncodingAttr::QuantizedType::DLFLOAT16; + } + bQTypeAttr = convertZTensorQuantizedTypeToStringAttr(rewriter, qtype); + bQType = convertStringAttrToZTensorQuantizedType(bQTypeAttr); + } + if (mlir::isa(biasRecScale.getType())) { + biasRecScale = create.mem.alignedAlloc(scalarF32MemRefTy); + create.krnl.store(one, biasRecScale); + } + if (mlir::isa(biasOffset.getType())) { + biasOffset = create.mem.alignedAlloc(scalarF32MemRefTy); + create.krnl.store(zero, biasOffset); + } + + // Prepare a buffer for work_area. + // Work area has the same layout as bias but dlfloat16 type. + if (bDims.empty()) + create.krnlIE.getShapeAsDims(bias, bDims); + Value workArea = insertAllocForZMemRefByDim(bDims, bLayout, + ZTensorEncodingAttr::QuantizedType::DLFLOAT16, op, rewriter); + + // Emit zlow.quantizedMatmul. + // clang-format off + create.zlow.quantizedMatMul( + operandAdaptor.getX(), operandAdaptor.getXRecScale(), operandAdaptor.getXOffset(), + operandAdaptor.getY(), operandAdaptor.getYRecScale(), operandAdaptor.getYOffset(), + bias, biasRecScale, biasOffset, + workArea, shapeMemRef, + alloc, outRecScale, outOffset, + xQTypeAttr, yQTypeAttr, bQTypeAttr, outQTypeAttr, + is_bcastAttr, is_stackedAttr, + matmulOp.getPreComputedBiasAttr(), + matmulOp.getDisableClippingAttr(), + matmulOp.getDequantizeOutputAttr()); + // clang-format on + rewriter.replaceOp(op, {alloc, outRecScale, outOffset}); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Lower ZHigh LSTM to ZLow LSTM //===----------------------------------------------------------------------===// @@ -1314,7 +1800,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern { Value iZero = create.math.constantIndex(0); ValueRange batchLoop = create.krnl.defineLoops(1); create.krnl.iterate(batchLoop, batchLoop, {iZero}, {create.mem.dim(Y, 2)}, - [&](KrnlBuilder &createKrnl, ValueRange batchIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange batchIndices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value bs = batchIndices[0]; @@ -1337,7 +1823,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(®ionOp.getBodyRegion().front()); ValueRange loops = create.krnl.defineLoops(yRank - 1); create.krnl.iterate(loops, loops, yLbs, yUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value sequenceIV(indices[0]); Value directionIV(indices[1]); Value hs(indices[2]); @@ -1365,7 +1851,7 @@ struct ZHighToZLowFixGRUYOpLowering : public ConversionPattern { ValueRange loops = create.krnl.defineLoops(yRank); create.krnl.iterate(loops, loops, yLbs, yUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value sequenceIV(indices[0]); @@ -1434,7 +1920,7 @@ struct ZHighToZLowFixGRUYhOpLowering : public ConversionPattern { Value seqSize = create.mem.dim(Y, 0); ValueRange loops = create.krnl.defineLoops(htRank); create.krnl.iterate(loops, loops, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value bs(indices[1]), hs(indices[2]); @@ -1557,7 +2043,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering Location loc = op->getLoc(); MDBuilder create(rewriter, loc); - auto stickOp = cast(op); + auto stickOp = mlir::cast(op); FloatAttr value = stickOp.getValueAttr(); Type i16Ty = rewriter.getI16Type(); Type i64Ty = rewriter.getI64Type(); @@ -1565,7 +2051,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering // Convert the scalar value to dlfloat16. // Use uint16_t as container. - float valueF32 = (float)value.getValueAsDouble(); + float valueF32 = static_cast(value.getValueAsDouble()); uint16_t valueDLF16; fp32_to_dlf16(&valueF32, &valueDLF16, 1); @@ -1611,7 +2097,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering SmallVector lbs(rank, LitIE(0)); SmallVector ubs = shapeHelper.getOutputDims(); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { // Keep this load inside the loop to tweak LLVM. Value valueF16 = createKrnl.load(memrefF16); createKrnl.store(valueF16, res, indices); @@ -1700,22 +2186,19 @@ struct ZHighToZLowDataConversionLowering SmallVector flattenedOutputDims; Value flatOutput = create.mem.reshapeToFlatInnermost( alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops); - DimsExpr lbs(1, LitIE(0)); // Create loop iteration (flattened to 1D) and block it by totVL. - ValueRange loopDef = create.krnl.defineLoops(1); - ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL); - SmallVector optimizedLoopDef(1, blockedLoopDef[0]); - + DimsExpr lbs = {LitIE(0)}; + bool useParallel = false; if (enableParallel) { int64_t parId; - int64_t tripCount = - flattenedOutputDims[0].isLiteral() - ? std::ceil(flattenedOutputDims[0].getLiteral() / (float)archVL) - : -1; + int64_t tripCount = flattenedOutputDims[0].isLiteral() + ? std::ceil(flattenedOutputDims[0].getLiteral() / + static_cast(archVL)) + : -1; if (findSuitableParallelDimension(lbs, flattenedOutputDims, 0, 1, parId, /*min iter for going parallel*/ 1024)) { - create.krnl.parallel(blockedLoopDef[0]); + useParallel = true; onnxToKrnlParallelReport(op, /*successful*/ true, 0, tripCount, "dlf16-f32 conversion fully parallelized"); } else { @@ -1728,8 +2211,8 @@ struct ZHighToZLowDataConversionLowering : -1, "dlf16-f32 conversion fully flattened"); - create.krnl.iterateIE(loopDef, optimizedLoopDef, lbs, flattenedOutputDims, - [&](KrnlBuilder &b, ValueRange loopInd) { + create.krnl.forLoopIE(lbs[0], flattenedOutputDims[0], totVL, useParallel, + [&](const KrnlBuilder &b, ValueRange loopInd) { MDBuilder create(b); // Manually unrolled loop, add archVL offset at each iterations. for (int64_t u = 0; u < unrollVL; ++u) { @@ -1770,7 +2253,7 @@ struct ZHighToZLowDataConversionLowering }; void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, + mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, bool enableSIMD, bool enableParallel) { // Stickify and unstickify operations. patterns.insert(typeConverter, ctx); @@ -1794,13 +2277,22 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, // Activations patterns.insert>(typeConverter, ctx); patterns.insert>(typeConverter, ctx); + patterns.insert>( + typeConverter, ctx); patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); patterns.insert>(typeConverter, ctx); patterns.insert>( typeConverter, ctx); // Neural network operations. + patterns.insert>( + typeConverter, ctx); + patterns.insert>( + typeConverter, ctx); patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); patterns.insert(typeConverter, ctx); @@ -1814,6 +2306,10 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, patterns .insert>( typeConverter, ctx); + // Quantized operations. + patterns.insert( + typeConverter, ctx, enableSIMD, enableParallel); + patterns.insert(typeConverter, ctx); } } // namespace zhigh diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp index 021f47deb3..e8e21eefd3 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.hpp @@ -19,6 +19,8 @@ #include "mlir/IR/BuiltinTypes.h" #include "src/Dialect/Mlir/IndexExpr.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" + /// Default 4K alignment for sticked tensors. static constexpr int64_t gAlignment = 4096; @@ -42,8 +44,9 @@ mlir::Value insertShapeMemRefI64(mlir::PatternRewriter &rewriter, /// Insert an allocation for the given dimensions and layout. /// By default, set alignment to 4K. mlir::Value insertAllocForZMemRefByDim(mlir::ArrayRef dims, - mlir::Type layoutType, mlir::Operation *op, mlir::PatternRewriter &rewriter, - int64_t alignment); + ZTensorEncodingAttr::DataLayout layout, + ZTensorEncodingAttr::QuantizedType qtype, mlir::Operation *op, + mlir::PatternRewriter &rewriter, int64_t alignment); /// Insert an allocation for the given ZMemRefType. /// By default, set alignment to 4K. @@ -53,7 +56,7 @@ mlir::Value insertAllocForZMemRef(ZMemRefType zType, /// Populate all conversion patterns for ZHigh Ops. void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, + mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx, bool enableSIMD, bool enableParallel); } // namespace zhigh diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/CMakeLists.txt index fe421a2a25..50ac04750b 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/CMakeLists.txt @@ -6,6 +6,7 @@ add_onnx_mlir_library(OMZLowToLLVM libzdnn LINK_LIBS PUBLIC + OMCompilerOptions MLIRLLVMCommonConversion OMKrnlToLLVM OMLayoutHelper diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp index a9cfd73a30..2c3f8fa768 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp @@ -4,7 +4,7 @@ //===---------- ZLowToLLVM.cpp - Lowering from ZLow to LLVM ---------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -24,6 +24,8 @@ #include "src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" +#include "src/Accelerators/NNPA/Support/NNPALimit.hpp" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "zdnn.h" @@ -37,7 +39,7 @@ namespace zlow { static bool FUNC_CALL_FOR_DLF16_CONVERSION = false; static bool SIMD_FOR_DLF16_CONVERSION = true; -zdnn_data_layouts UNDEFINED_ZDNN_LAYOUT = (zdnn_data_layouts)255; +zdnn_data_layouts UNDEFINED_ZDNN_LAYOUT = static_cast(255); // Obtain a zDNN API for an elementwise ZLow operation. template <> @@ -73,10 +75,18 @@ API APIFor() { return API::ZDNN_EXP; } template <> +API APIFor() { + return API::ZDNN_INVSQRT; +} +template <> API APIFor() { return API::ZDNN_RELU; } template <> +API APIFor() { + return API::ZDNN_GELU; +} +template <> API APIFor() { return API::ZDNN_TANH; } @@ -85,6 +95,11 @@ API APIFor() { return API::ZDNN_SIGMOID; } +template <> +API APIFor() { + return API::ZDNN_SQRT; +} + class ZLowStickLowering : public mlir::ConvertToLLVMPattern { public: explicit ZLowStickLowering(MLIRContext *context, LLVMTypeConverter &lowering_, @@ -98,7 +113,9 @@ class ZLowStickLowering : public mlir::ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowStickOp stickOp = cast(op); + ZLowStickOp stickOp = mlir::cast(op); + std::optional saturationOpt = stickOp.getSaturation(); + bool saturation = saturationOpt.has_value() && saturationOpt.value() != 0; ZLowStickOpAdaptor operandAdaptor(operands); // Do not get element type from adaptor since the type can be opaque. @@ -130,8 +147,96 @@ class ZLowStickLowering : public mlir::ConvertToLLVMPattern { // Ready to stickify. Value unstickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getX()); - callApi(rewriter, loc, module, apiRegistry, API::ZDNN_TRANSFORM_ZTENSOR, - {toOpaquePtr(rewriter, loc, module, zTensor.val), unstickI8Ptr}); + if (saturation) + callApi(rewriter, loc, module, apiRegistry, + API::ZDNN_TRANSFORM_ZTENSOR_WITH_SATURATION, + {toOpaquePtr(rewriter, loc, module, zTensor.val), unstickI8Ptr}); + else + callApi(rewriter, loc, module, apiRegistry, API::ZDNN_TRANSFORM_ZTENSOR, + {toOpaquePtr(rewriter, loc, module, zTensor.val), unstickI8Ptr}); + + rewriter.eraseOp(op); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + +class ZLowQuantizedStickLowering : public mlir::ConvertToLLVMPattern { +public: + explicit ZLowQuantizedStickLowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + ZLowQuantizedStickOp::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ModuleOp module = op->getParentOfType(); + MultiDialectBuilder create(rewriter, loc); + + ZLowQuantizedStickOp stickOp = cast(op); + ZLowQuantizedStickOpAdaptor operandAdaptor(operands); + Value recScale = operandAdaptor.getRecScale(); + Value offset = operandAdaptor.getOffset(); + StringRef transformTypeStr = stickOp.getQType(); + + // Do not get element type from adaptor since the type can be opaque. + Type llvmElementTy = typeConverter->convertType( + mlir::cast(stickOp.getX().getType()).getElementType()); + Type llvmI64Ty = rewriter.getI64Type(); + Type llvmF32Ty = rewriter.getF32Type(); + + ZTensorHelper zTensorHelper = + ZTensorHelper(rewriter, loc, module, apiRegistry); + + // Get the dimensions of the original shape (the shape before stickifying) + // used for creating a zTensor. For 'zLow.quantizedStick', the original + // shape is obtained from the first argument. + SmallVector dims; + getDimsFromMemRef(rewriter, loc, module, operandAdaptor.getX(), dims); + + // Get zDNN data type. + zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); + + // Get zDNN data layout. + zdnn_data_layouts zDNNDataLayout = + convertLayoutAttrToZDNNDataLayout(dims.size(), stickOp.getLayoutAttr()); + + // Get zDNN transform type. + zdnn_quantized_transform_types transformType = + getQuantizedTransformType(transformTypeStr); + + // Create a zTensor. + Value stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getOut()); + Value recScaleF32 = loadFromMemRef(create.llvm, llvmF32Ty, recScale, 0); + Value offsetF32 = loadFromMemRef(create.llvm, llvmF32Ty, offset, 0); + ZTensor zTensor = + zTensorHelper.getQuantizedZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/zDNNDataLayout, /*transformType=*/transformType, + /*originalDims=*/dims, /*recScale=*/recScaleF32, + /*offset=*/offsetF32, + /*isTransformed=*/false); + + // Always saturate. + Value saturationVal = + create.llvm.constant(llvmI64Ty, static_cast(1)); + + // Min, Max clip values. + Value clipMIN = + create.llvm.constant(llvmI64Ty, static_cast(INT8_MIN)); + Value clipMAX = + create.llvm.constant(llvmI64Ty, static_cast(INT8_MAX)); + + // Ready to stickify. + Value unstickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getX()); + callApi(rewriter, loc, module, apiRegistry, + API::ZDNN_TRANSFORM_QUANTIZED_ZTENSOR, + {toOpaquePtr(rewriter, loc, module, zTensor.val), saturationVal, + clipMIN, clipMAX, unstickI8Ptr}); rewriter.eraseOp(op); return success(); @@ -154,7 +259,7 @@ class ZLowStickForLSTMLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowStickForLSTMOp stickForLSTMOp = cast(op); + ZLowStickForLSTMOp stickForLSTMOp = mlir::cast(op); ZLowStickForLSTMOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( @@ -240,7 +345,7 @@ class ZLowStickForGRULowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowStickForGRUOp stickForGRUOp = cast(op); + ZLowStickForGRUOp stickForGRUOp = mlir::cast(op); ZLowStickForGRUOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( @@ -324,7 +429,7 @@ class ZLowLSTMLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowLSTMOp lstmOp = cast(op); + ZLowLSTMOp lstmOp = mlir::cast(op); MultiDialectBuilder create(rewriter, loc); ZLowLSTMOpAdaptor operandAdaptor(operands); @@ -336,7 +441,7 @@ class ZLowLSTMLowering : public ConvertToLLVMPattern { // Some frequently used types and constants. Type llvmI64Ty = rewriter.getI64Type(); - Value oneI64 = create.llvm.constant(llvmI64Ty, (int64_t)1); + Value oneI64 = create.llvm.constant(llvmI64Ty, static_cast(1)); // Get the dimensions of the original shape (the shape before stickifying) // used for creating zTensors. @@ -429,11 +534,11 @@ class ZLowLSTMLowering : public ConvertToLLVMPattern { Value direction; StringRef directionStr = lstmOp.getDirection(); if (directionStr.equals_insensitive("forward")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)FWD); + direction = create.llvm.constant(llvmI64Ty, static_cast(FWD)); } else if (directionStr.equals_insensitive("reverse")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)BWD); + direction = create.llvm.constant(llvmI64Ty, static_cast(BWD)); } else if (directionStr.equals_insensitive("bidirectional")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)BIDIR); + direction = create.llvm.constant(llvmI64Ty, static_cast(BIDIR)); } else llvm_unreachable("Unsupported direction"); @@ -520,7 +625,7 @@ class ZLowGRULowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowGRUOp gruOp = cast(op); + ZLowGRUOp gruOp = mlir::cast(op); MultiDialectBuilder create(rewriter, loc); ZLowGRUOpAdaptor operandAdaptor(operands); @@ -532,7 +637,7 @@ class ZLowGRULowering : public ConvertToLLVMPattern { // Some frequently used types and constants. Type llvmI64Ty = rewriter.getI64Type(); - Value oneI64 = create.llvm.constant(llvmI64Ty, (int64_t)1); + Value oneI64 = create.llvm.constant(llvmI64Ty, static_cast(1)); // Get the dimensions of the original shape (the shape before stickifying) // used for creating zTensors. @@ -604,11 +709,11 @@ class ZLowGRULowering : public ConvertToLLVMPattern { Value direction; StringRef directionStr = gruOp.getDirection(); if (directionStr.equals_insensitive("forward")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)FWD); + direction = create.llvm.constant(llvmI64Ty, static_cast(FWD)); } else if (directionStr.equals_insensitive("reverse")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)BWD); + direction = create.llvm.constant(llvmI64Ty, static_cast(BWD)); } else if (directionStr.equals_insensitive("bidirectional")) { - direction = create.llvm.constant(llvmI64Ty, (int64_t)BIDIR); + direction = create.llvm.constant(llvmI64Ty, static_cast(BIDIR)); } else llvm_unreachable("Unsupported direction"); @@ -675,7 +780,7 @@ class ZLowUnstickLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowUnstickOp unstickOp = cast(op); + ZLowUnstickOp unstickOp = mlir::cast(op); ZLowUnstickOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( @@ -732,7 +837,7 @@ class ZLowUnaryElementwiseOpLowering : public ConvertToLLVMPattern { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); MLIRContext *context = rewriter.getContext(); - UnaryElementwiseOp unaryOp = cast(op); + UnaryElementwiseOp unaryOp = mlir::cast(op); typename UnaryElementwiseOp::Adaptor operandAdaptor(operands); MultiDialectBuilder create(rewriter, loc); @@ -782,6 +887,14 @@ class ZLowUnaryElementwiseOpLowering : public ConvertToLLVMPattern { callApi(rewriter, loc, module, apiRegistry, APIFor(), {toOpaquePtr(rewriter, loc, module, inputZTensor.val), nullpointer, toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); + } else if (APIFor() == API::ZDNN_INVSQRT) { + MultiDialectBuilder create(rewriter, loc); + // Create a float for the epsilon value. + Value epsilon = create.llvm.constant(rewriter.getF32Type(), nnpaEpsilon); + // Pass to ZDNN. + callApi(rewriter, loc, module, apiRegistry, APIFor(), + {toOpaquePtr(rewriter, loc, module, inputZTensor.val), epsilon, + toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); } else { callApi(rewriter, loc, module, apiRegistry, APIFor(), {toOpaquePtr(rewriter, loc, module, inputZTensor.val), @@ -810,7 +923,7 @@ class ZLowBinaryElementwiseOpLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - BinaryElementwiseOp binaryOp = cast(op); + BinaryElementwiseOp binaryOp = mlir::cast(op); typename BinaryElementwiseOp::Adaptor operandAdaptor(operands); Value input1 = operandAdaptor.getX(); @@ -888,7 +1001,7 @@ class ZLowSoftmaxOpLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowSoftmaxOp softmaxOp = cast(op); + ZLowSoftmaxOp softmaxOp = mlir::cast(op); MultiDialectBuilder create(rewriter, loc); ZLowSoftmaxOpAdaptor operandAdaptor(operands); @@ -925,8 +1038,8 @@ class ZLowSoftmaxOpLowering : public ConvertToLLVMPattern { actType = NNPA_SOFTMAX_LOG; else llvm_unreachable("Unsupported activation function"); - Value actFunc = - create.llvm.constant(rewriter.getI64Type(), (int64_t)actType); + Value actFunc = create.llvm.constant( + rewriter.getI64Type(), static_cast(actType)); // Create the output zTensor. stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getOut()); @@ -958,6 +1071,188 @@ class ZLowSoftmaxOpLowering : public ConvertToLLVMPattern { ApiRegistry apiRegistry; }; +class ZLowLeakyReluLowering : public ConvertToLLVMPattern { +public: + explicit ZLowLeakyReluLowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + ZLowLeakyReluOp::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + Location loc = op->getLoc(); + ZLowLeakyReluOp leakyreluOp = cast(op); + MultiDialectBuilder create(rewriter, loc); + MLIRContext *context = rewriter.getContext(); + typename ZLowLeakyReluOp::Adaptor operandAdaptor(operands); + + Value input = operandAdaptor.getX(); + Value shape = operandAdaptor.getShape(); + Value output = operandAdaptor.getOut(); + Type llvmElementTy = typeConverter->convertType( + mlir::cast(op->getOperand(0).getType()).getElementType()); + + ZTensorHelper zTensorHelper = + ZTensorHelper(rewriter, loc, module, apiRegistry); + + // Get zDNN data type. + zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); + + // Get zDNN data layout. + zdnn_data_layouts zDNNDataLayout = + convertLayoutAttrToZDNNDataLayout(0, leakyreluOp.getLayoutAttr()); + + // Get the dimensions of the original shape (the shape before stickifying) + // used for creating a zTensor. + std::vector dims = + getDimsFromShapeMemRef(rewriter, loc, module, shape, + /*layout=*/zDNNDataLayout); + + // Create an input zTensor. + Value stickI8Ptr = zTensorHelper.getAlignedI8Ptr(input); + ZTensor inputZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/zDNNDataLayout, /*originalDims=*/dims, + /*isTransformed=*/true); + + // Create an output zTensor. + stickI8Ptr = zTensorHelper.getAlignedI8Ptr(output); + ZTensor outputZTensor = zTensorHelper.getZTensor( + /*preTransformedDescPtr=*/inputZTensor.preTransformedDescPtr, + /*transformedDescPtr=*/inputZTensor.transformedDescPtr, + /*bufferSize=*/inputZTensor.bufferSize, + /*alignedBuffer=*/stickI8Ptr, + /*isTransformed=*/true); + + // Create the clipping value as null because the zDNN LeakyRelu API does not + // use it. + Value clippingVal = create.llvm.null(krnl::getI8PointerType(context)); + + // Create the adjustment factor value from the input alpha attribute. + FloatAttr alphaAttr = leakyreluOp.getAlphaAttr(); + float alphaFloat = (float)alphaAttr.getValueAsDouble(); + Value adjustmentFactorVal = + create.llvm.constant(rewriter.getF32Type(), alphaFloat); + + // Call the zDNN LeakyRelu API. + callApi(rewriter, loc, module, apiRegistry, API::ZDNN_LEAKY_RELU, + {toOpaquePtr(rewriter, loc, module, inputZTensor.val), clippingVal, + adjustmentFactorVal, + toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); + + rewriter.eraseOp(op); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + +template +zdnn_reduce_ops getZDNNReduceOpType() { + return REDUCE_OP_MAXIMUM; +} + +template <> +zdnn_reduce_ops getZDNNReduceOpType() { + return REDUCE_OP_MAXIMUM; +} + +template <> +zdnn_reduce_ops getZDNNReduceOpType() { + return REDUCE_OP_MINIMUM; +} + +template +class ZLowReduceLowering : public ConvertToLLVMPattern { +public: + explicit ZLowReduceLowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + REDUCE_OP::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + Location loc = op->getLoc(); + REDUCE_OP reduceOp = mlir::cast(op); + MultiDialectBuilder create(rewriter, loc); + typename REDUCE_OP::Adaptor operandAdaptor(operands); + + Value data = operandAdaptor.getX(); + Value shape = operandAdaptor.getShape(); + Value output = operandAdaptor.getOut(); + Type llvmElementTy = typeConverter->convertType( + mlir::cast(op->getOperand(0).getType()).getElementType()); + + ZTensorHelper zTensorHelper = + ZTensorHelper(rewriter, loc, module, apiRegistry); + + // Get zDNN data type. + zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); + + // Get zDNN data layout. + zdnn_data_layouts zDNNDataLayout = + convertLayoutAttrToZDNNDataLayout(0, reduceOp.getLayoutAttr()); + + // Get the dimensions of the original shape (the shape before stickifying) + // used for creating a zTensor. + std::vector dims = + getDimsFromShapeMemRef(rewriter, loc, module, shape, + /*layout=*/zDNNDataLayout); + + Type llvmI64Ty = rewriter.getI64Type(); + Value one = create.llvm.constant(llvmI64Ty, static_cast(1)); + + // Calculation for the output dimension + int64_t axis = dims.size() - 1; + SmallVector outputDims; + for (int64_t i = 0; i < axis; ++i) { + outputDims.emplace_back(dims[i]); + } + outputDims.emplace_back(one); + + // Create an input zTensor. + Value stickI8Ptr = zTensorHelper.getAlignedI8Ptr(data); + ZTensor inputZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/zDNNDataLayout, /*originalDims=*/dims, + /*isTransformed=*/true); + + // Create an output zTensor. + stickI8Ptr = zTensorHelper.getAlignedI8Ptr(output); + ZTensor outputZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/zDNNDataLayout, /*originalDims=*/outputDims, + /*isTransformed=*/true); + + // work_area. + Value workArea = + zTensorHelper.getAlignedI8Ptr(operandAdaptor.getWorkArea()); + + // op_type + zdnn_reduce_ops zdnnOpType = getZDNNReduceOpType(); + Value opType = create.llvm.constant( + rewriter.getI64Type(), static_cast(zdnnOpType)); + + // Call the zDNN ReduceMax/ReduceMin API. + callApi(rewriter, loc, module, apiRegistry, API::ZDNN_REDUCE, + {toOpaquePtr(rewriter, loc, module, inputZTensor.val), workArea, opType, + toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); + + rewriter.eraseOp(op); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + class ZLowMatMulLowering : public ConvertToLLVMPattern { public: explicit ZLowMatMulLowering(MLIRContext *context, @@ -971,22 +1266,25 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowMatMulOp matmulOp = cast(op); + ZLowMatMulOp matmulOp = mlir::cast(op); MultiDialectBuilder create(rewriter, loc); ZLowMatMulOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( mlir::cast(matmulOp.getX().getType()).getElementType()); - bool stacked, broadcasting; + bool stacked = false, broadcasting1 = false, broadcasting23 = false, + transposeA = false, transposeB = false; if (matmulOp.getIsStacked() == -1) stacked = true; - else - stacked = false; - if (matmulOp.getIsBcast() == -1) - broadcasting = true; - else - broadcasting = false; + if (matmulOp.getIsBcast1() == -1) + broadcasting1 = true; + else if (matmulOp.getIsBcast23() == -1) + broadcasting23 = true; + if (matmulOp.getTransposeA() != 0) + transposeA = true; + if (matmulOp.getTransposeB() != 0) + transposeB = true; ZTensorHelper zTensorHelper = ZTensorHelper(rewriter, loc, module, apiRegistry); @@ -997,17 +1295,22 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { // Get the dimensions of the original shape (the shape before stickifying) // used for creating zTensors. int dimCount = 3; - if (stacked || broadcasting) + if (stacked || broadcasting1 || broadcasting23) dimCount = 4; std::vector dims = getDimsFromShapeMemRefBySize( rewriter, loc, module, operandAdaptor.getShape(), /*size=*/dimCount); // Dimensions: s, m, n, p; Value S, M, N, P; - if (stacked || broadcasting) { + if (stacked || broadcasting23) { S = dims[0]; M = dims[1]; N = dims[2]; P = dims[3]; + } else if (broadcasting1) { + M = dims[0]; + N = dims[1]; + S = dims[2]; + P = dims[3]; } else { M = dims[0]; N = dims[1]; @@ -1019,49 +1322,111 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { // Create zTensors. ZTensor xZTensor, yZTensor, biasZTensor, outputZTensor; + + // clang-format off + // Requirements + // Type X Y Bias Output + // ---------------------------------------------------------------------------------------- + // unstacked ZDNN_2D (m, n) ZDNN_2D (n, p) ZDNN_1D (p) ZDNN_2D (m, p) + // stacked ZDNN_3DS (s, m, n) ZDNN_3DS (s, n, p) ZDNN_2DS (s, p) ZDNN_3DS (s, m, p) + // bcast1 ZDNN_2D (m, n) ZDNN_3DS (s, n, p) ZDNN_2DS (s, p) ZDNN_3DS (s, m, p) + // bcast23 ZDNN_3DS (s, m, n) ZDNN_2D (n, p) ZDNN_1D (p) ZDNN_3DS (s, m, p) + // clang-format on + // X Value stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getX()); - if (stacked || broadcasting) - xZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, - /*layout=*/ZDNN_3DS, /*originalDims=*/{S, M, N}, - /*isTransformed=*/true); - else - xZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, - /*layout=*/ZDNN_2D, /*originalDims=*/{M, N}, - /*isTransformed=*/true); + if (stacked || broadcasting23) { + if (transposeA) + // ZDNN_3DS (s, n, m) + xZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_3DS, /*originalDims=*/{S, N, M}, + /*isTransformed=*/true); + else + // ZDNN_3DS (s, m, n) + xZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_3DS, /*originalDims=*/{S, M, N}, + /*isTransformed=*/true); + } else { /* unstacked || broadcasting1 */ + if (transposeA) + // ZDNN_2D (n, m) + xZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2D, /*originalDims=*/{N, M}, + /*isTransformed=*/true); + else + // ZDNN_2D (m, n) + xZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2D, /*originalDims=*/{M, N}, + /*isTransformed=*/true); + } // Y stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getY()); - if (stacked) - yZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, - /*layout=*/ZDNN_3DS, /*originalDims=*/{S, N, P}, - /*isTransformed=*/true); - else - yZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, - /*layout=*/ZDNN_2D, /*originalDims=*/{N, P}, - /*isTransformed=*/true); + if (stacked || broadcasting1) { + if (transposeB) + // ZDNN_3DS (s, p, n) + yZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_3DS, /*originalDims=*/{S, P, N}, + /*isTransformed=*/true); + else + // ZDNN_3DS (s, n, p) + yZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_3DS, /*originalDims=*/{S, N, P}, + /*isTransformed=*/true); + } else { /* unstacked || broadcasting23 */ + if (transposeB) + // ZDNN_2D (p, n) + yZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2D, /*originalDims=*/{P, N}, + /*isTransformed=*/true); + else + // ZDNN_2D (n, p) + yZTensor = + zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2D, /*originalDims=*/{N, P}, + /*isTransformed=*/true); + } // Bias stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getBias()); - if (stacked) + if (stacked || broadcasting1) + // ZDNN_2D (s, p) biasZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, /*layout=*/ZDNN_2DS, /*originalDims=*/{S, P}, /*isTransformed=*/true); else + // ZDNN_1D (p) biasZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, /*layout=*/ZDNN_1D, /*originalDims=*/{P}, /*isTransformed=*/true); // Op_type - Value op_type; - if (broadcasting) - op_type = create.llvm.constant( - llvmI64Ty, (int64_t)NNPA_MATMUL_BCAST_OP_ADDITION); + Value opType; + if (broadcasting23 || broadcasting1) + opType = create.llvm.constant( + llvmI64Ty, static_cast(NNPA_MATMUL_BCAST_OP_ADDITION)); + else + opType = create.llvm.constant( + llvmI64Ty, static_cast(NNPA_MATMUL_OP_ADDITION)); + // Transposing + Value transposeAVal; + if (transposeA) + transposeAVal = create.llvm.constant(llvmI64Ty, static_cast(1)); else - op_type = - create.llvm.constant(llvmI64Ty, (int64_t)NNPA_MATMUL_OP_ADDITION); + transposeAVal = create.llvm.constant(llvmI64Ty, static_cast(0)); + Value transposeBVal; + if (transposeB) + transposeBVal = create.llvm.constant(llvmI64Ty, static_cast(1)); + else + transposeBVal = create.llvm.constant(llvmI64Ty, static_cast(0)); // Output stickI8Ptr = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getOut()); - if (stacked || broadcasting) + if (stacked || broadcasting23 || broadcasting1) outputZTensor = zTensorHelper.getZTensor(stickI8Ptr, /*dataType=*/zDNNDataType, /*layout=*/ZDNN_3DS, /*originalDims=*/{S, M, P}, @@ -1073,17 +1438,24 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { /*isTransformed=*/true); // Ready to call zDNN MatMul. - if (broadcasting) { + if (transposeA || transposeB) { + callApi(rewriter, loc, module, apiRegistry, API::ZDNN_MATMUL_TRANSPOSE_OP, + {toOpaquePtr(rewriter, loc, module, xZTensor.val), + toOpaquePtr(rewriter, loc, module, yZTensor.val), + toOpaquePtr(rewriter, loc, module, biasZTensor.val), + transposeAVal, transposeBVal, opType, + toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); + } else if (broadcasting23 || broadcasting1) { callApi(rewriter, loc, module, apiRegistry, API::ZDNN_MATMUL_BCAST_OP, {toOpaquePtr(rewriter, loc, module, xZTensor.val), toOpaquePtr(rewriter, loc, module, yZTensor.val), - toOpaquePtr(rewriter, loc, module, biasZTensor.val), op_type, + toOpaquePtr(rewriter, loc, module, biasZTensor.val), opType, toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); } else { callApi(rewriter, loc, module, apiRegistry, API::ZDNN_MATMUL_OP, {toOpaquePtr(rewriter, loc, module, xZTensor.val), toOpaquePtr(rewriter, loc, module, yZTensor.val), - toOpaquePtr(rewriter, loc, module, biasZTensor.val), op_type, + toOpaquePtr(rewriter, loc, module, biasZTensor.val), opType, toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); } @@ -1095,6 +1467,230 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern { ApiRegistry apiRegistry; }; +class ZLowQuantizedMatMulLowering : public ConvertToLLVMPattern { +public: + explicit ZLowQuantizedMatMulLowering(MLIRContext *context, + LLVMTypeConverter &lowering_, ApiRegistry apiRegistry) + : ConvertToLLVMPattern( + ZLowQuantizedMatMulOp::getOperationName(), context, lowering_) { + this->apiRegistry = apiRegistry; + } + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ModuleOp module = op->getParentOfType(); + MLIRContext *context = module.getContext(); + MultiDialectBuilder create(rewriter, loc); + + ZLowQuantizedMatMulOp matmulOp = cast(op); + ZLowQuantizedMatMulOpAdaptor operandAdaptor(operands); + + // Inputs. + // X + Value X = operandAdaptor.getX(); + Value XRecScale = operandAdaptor.getXRecScale(); + Value XOffset = operandAdaptor.getXOffset(); + StringRef XQType = matmulOp.getXQType(); + // Y + Value Y = operandAdaptor.getY(); + Value YRecScale = operandAdaptor.getYRecScale(); + Value YOffset = operandAdaptor.getYOffset(); + StringRef YQType = matmulOp.getYQType(); + // Bias + Value Bias = operandAdaptor.getBias(); + Value BiasRecScale = operandAdaptor.getBiasRecScale(); + Value BiasOffset = operandAdaptor.getBiasOffset(); + StringRef BiasQType = matmulOp.getBiasQType(); + // Out + Value Out = operandAdaptor.getOut(); + Value OutRecScale = operandAdaptor.getOutRecScale(); + Value OutOffset = operandAdaptor.getOutOffset(); + StringRef OutQType = matmulOp.getOutQType(); + + // Types. + Type llvmXElementTy = typeConverter->convertType( + mlir::cast(matmulOp.getX().getType()).getElementType()); + Type llvmYElementTy = typeConverter->convertType( + mlir::cast(matmulOp.getY().getType()).getElementType()); + Type llvmBiasElementTy = typeConverter->convertType( + mlir::cast(matmulOp.getBias().getType()).getElementType()); + Type llvmOutElementTy = typeConverter->convertType( + mlir::cast(matmulOp.getOut().getType()).getElementType()); + Type llvmF32Ty = rewriter.getF32Type(); + Type llvmI64Ty = rewriter.getI64Type(); + Type llvmZTensorTy = getZTensorStructTy(context); + Type llvmZTensorPtrTy = krnl::getPointerType(context, llvmZTensorTy); + + bool stacked, broadcasting; + if (matmulOp.getIsStacked() == -1) + stacked = true; + else + stacked = false; + if (matmulOp.getIsBcast() == -1) + broadcasting = true; + else + broadcasting = false; + + // Get the dimensions of the original shape (the shape before stickifying) + // used for creating zTensors. + int dimCount = 3; + if (stacked || broadcasting) + dimCount = 4; + std::vector dims = getDimsFromShapeMemRefBySize( + rewriter, loc, module, operandAdaptor.getShape(), /*size=*/dimCount); + // Dimensions: s, m, n, p; + Value S, M, N, P; + if (stacked || broadcasting) { + S = dims[0]; + M = dims[1]; + N = dims[2]; + P = dims[3]; + } else { + M = dims[0]; + N = dims[1]; + P = dims[2]; + } + + // Create zTensors. + ZTensorHelper zTensorHelper = + ZTensorHelper(rewriter, loc, module, apiRegistry); + ZTensor xZTensor, yZTensor, biasZTensor, outputZTensor; + // X + zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmXElementTy); + zdnn_quantized_transform_types zDNNQType = + getQuantizedTransformType(XQType); + Value recScale = loadFromMemRef(create.llvm, llvmF32Ty, XRecScale, 0); + Value offset = loadFromMemRef(create.llvm, llvmF32Ty, XOffset, 0); + Value stickI8Ptr = zTensorHelper.getAlignedI8Ptr(X); + if (stacked || broadcasting) + xZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, /*layout=*/ZDNN_3DS, + /*transformType=*/zDNNQType, + /*originalDims=*/{S, M, N}, + /*recScale=*/recScale, /*offset=*/offset, + /*isTransformed=*/true); + else + xZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, /*layout=*/ZDNN_2D, + /*transformType=*/zDNNQType, + /*originalDims=*/{M, N}, /*recScale=*/recScale, /*offset=*/offset, + /*isTransformed=*/true); + // Y + zDNNDataType = llvmTypeToZDNNType(llvmYElementTy); + zDNNQType = getQuantizedTransformType(YQType); + recScale = loadFromMemRef(create.llvm, llvmF32Ty, YRecScale, 0); + offset = loadFromMemRef(create.llvm, llvmF32Ty, YOffset, 0); + stickI8Ptr = zTensorHelper.getAlignedI8Ptr(Y); + if (stacked) + yZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, /*layout=*/ZDNN_3DS, + /*transformType=*/zDNNQType, /*originalDims=*/{S, N, P}, + /*recScale=*/recScale, /*offset=*/offset, /*isTransformed=*/true); + else + yZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, /*layout=*/ZDNN_2D, + /*transformType=*/zDNNQType, /*originalDims=*/{N, P}, + /*recScale=*/recScale, /*offset=*/offset, /*isTransformed=*/true); + // Bias + zDNNDataType = llvmTypeToZDNNType(llvmBiasElementTy); + zDNNQType = getQuantizedTransformType(BiasQType); + recScale = loadFromMemRef(create.llvm, llvmF32Ty, BiasRecScale, 0); + offset = loadFromMemRef(create.llvm, llvmF32Ty, BiasOffset, 0); + stickI8Ptr = zTensorHelper.getAlignedI8Ptr(Bias); + if (stacked) + biasZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2DS, + /*transformType=*/zDNNQType, /*originalDims=*/{S, P}, + /*recScale=*/recScale, /*offset=*/offset, /*isTransformed=*/true); + else + biasZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, /*layout=*/ZDNN_1D, + /*transformType=*/zDNNQType, /*originalDims=*/{P}, + /*recScale=*/recScale, /*offset=*/offset, /*isTransformed=*/true); + + // Op_type + Value opType = create.llvm.constant( + llvmI64Ty, static_cast(NNPA_MATMUL_OP_ADDITION)); + + // Min, Max clip values. + Value clipMIN = + create.llvm.constant(llvmI64Ty, static_cast(INT8_MIN)); + Value clipMAX = + create.llvm.constant(llvmI64Ty, static_cast(INT8_MAX)); + + // work_area. + Value workArea; + if (mlir::isa(matmulOp.getWorkArea().getType())) + workArea = create.llvm.null(krnl::getI8PointerType(context)); + else + workArea = zTensorHelper.getAlignedI8Ptr(operandAdaptor.getWorkArea()); + + // Output + zDNNDataType = llvmTypeToZDNNType(llvmOutElementTy); + zDNNQType = getQuantizedTransformType(OutQType); + recScale = loadFromMemRef(create.llvm, llvmF32Ty, OutRecScale, 0); + offset = loadFromMemRef(create.llvm, llvmF32Ty, OutOffset, 0); + stickI8Ptr = zTensorHelper.getAlignedI8Ptr(Out); + if (stacked || broadcasting) + outputZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_3DS, + /*transformType=*/zDNNQType, + /*originalDims=*/{S, M, P}, + /*recScale=*/recScale, /*offset=*/offset, + /*isTransformed=*/true); + else + outputZTensor = zTensorHelper.getQuantizedZTensor(stickI8Ptr, + /*dataType=*/zDNNDataType, + /*layout=*/ZDNN_2D, + /*transformType=*/zDNNQType, + /*originalDims=*/{M, P}, + /*recScale=*/recScale, /*offset=*/offset, + /*isTransformed=*/true); + + // Ready to call zDNN MatMul. + Value disableClipping = create.llvm.constant( + llvmI64Ty, static_cast(matmulOp.getDisableClipping())); + Value dequantizeOutput = create.llvm.constant( + llvmI64Ty, static_cast(matmulOp.getDequantizeOutput())); + Value preComputedBias = create.llvm.constant( + llvmI64Ty, static_cast(matmulOp.getPreComputedBias())); + zlow::API apiName = API::ZDNN_QUANTIZED_MATMUL_OP; + callApi(rewriter, loc, module, apiRegistry, apiName, + {/*input_a=*/toOpaquePtr(rewriter, loc, module, xZTensor.val), + /*input_b=*/toOpaquePtr(rewriter, loc, module, yZTensor.val), + /*input_c=*/toOpaquePtr(rewriter, loc, module, biasZTensor.val), + /*op_type=*/opType, + /*clip_min=*/clipMIN, + /*clip_max=*/clipMAX, + /*disable_clipping=*/disableClipping, + /*dequantized=*/dequantizeOutput, + /*pre_computed=*/preComputedBias, + /*work_area=*/workArea, + /*output=*/ + toOpaquePtr(rewriter, loc, module, outputZTensor.val)}); + + // Store the output rec_scale. + Value recScalePtr = create.llvm.getElemPtr(llvmZTensorPtrTy, llvmZTensorTy, + outputZTensor.val, ArrayRef{0, 6}); + Value outRecScale = create.llvm.load(llvmF32Ty, recScalePtr); + storeToMemRef(create.llvm, outRecScale, OutRecScale, 0); + // Store the output offset. + Value offsetPtr = create.llvm.getElemPtr(llvmZTensorPtrTy, llvmZTensorTy, + outputZTensor.val, ArrayRef{0, 7}); + Value outOffset = create.llvm.load(llvmF32Ty, offsetPtr); + storeToMemRef(create.llvm, outOffset, OutOffset, 0); + + rewriter.eraseOp(op); + return success(); + } + +private: + ApiRegistry apiRegistry; +}; + class ZLowConv2DLowering : public ConvertToLLVMPattern { public: explicit ZLowConv2DLowering(MLIRContext *context, @@ -1109,7 +1705,7 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); MLIRContext *context = rewriter.getContext(); - ZLowConv2DOp convOp = cast(op); + ZLowConv2DOp convOp = mlir::cast(op); ZLowConv2DOpAdaptor operandAdaptor(operands); MultiDialectBuilder create(rewriter, loc); @@ -1145,10 +1741,12 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { convOp.getKernelShape().getValue(); // kernel height Value KH = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(kernelShapeArrayAttr[0]).getInt()); + static_cast( + mlir::cast(kernelShapeArrayAttr[0]).getInt())); // kernel width Value KW = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(kernelShapeArrayAttr[1]).getInt()); + static_cast( + mlir::cast(kernelShapeArrayAttr[1]).getInt())); // Get zDNN data type. zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); @@ -1178,28 +1776,30 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern { Value paddingType; if (convOp.getPaddingType().equals_insensitive("SAME_PADDING")) paddingType = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_pool_padding::SAME_PADDING); + llvmI64Ty, static_cast(zdnn_pool_padding::SAME_PADDING)); else if (convOp.getPaddingType().equals_insensitive("VALID_PADDING")) paddingType = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_pool_padding::VALID_PADDING); + llvmI64Ty, static_cast(zdnn_pool_padding::VALID_PADDING)); else llvm_unreachable("Unsupported padding type"); // Strides ArrayRef strideArrayAttr = convOp.getStrides().getValue(); - Value strideHeight = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(strideArrayAttr[0]).getInt()); - Value strideWidth = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(strideArrayAttr[1]).getInt()); + Value strideHeight = create.llvm.constant( + llvmI64Ty, static_cast( + mlir::cast(strideArrayAttr[0]).getInt())); + Value strideWidth = create.llvm.constant( + llvmI64Ty, static_cast( + mlir::cast(strideArrayAttr[1]).getInt())); // Activation function. Value actFunc; if (convOp.getActFunc().equals_insensitive("ACT_NONE")) actFunc = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_conv2d_act::CONV2D_ACT_NONE); + llvmI64Ty, static_cast(zdnn_conv2d_act::CONV2D_ACT_NONE)); else if (convOp.getActFunc().equals_insensitive("ACT_RELU")) actFunc = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_conv2d_act::CONV2D_ACT_RELU); + llvmI64Ty, static_cast(zdnn_conv2d_act::CONV2D_ACT_RELU)); else llvm_unreachable("Unsupported activation function"); @@ -1256,7 +1856,7 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - POOLOP poolOp = cast(op); + POOLOP poolOp = mlir::cast(op); typename POOLOP::Adaptor operandAdaptor(operands); MultiDialectBuilder create(rewriter, loc); @@ -1293,10 +1893,12 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { poolOp.getKernelShape().getValue(); // kernel height Value KH = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(kernelShapeArrayAttr[0]).getInt()); + static_cast( + mlir::cast(kernelShapeArrayAttr[0]).getInt())); // kernel width Value KW = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(kernelShapeArrayAttr[1]).getInt()); + static_cast( + mlir::cast(kernelShapeArrayAttr[1]).getInt())); // Get zDNN data type. zdnn_data_types zDNNDataType = llvmTypeToZDNNType(llvmElementTy); @@ -1312,19 +1914,21 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern { Value paddingType; if (poolOp.getPaddingType().equals_insensitive("SAME_PADDING")) paddingType = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_pool_padding::SAME_PADDING); + llvmI64Ty, static_cast(zdnn_pool_padding::SAME_PADDING)); else if (poolOp.getPaddingType().equals_insensitive("VALID_PADDING")) paddingType = create.llvm.constant( - llvmI64Ty, (int64_t)zdnn_pool_padding::VALID_PADDING); + llvmI64Ty, static_cast(zdnn_pool_padding::VALID_PADDING)); else llvm_unreachable("Unsupported padding type"); // Strides ArrayRef strideArrayAttr = poolOp.getStrides().getValue(); - Value strideHeight = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(strideArrayAttr[0]).getInt()); - Value strideWidth = create.llvm.constant(llvmI64Ty, - (int64_t)mlir::cast(strideArrayAttr[1]).getInt()); + Value strideHeight = create.llvm.constant( + llvmI64Ty, static_cast( + mlir::cast(strideArrayAttr[0]).getInt())); + Value strideWidth = create.llvm.constant( + llvmI64Ty, static_cast( + mlir::cast(strideArrayAttr[1]).getInt())); // Create zTensor for output. stickI8Ptr = zTensorHelper.getAlignedI8Ptr(output); @@ -1360,7 +1964,7 @@ class ZLowMeanReduce2DLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowMeanReduce2DOp meanOp = cast(op); + ZLowMeanReduce2DOp meanOp = mlir::cast(op); MultiDialectBuilder create(rewriter, loc); ZLowMeanReduce2DOpAdaptor operandAdaptor(operands); @@ -1372,7 +1976,7 @@ class ZLowMeanReduce2DLowering : public ConvertToLLVMPattern { // Some frequently used types and constants. Type llvmI64Ty = rewriter.getI64Type(); - Value oneI64 = create.llvm.constant(llvmI64Ty, (int64_t)1); + Value oneI64 = create.llvm.constant(llvmI64Ty, static_cast(1)); // Get the dimensions of the original shape (the shape before stickifying) // used for creating zTensors. @@ -1429,7 +2033,7 @@ class ZLowBatchNormLowering : public ConvertToLLVMPattern { ConversionPatternRewriter &rewriter) const override { ModuleOp module = op->getParentOfType(); Location loc = op->getLoc(); - ZLowBatchNormOp batchnormOp = cast(op); + ZLowBatchNormOp batchnormOp = mlir::cast(op); ZLowBatchNormOpAdaptor operandAdaptor(operands); Type llvmElementTy = typeConverter->convertType( @@ -1534,10 +2138,12 @@ class ZLowDLF16ToF32Lowering : public ConvertToLLVMPattern { Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4); Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4); - // SIMD instruction in string for z/Linux. + // SIMD instruction in string for z/Linux and z/OS. // Convert and lengthen from DLF16: VCLFN(H/L) V1,V2,M3,M4 // M3 = 2 = FP32, M4 = 0 = DLF16 - const char *asmStr = "VCLFNH $0,$2,2,0 \n\t VCLFNL $1,$2,2,0 \n\t"; + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " VCLFNH $0,$2,2,0 \n\t" + " VCLFNL $1,$2,2,0 \n\t"; const char *asmConstraints = "=&v,=v,v"; // Prepare the input vector. @@ -1571,10 +2177,10 @@ class ZLowDLF16ToF32Lowering : public ConvertToLLVMPattern { // https://github.com/tungld/onnx-mlir-tools/blob/main/convert_dlf16_to_f32.cpp Value inputI32 = create.llvm.zext(i32Ty, inputI16); // ~DLF16_SIGN - Value c32767 = create.llvm.constant(i32Ty, (int64_t)32767); + Value c32767 = create.llvm.constant(i32Ty, static_cast(32767)); // dlf16 & ~DLF16_SIGN Value v19 = create.llvm.andi(inputI32, c32767); - Value c0 = create.llvm.constant(i32Ty, (int64_t)0); + Value c0 = create.llvm.constant(i32Ty, static_cast(0)); // Split the block right before the current op into two blocks. Block *currentBlock = rewriter.getInsertionBlock(); @@ -1602,9 +2208,11 @@ class ZLowDLF16ToF32Lowering : public ConvertToLLVMPattern { // Emit code for zero case. rewriter.setInsertionPointToEnd(trueBlock); - Value cf0 = create.llvm.constant(f32Ty, (float)0.000000e+00); - Value cfm0 = create.llvm.constant(f32Ty, (float)-0.000000e+00); - Value c32768 = create.llvm.constant(i32Ty, (int64_t)32768); + Value cf0 = + create.llvm.constant(f32Ty, static_cast(0.000000e+00)); + Value cfm0 = + create.llvm.constant(f32Ty, static_cast(-0.000000e+00)); + Value c32768 = create.llvm.constant(i32Ty, static_cast(32768)); Value v20 = create.llvm.andi(inputI32, c32768); Value v21 = create.llvm.icmp(LLVM::ICmpPredicate::eq, v20, c0); Value v22 = create.llvm.select(v21, cf0, cfm0); @@ -1618,21 +2226,25 @@ class ZLowDLF16ToF32Lowering : public ConvertToLLVMPattern { condBlock->splitBlock(rewriter.getInsertionPoint()); rewriter.setInsertionPointToEnd(condBlock); - Value nan = create.llvm.constant(f32Ty, (float)0x7FC00000); - Value inf = create.llvm.constant(i32Ty, (int64_t)32767); + Value nan = + create.llvm.constant(f32Ty, static_cast(0x7FC00000)); + Value inf = create.llvm.constant(i32Ty, static_cast(32767)); Value v19Inf = create.llvm.icmp(LLVM::ICmpPredicate::eq, v19, inf); // Emit `if (v19 == inf) then endBlock(nan) else defaultBlock` create.llvm.condBr(v19Inf, endBlock, {nan}, defaultBlock, {}); // Emit code for non-infinity case. rewriter.setInsertionPointToEnd(defaultBlock); - Value c14 = create.llvm.constant(i32Ty, (int64_t)14); - Value c16 = create.llvm.constant(i32Ty, (int64_t)16); + Value c14 = create.llvm.constant(i32Ty, static_cast(14)); + Value c16 = create.llvm.constant(i32Ty, static_cast(16)); Value cm2147483648 = - create.llvm.constant(i32Ty, (int64_t)-2147483648); - Value c528482304 = create.llvm.constant(i32Ty, (int64_t)528482304); - Value c805306368 = create.llvm.constant(i32Ty, (int64_t)805306368); - Value c8372224 = create.llvm.constant(i32Ty, (int64_t)8372224); + create.llvm.constant(i32Ty, static_cast(-2147483648)); + Value c528482304 = + create.llvm.constant(i32Ty, static_cast(528482304)); + Value c805306368 = + create.llvm.constant(i32Ty, static_cast(805306368)); + Value c8372224 = + create.llvm.constant(i32Ty, static_cast(8372224)); Value v23 = create.llvm.shl(inputI32, c16); Value v24 = create.llvm.andi(v23, cm2147483648); Value v25 = create.llvm.shl(inputI32, c14); @@ -1696,10 +2308,11 @@ class ZLowF32ToDLF16Lowering : public ConvertToLLVMPattern { Type vecTypeI16 = LLVM::getFixedVectorType(i16Ty, 8); Type vecTypeF16 = LLVM::getFixedVectorType(f16Ty, 8); - // SIMD instruction in string for z/Linux. + // SIMD instruction in string for z/Linux and z/OS. // Convert and round to DLF16: VCRNF V1,V2,V3,M4,M5 // M4 = 0 = DLF16, M5 = 2 = FP32 - const char *asmStr = "VCRNF $0,$1,$2,0,2"; + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " VCRNF $0,$1,$2,0,2 \n\t"; const char *asmConstraints = "=v,v,v"; // Prepare two input vectors: each for left/right four elements. @@ -1731,23 +2344,27 @@ class ZLowF32ToDLF16Lowering : public ConvertToLLVMPattern { // `clang -emit-llvm convert_f32_to_dlf16.cpp -S -O3` // where `convert_f32_to_dlf16.cpp` can be found at // https://github.com/tungld/onnx-mlir-tools/blob/main/convert_f32_to_dlf16.cpp - Value c0 = create.llvm.constant(i32Ty, (int64_t)0); - Value c9 = create.llvm.constant(i32Ty, (int64_t)9); - Value c14 = create.llvm.constant(i32Ty, (int64_t)14); - Value c16 = create.llvm.constant(i32Ty, (int64_t)16); - Value c23 = create.llvm.constant(i32Ty, (int64_t)23); - Value c255 = create.llvm.constant(i32Ty, (int64_t)255); - Value c8192 = create.llvm.constant(i32Ty, (int64_t)8192); - Value c32767 = create.llvm.constant(i32Ty, (int64_t)32767); - Value c32768 = create.llvm.constant(i32Ty, (int64_t)32768); - Value c32256 = create.llvm.constant(i32Ty, (int64_t)32256); - Value c8388607 = create.llvm.constant(i32Ty, (int64_t)8388607); - Value c8380415 = create.llvm.constant(i32Ty, (int64_t)8380415); - Value c1342152704 = create.llvm.constant(i32Ty, (int64_t)1342152704); - Value c2147475456 = create.llvm.constant(i32Ty, (int64_t)2147475456); - Value cm1 = create.llvm.constant(i32Ty, (int64_t)-1); - Value cm95 = create.llvm.constant(i32Ty, (int64_t)-95); - Value cm96 = create.llvm.constant(i32Ty, (int64_t)-96); + Value c0 = create.llvm.constant(i32Ty, static_cast(0)); + Value c9 = create.llvm.constant(i32Ty, static_cast(9)); + Value c14 = create.llvm.constant(i32Ty, static_cast(14)); + Value c16 = create.llvm.constant(i32Ty, static_cast(16)); + Value c23 = create.llvm.constant(i32Ty, static_cast(23)); + Value c255 = create.llvm.constant(i32Ty, static_cast(255)); + Value c8192 = create.llvm.constant(i32Ty, static_cast(8192)); + Value c32767 = create.llvm.constant(i32Ty, static_cast(32767)); + Value c32768 = create.llvm.constant(i32Ty, static_cast(32768)); + Value c32256 = create.llvm.constant(i32Ty, static_cast(32256)); + Value c8388607 = + create.llvm.constant(i32Ty, static_cast(8388607)); + Value c8380415 = + create.llvm.constant(i32Ty, static_cast(8380415)); + Value c1342152704 = + create.llvm.constant(i32Ty, static_cast(1342152704)); + Value c2147475456 = + create.llvm.constant(i32Ty, static_cast(2147475456)); + Value cm1 = create.llvm.constant(i32Ty, static_cast(-1)); + Value cm95 = create.llvm.constant(i32Ty, static_cast(-95)); + Value cm96 = create.llvm.constant(i32Ty, static_cast(-96)); Value inputI32 = create.llvm.bitcast(i32Ty, input); Value v24 = create.llvm.lshr(inputI32, c23); Value v25 = create.llvm.andi(v24, c255); @@ -1846,10 +2463,15 @@ class ZLowDLF16ToF32VectorLowering : public ConvertToLLVMPattern { Value inputVecI16 = create.llvm.bitcast(vecTypeI16, operandAdaptor.getInput()); - // Emit SIMD instruction for conversion. - // TODO: check if z/OS uses the same or different instruction. - const char *asmStr = ".insn vrr,0xe60000000056,$0,$2,0,2,0,0 \n\t" - ".insn vrr,0xe6000000005E,$1,$2,0,2,0,0 \n\t"; + // SIMD instruction in string for z/Linux and z/OS. + // Note this .insn version of asmStr was used previously for z/Linux. + // const char *asmStr = ".insn vrr,0xe60000000056,$0,$2,0,2,0,0 \n\t" + // ".insn vrr,0xe6000000005E,$1,$2,0,2,0,0 \n\t"; + // Convert and lengthen from DLF16: VCLFN(H/L) V1,V2,M3,M4 + // M3 = 2 = FP32, M4 = 0 = DLF16 + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " VCLFNH $0,$2,2,0 \n\t" + " VCLFNL $1,$2,2,0 \n\t"; const char *asmConstraints = "=&v,=v,v"; SmallVector asmVals{inputVecI16}; Value outVecI32Struct = @@ -1902,9 +2524,13 @@ class ZLowF32ToDLF16VectorLowering : public ConvertToLLVMPattern { Value vecI32H = create.llvm.bitcast(vecTypeI32, operandAdaptor.getInput1()); Value vecI32L = create.llvm.bitcast(vecTypeI32, operandAdaptor.getInput2()); - // Emit SIMD instruction for conversion. - // TODO: check if z/OS uses the same or different instruction. - const char *asmStr = ".insn vrr,0xe60000000075,$0,$1,$2,0,2,0"; + // SIMD instruction in string for z/Linux and z/OS. + // Note this .insn version of asmStr was used previously for z/Linux. + // asmStr = ".insn vrr,0xe60000000075,$0,$1,$2,0,2,0"; + // Convert and round to DLF16: VCRNF V1,V2,V3,M4,M5 + // M4 = 0 = DLF16, M5 = 2 = FP32 + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " VCRNF $0,$1,$2,0,2 \n\t"; const char *asmConstraints = "=v,v,v"; SmallVector asmVals{vecI32H, vecI32L}; @@ -1935,6 +2561,7 @@ void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, // clang-format off patterns.insert< ZLowStickLowering, + ZLowQuantizedStickLowering, ZLowUnstickLowering, ZLowStickForLSTMLowering, ZLowStickForGRULowering, @@ -1945,9 +2572,11 @@ void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, ZLowGRULowering, // Other operations ZLowMatMulLowering, + ZLowQuantizedMatMulLowering, ZLowConv2DLowering, ZLowMeanReduce2DLowering, ZLowBatchNormLowering, + ZLowLeakyReluLowering, // Scalar operations ZLowDLF16ToF32Lowering, ZLowF32ToDLF16Lowering, @@ -1966,13 +2595,18 @@ void populateZLowToLLVMConversionPattern(mlir::RewritePatternSet &patterns, // Unary operations ZLowUnaryElementwiseOpLowering, ZLowUnaryElementwiseOpLowering, + ZLowUnaryElementwiseOpLowering, // Activation operations ZLowUnaryElementwiseOpLowering, + ZLowUnaryElementwiseOpLowering, ZLowUnaryElementwiseOpLowering, ZLowUnaryElementwiseOpLowering, + ZLowUnaryElementwiseOpLowering, // Other operations ZLowPool2DLowering, - ZLowPool2DLowering + ZLowPool2DLowering, + ZLowReduceLowering, + ZLowReduceLowering >(ctx, typeConverter, apiRegistry); // clang-format on } diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index e253389293..114c19d618 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -4,7 +4,7 @@ //===---------- ZLowToLLVMCommon.hpp - Lowering from ZLow to LLVM ---------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -19,6 +19,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp" +#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "zdnn.h" @@ -52,9 +53,12 @@ ApiRegistry RegisterAllApis(MLIRContext *context) { ApiSpec(API::ZDNN_INIT_PRE_TRANSFORMED_DESC, "zdnn_init_pre_transformed_desc", voidTy, {int64Ty, int64Ty, opaquePtrTy}, true), ApiSpec(API::ZDNN_GENERATE_TRANSFORMED_DESC, "zdnn_generate_transformed_desc", int32Ty, {opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_GENERATE_TRANSFORMED_DESC_CONCATENATED, "zdnn_generate_transformed_desc_concatenated", int32Ty, {opaquePtrTy, int64Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_GENERATE_QUANTIZED_TRANSFORMED_DESC, "zdnn_generate_quantized_transformed_desc", int32Ty, {opaquePtrTy, int64Ty, opaquePtrTy}, false), ApiSpec(API::ZDNN_GETSIZE_ZTENSOR, "zdnn_getsize_ztensor", int64Ty, {opaquePtrTy}, false), ApiSpec(API::ZDNN_TRANSFORM_ZTENSOR, "zdnn_transform_ztensor", int32Ty, {opaquePtrTy}, true), + ApiSpec(API::ZDNN_TRANSFORM_ZTENSOR_WITH_SATURATION, "zdnn_transform_ztensor_with_saturation", int32Ty, {opaquePtrTy}, true), ApiSpec(API::ZDNN_TRANSFORM_ORIGTENSOR, "zdnn_transform_origtensor", int32Ty, {opaquePtrTy, opaquePtrTy}, false), + ApiSpec(API::ZDNN_TRANSFORM_QUANTIZED_ZTENSOR, "zdnn_transform_quantized_ztensor", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), // Elementwise operations ApiSpec(API::ZDNN_ADD, "zdnn_add_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_SUB, "zdnn_sub_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), @@ -64,17 +68,24 @@ ApiRegistry RegisterAllApis(MLIRContext *context) { ApiSpec(API::ZDNN_MAX, "zdnn_max_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_LOG, "zdnn_log_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_EXP, "zdnn_exp_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), + ApiSpec(API::ZDNN_INVSQRT, "zdnn_invsqrt_ext", int32Ty, {opaquePtrTy, float32Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_REDUCE, "zdnn_reduce_ext", int32Ty, {opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false), // Activation operations + ApiSpec(API::ZDNN_LEAKY_RELU, "zdnn_leaky_relu_ext", int32Ty, {opaquePtrTy, opaquePtrTy, float32Ty, opaquePtrTy}, false), ApiSpec(API::ZDNN_RELU, "zdnn_relu_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), + ApiSpec(API::ZDNN_GELU, "zdnn_gelu_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_TANH, "zdnn_tanh_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_SIGMOID, "zdnn_sigmoid_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_SOFTMAX, "zdnn_softmax_ext", int32Ty, {opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_SQRT, "zdnn_sqrt_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false), // RNN operations ApiSpec(API::ZDNN_LSTM, "zdnn_lstm", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_GRU, "zdnn_gru", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy, opaquePtrTy}, false), // Other operations ApiSpec(API::ZDNN_MATMUL_OP, "zdnn_matmul_op_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false), ApiSpec(API::ZDNN_MATMUL_BCAST_OP, "zdnn_matmul_bcast_op_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_MATMUL_TRANSPOSE_OP, "zdnn_matmul_transpose_op_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), + ApiSpec(API::ZDNN_QUANTIZED_MATMUL_OP, "zdnn_quantized_matmul_op", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_CONV2D, "zdnn_conv2d", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy, opaquePtrTy}, false), ApiSpec(API::ZDNN_AVGPOOL2D, "zdnn_avgpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), ApiSpec(API::ZDNN_MAXPOOL2D, "zdnn_maxpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false), @@ -109,8 +120,10 @@ Value ZTensorHelper::getPreTransformedDescPtr(zdnn_data_types zDNNDataType, Type llvmI64Ty = rewriter.getI64Type(); Type llvmZTensorDescStructTy = getZTensorDescStructTy(context); - Value one = create.llvm.constant(llvmI64Ty, (int64_t)1); + Value one = create.llvm.constant(llvmI64Ty, static_cast(1)); + // Alloca is fine for LLVM structs; if we were to use alloc, we would also to + // manually insert free calls. So alloca makes total sense here. Value preTransformedDescPtr = create.llvm._alloca( krnl::getPointerType(context, llvmZTensorDescStructTy), llvmZTensorDescStructTy, one, @@ -120,10 +133,12 @@ Value ZTensorHelper::getPreTransformedDescPtr(zdnn_data_types zDNNDataType, // descriptor. SmallVector operands; // 1. Data layout. - Value dataLayout = create.llvm.constant(llvmI64Ty, (int64_t)zDNNDataLayout); + Value dataLayout = + create.llvm.constant(llvmI64Ty, static_cast(zDNNDataLayout)); operands.emplace_back(dataLayout); // 2. Data type. - Value dataType = create.llvm.constant(llvmI64Ty, (int64_t)zDNNDataType); + Value dataType = + create.llvm.constant(llvmI64Ty, static_cast(zDNNDataType)); operands.emplace_back(dataType); // 3. Tensor descriptor. operands.emplace_back( @@ -150,7 +165,7 @@ Value ZTensorHelper::getTransformedDescPtr( Type llvmI64Ty = rewriter.getI64Type(); Type llvmZTensorDescStructTy = getZTensorDescStructTy(context); - Value one = create.llvm.constant(llvmI64Ty, (int64_t)1); + Value one = create.llvm.constant(llvmI64Ty, static_cast(1)); Value transformedDescPtr = create.llvm._alloca( krnl::getPointerType(context, llvmZTensorDescStructTy), @@ -158,7 +173,8 @@ Value ZTensorHelper::getTransformedDescPtr( /*alignment=*/0); if (isConcat) { - Value concatLayout = create.llvm.constant(llvmI64Ty, (int64_t)concatInfo); + Value concatLayout = + create.llvm.constant(llvmI64Ty, static_cast(concatInfo)); callApi(rewriter, loc, module, apiRegistry, API::ZDNN_GENERATE_TRANSFORMED_DESC_CONCATENATED, {toOpaquePtr(rewriter, loc, module, preTransformedDescPtr), @@ -172,6 +188,31 @@ Value ZTensorHelper::getTransformedDescPtr( return transformedDescPtr; } +// Get a transformed descriptor. +Value ZTensorHelper::getQuantizedTransformedDescPtr(Value preTransformedDescPtr, + zdnn_quantized_transform_types transformedType) { + MultiDialectBuilder create(rewriter, loc); + MLIRContext *context = module.getContext(); + + Type llvmI64Ty = rewriter.getI64Type(); + Type llvmZTensorDescStructTy = getZTensorDescStructTy(context); + Value one = create.llvm.constant(llvmI64Ty, (int64_t)1); + + Value transformedDescPtr = create.llvm._alloca( + krnl::getPointerType(context, llvmZTensorDescStructTy), + llvmZTensorDescStructTy, one, + /*alignment=*/0); + + Value transformedTyVal = + create.llvm.constant(llvmI64Ty, (int64_t)transformedType); + callApi(rewriter, loc, module, apiRegistry, + API::ZDNN_GENERATE_QUANTIZED_TRANSFORMED_DESC, + {toOpaquePtr(rewriter, loc, module, preTransformedDescPtr), + transformedTyVal, + toOpaquePtr(rewriter, loc, module, transformedDescPtr)}); + return transformedDescPtr; +} + // Get the pointer to memref. Value ZTensorHelper::getAlignedI8Ptr(Value memRef) { MLIRContext *context = rewriter.getContext(); @@ -202,7 +243,8 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType, // LLVM types for zTensor and zTensor descriptor. Type llvmZTensorStructTy = getZTensorStructTy(context); // Some frequently used constants. - Value one = create.llvm.constant(rewriter.getI64Type(), (int64_t)1); + Value one = + create.llvm.constant(rewriter.getI64Type(), static_cast(1)); // Create a pre transformed descriptor. Value preTransformedDescPtr = @@ -223,7 +265,9 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType, /*transformedDescPtr=*/transformedDescPtr, /*isTransformed=*/isTransformed, /*bufferSize=*/bufferSize, - /*alignedBuffer=*/bufferPtr); + /*alignedBuffer=*/bufferPtr, + /*recScale=*/nullptr, + /*offset=*/nullptr); // clang-format on zTensor.val = alloc; @@ -235,6 +279,55 @@ ZTensor ZTensorHelper::getZTensor(Value bufferPtr, zdnn_data_types dataType, return zTensor; } +/// Create a quantized zTensor. +ZTensor ZTensorHelper::getQuantizedZTensor(Value bufferPtr, + zdnn_data_types dataType, zdnn_data_layouts layout, + zdnn_quantized_transform_types transformType, ArrayRef originalDims, + Value recScale, Value offset, bool isTransformed) { + MultiDialectBuilder create(rewriter, loc); + MLIRContext *context = module.getContext(); + ZTensor zTensor; + + // LLVM types for zTensor and zTensor descriptor. + Type llvmZTensorStructTy = getZTensorStructTy(context); + // Some frequently used constants. + Value one = create.llvm.constant(rewriter.getI64Type(), (int64_t)1); + + // Create a pre transformed descriptor. + Value preTransformedDescPtr = + getPreTransformedDescPtr(dataType, layout, originalDims); + // Create a transformed descriptor. + Value transformedDescPtr = + getQuantizedTransformedDescPtr(preTransformedDescPtr, transformType); + // Create the input zTensor. + Value alloc = + create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy), + llvmZTensorStructTy, one, + /*alignment=*/0); + // Buffer size. + Value bufferSize = getBufferSize(transformedDescPtr); + // clang-format off + fillInZTensor(rewriter, loc, module, alloc, + /*preTransformedDescPtr=*/preTransformedDescPtr, + /*transformedDescPtr=*/transformedDescPtr, + /*isTransformed=*/isTransformed, + /*bufferSize=*/bufferSize, + /*alignedBuffer=*/bufferPtr, + /*recScale=*/recScale, + /*offset=*/offset); + // clang-format on + + zTensor.val = alloc; + zTensor.preTransformedDescPtr = preTransformedDescPtr; + zTensor.transformedDescPtr = transformedDescPtr; + zTensor.isTransformed = isTransformed; + zTensor.bufferSize = bufferSize; + zTensor.bufferPtr = bufferPtr; + zTensor.recScale = recScale; + zTensor.offset = offset; + return zTensor; +} + /// Create a zTensor from existing descriptors. ZTensor ZTensorHelper::getZTensor(Value preTransformedDescPtr, Value transformedDescPtr, Value bufferSize, Value bufferPtr, @@ -244,7 +337,8 @@ ZTensor ZTensorHelper::getZTensor(Value preTransformedDescPtr, ZTensor zTensor; Type llvmZTensorStructTy = getZTensorStructTy(context); - Value one = create.llvm.constant(rewriter.getI64Type(), (int64_t)1); + Value one = + create.llvm.constant(rewriter.getI64Type(), static_cast(1)); Value alloc = create.llvm._alloca(krnl::getPointerType(context, llvmZTensorStructTy), llvmZTensorStructTy, one, @@ -364,7 +458,7 @@ std::vector getDimsFromShapeMemRefBySize(PatternRewriter &rewriter, bitcastOp.getArg().getDefiningOp()); if (addressOfOp) { LLVM::GlobalOp globalOp = - dyn_cast_or_null(SymbolTable::lookupSymbolIn( + mlir::dyn_cast_or_null(SymbolTable::lookupSymbolIn( module, addressOfOp.getGlobalNameAttr())); if (globalOp) { DenseElementsAttr valueAttr = @@ -385,7 +479,7 @@ std::vector getDimsFromShapeMemRefBySize(PatternRewriter &rewriter, for (int64_t i = 0; i < size; ++i) { Value alignedGep = create.llvm.getElemPtr(krnl::getPointerType(context, int64Ty), int64Ty, - alignedPtr, ArrayRef{(int32_t)i}); + alignedPtr, ArrayRef{static_cast(i)}); Value dimI64 = create.llvm.load(int64Ty, alignedGep); dims.emplace_back(dimI64); } @@ -433,6 +527,8 @@ zdnn_data_types llvmTypeToZDNNType(Type elemType) { return FP16; else if (mlir::isa(elemType)) return FP32; + else if (elemType.isInteger(8)) + return INT8; else llvm_unreachable("Unexpected LLVM type, cannot be converted to zDNN type."); } @@ -474,7 +570,9 @@ Type getZTensorStructTy(MLIRContext *context) { Type llvmI64Ty = IntegerType::get(context, 64); Type llvmI1Ty = IntegerType::get(context, 1); Type llvmI8Ty = IntegerType::get(context, 8); - Type llvmArrayI8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 31); + Type llvmF32Ty = FloatType::getF32(context); + Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3); + Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20); Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty); Type llvmZTensorDescStructTy = getZTensorDescStructTy(context); @@ -491,8 +589,14 @@ Type getZTensorStructTy(MLIRContext *context) { zTensorTypeElements.emplace_back(llvmI8PtrTy); // indicator if data in buffer has been transformed zTensorTypeElements.emplace_back(llvmI1Ty); - // reserved[31], not currently used, exploiter should not touch - zTensorTypeElements.emplace_back(llvmArrayI8Ty); + // reserved[3], not currently used, should contain zeros + zTensorTypeElements.emplace_back(llvmArray3I8Ty); + // the scale factor for quantization, stored as reciprocal + zTensorTypeElements.emplace_back(llvmF32Ty); + // the offset for quantization + zTensorTypeElements.emplace_back(llvmF32Ty); + // reserved[20], not currently used, should contain zeros + zTensorTypeElements.emplace_back(llvmArray20I8Ty); Type zTensorStructTy = LLVM::LLVMStructType::getLiteral(context, /*elements=*/zTensorTypeElements, @@ -510,7 +614,8 @@ Value toOpaquePtr( void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, Value zTensor, Value preTransformedDescPtr, Value transformedDescPtr, - bool isTransformed, Value bufferSize, Value alignedBuffer) { + bool isTransformed, Value bufferSize, Value alignedBuffer, Value recScale, + Value offset) { MLIRContext *context = module.getContext(); MultiDialectBuilder create(rewriter, loc); @@ -540,13 +645,73 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, create.llvm.store(alignedBuffer, bufferPtr); // 5. Set is_transformed. - Value isTransformedVal = - create.llvm.constant(llvmI1Ty, (int64_t)((isTransformed) ? 1 : 0)); + Value isTransformedVal = create.llvm.constant( + llvmI1Ty, static_cast(((isTransformed) ? 1 : 0))); Value isTransformedDescPtr = create.llvm.getElemPtr( llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef{0, 4}); create.llvm.store(isTransformedVal, isTransformedDescPtr); - // 6. Set reserved (not currently used), not touch + // 6. Set reserved1 (3 bytes), not currently used. + + // 7. Set rec_scale. + Value recScalePtr = create.llvm.getElemPtr( + llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef{0, 6}); + if (recScale) { + Type scaleTy = recScale.getType(); + assert( + scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float"); + create.llvm.store(recScale, recScalePtr); + } else { + Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + create.llvm.store(zero, recScalePtr); + } + + // 8. Set offset + Value offsetPtr = create.llvm.getElemPtr( + llvmZTensorPtrTy, llvmZTensorTy, zTensor, ArrayRef{0, 7}); + if (offset) { + Type offsetTy = offset.getType(); + assert( + offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float"); + create.llvm.store(offset, offsetPtr); + } else { + Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + create.llvm.store(zero, offsetPtr); + } + + // 9. Set reserved2 (20 bytes), not currently used. +} + +Value loadFromMemRef( + LLVMBuilder &create, Type elementTy, Value llvmMemRef, int32_t index) { + MLIRContext *context = create.getBuilder().getContext(); + MemRefDescriptor mrd(llvmMemRef); + Value alignedPtr = mrd.alignedPtr(create.getBuilder(), create.getLoc()); + Value alignedGep = create.getElemPtr(krnl::getPointerType(context, elementTy), + elementTy, alignedPtr, ArrayRef{index}); + return create.load(elementTy, alignedGep); +} + +void storeToMemRef( + LLVMBuilder &create, Value val, Value llvmMemRef, int32_t index) { + MLIRContext *context = create.getBuilder().getContext(); + Type elementTy = val.getType(); + MemRefDescriptor mrd(llvmMemRef); + Value alignedPtr = mrd.alignedPtr(create.getBuilder(), create.getLoc()); + Value alignedGep = create.getElemPtr(krnl::getPointerType(context, elementTy), + elementTy, alignedPtr, ArrayRef{index}); + create.store(val, alignedGep); +} + +zdnn_quantized_transform_types getQuantizedTransformType(mlir::StringRef str) { + if (str.equals_insensitive(QTYPE_DLFLOAT16)) + return QUANTIZED_DLFLOAT16; + else if (str.equals_insensitive(QTYPE_INT8)) + return QUANTIZED_INT8; + else if (str.equals_insensitive(QTYPE_WEIGHTS)) + return QUANTIZED_WEIGHTS_INT8; + else + llvm_unreachable("Invalid transform type"); } } // namespace zlow diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp index 9e9c251b73..fc427e5c87 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.hpp @@ -19,6 +19,8 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/Dialect/Mlir/DialectBuilder.hpp" + #include "zdnn.h" namespace onnx_mlir { @@ -30,9 +32,12 @@ enum class API { ZDNN_INIT_PRE_TRANSFORMED_DESC, ZDNN_GENERATE_TRANSFORMED_DESC, ZDNN_GENERATE_TRANSFORMED_DESC_CONCATENATED, + ZDNN_GENERATE_QUANTIZED_TRANSFORMED_DESC, ZDNN_GETSIZE_ZTENSOR, ZDNN_TRANSFORM_ZTENSOR, + ZDNN_TRANSFORM_ZTENSOR_WITH_SATURATION, ZDNN_TRANSFORM_ORIGTENSOR, + ZDNN_TRANSFORM_QUANTIZED_ZTENSOR, // Elementwise operations ZDNN_ADD, ZDNN_SUB, @@ -42,22 +47,30 @@ enum class API { ZDNN_MAX, ZDNN_LOG, ZDNN_EXP, + ZDNN_INVSQRT, + // Reduction operations + ZDNN_REDUCE, + ZDNN_MEANREDUCE2D, // Activation operations ZDNN_RELU, + ZDNN_GELU, ZDNN_TANH, ZDNN_SIGMOID, ZDNN_SOFTMAX, + ZDNN_SQRT, // RNN operations ZDNN_LSTM, ZDNN_GRU, // Other operations ZDNN_MATMUL_OP, ZDNN_MATMUL_BCAST_OP, + ZDNN_MATMUL_TRANSPOSE_OP, + ZDNN_QUANTIZED_MATMUL_OP, ZDNN_CONV2D, ZDNN_AVGPOOL2D, ZDNN_MAXPOOL2D, - ZDNN_MEANREDUCE2D, ZDNN_BATCHNORM, + ZDNN_LEAKY_RELU, // Scalar operations. DLF16_TO_F32, F32_TO_DLF16, @@ -100,6 +113,8 @@ struct ZTensor { mlir::Value bufferSize; mlir::Value bufferPtr; bool isTransformed; + mlir::Value recScale; + mlir::Value offset; }; /// A helper class to create a zTensor. @@ -117,6 +132,10 @@ class ZTensorHelper { bool isConcat = false, zdnn_concat_info concatInfo = RNN_TYPE_GRU | USAGE_WEIGHTS | PREV_LAYER_NONE); + // Get a quantized transformed descriptor. + mlir::Value getQuantizedTransformedDescPtr(mlir::Value preTransformedDescPtr, + zdnn_quantized_transform_types transform_type); + // Get the pointer to memref. mlir::Value getAlignedI8Ptr(mlir::Value memRef); // Get buffer size from a transformed descriptor. @@ -127,6 +146,11 @@ class ZTensorHelper { bool isTransformed, bool isConcat = false, zdnn_concat_info concatInfo = RNN_TYPE_GRU | USAGE_WEIGHTS | PREV_LAYER_NONE); + // Create a quantized zTensor. + ZTensor getQuantizedZTensor(mlir::Value bufferPtr, zdnn_data_types dataType, + zdnn_data_layouts layout, zdnn_quantized_transform_types transformType, + mlir::ArrayRef originalDims, mlir::Value recScale, + mlir::Value offset, bool isTransformed); // Create a zTensor from existing descriptors. ZTensor getZTensor(mlir::Value preTransformedDescPtr, mlir::Value transformedDescPtr, mlir::Value bufferSize, @@ -197,7 +221,19 @@ mlir::Value toOpaquePtr(mlir::PatternRewriter &rewriter, mlir::Location loc, void fillInZTensor(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::ModuleOp module, mlir::Value zTensor, mlir::Value preTransformedDescPtr, mlir::Value transformedDescPtr, - bool isTransformed, mlir::Value bufferSize, mlir::Value alignedBuffer); + bool isTransformed, mlir::Value bufferSize, mlir::Value alignedBuffer, + mlir::Value recScale = nullptr, mlir::Value offset = nullptr); + +/// Function to load a value from a LLVM Struct of MemRef. +mlir::Value loadFromMemRef(onnx_mlir::LLVMBuilder &create, mlir::Type elementTy, + mlir::Value llvmMemRef, int32_t index); + +/// Function to store a value to a LLVM Struct of MemRef. +void storeToMemRef(onnx_mlir::LLVMBuilder &create, mlir::Value val, + mlir::Value llvmMemRef, int32_t index); + +/// Function to get a quantized tranform type from a string. +zdnn_quantized_transform_types getQuantizedTransformType(mlir::StringRef str); } // namespace zlow } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt index 915ed61717..2c3e6ca953 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt @@ -23,6 +23,9 @@ add_onnx_mlir_library(OMZHighOps ZHighOps/MatMul/MatMul.cpp ZHighOps/MeanReduce2D/MeanReduce2D.cpp ZHighOps/Pooling/Pooling.cpp + ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp + ZHighOps/QuantizedStick/QuantizedStick.cpp + ZHighOps/Reduction/Reduction.cpp ZHighOps/Softmax/Softmax.cpp ZHighOps/Stick/Stick.cpp ZHighOps/StickForGRU/StickForGRU.cpp diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td index d2624138c0..7bbcd02c87 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td @@ -2,7 +2,7 @@ //===----- ZHigh.td -- ZHigh Dialect Operation Definitions -*- tablegen ----==// // -// Copyright 2019-2020 The IBM Research Authors +// Copyright 2019-2024 The IBM Research Authors // // ============================================================================= // @@ -67,7 +67,8 @@ def ZTensorEncodingAttr : ZHigh_Attr<"ZTensorEncoding"> { // Data in ztensor encoding. let parameters = (ins // A original data layout of the pre-stickified data. - "ZTensorEncodingAttr::DataLayout":$dataLayout + "ZTensorEncodingAttr::DataLayout":$dataLayout, + "ZTensorEncodingAttr::QuantizedType":$quantizedType ); let extraClassDeclaration = [{ @@ -76,6 +77,13 @@ def ZTensorEncodingAttr : ZHigh_Attr<"ZTensorEncoding"> { NCHW, NHWC, HWCK, FICO, ZRH, BFICO, BZRH }; + enum class QuantizedType { + UNDEFINED, DLFLOAT16, INT8, WEIGHTS + }; + // QuantizedType is optional. + static ZTensorEncodingAttr get(::mlir::MLIRContext *context, ZTensorEncodingAttr::DataLayout dataLayout) { + return get(context, dataLayout, ZTensorEncodingAttr::QuantizedType::UNDEFINED); + } }]; let cppNamespace = "::onnx_mlir::zhigh"; @@ -89,6 +97,18 @@ class DataLayoutOfPred : And<[ " == ZTensorEncodingAttr::DataLayout::" # layout # ")"> ]>; +// Whether a ztensor type has the specified quantized type. +class QuantizedTypeOfPred : And<[ + CPred<"(mlir::cast<::mlir::RankedTensorType>($_self)) &&" + "(mlir::dyn_cast_or_null(mlir::cast<::mlir::RankedTensorType>($_self).getEncoding())) &&" + "(mlir::cast(mlir::cast<::mlir::RankedTensorType>($_self).getEncoding()).getQuantizedType()" + " == ZTensorEncodingAttr::QuantizedType::" # qtype # ")"> +]>; + +// Whether a shaped type has one of the specified quantized type. +class HasAnyQuantizedTypeOfPred qtypes> : And<[ + Or)>]>; + // So far ZTensor supports only F16 for stickified data. class ZTensorOf ranks> : Type.predicate, HasAnyRankOfPred, @@ -97,8 +117,21 @@ class ZTensorOf ranks> : TensorOf<[F16]>.summary # " with layout " # layout, "::mlir::TensorType">; +// Quantized ZTensor. +class QZTensorOf ranks> : + Type.predicate, + DataLayoutOfPred, + HasAnyQuantizedTypeOfPred<["DLFLOAT16", "INT8", "WEIGHTS"]>, + HasAnyRankOfPred + ]>, + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # + TensorOf<[I8, F16]>.summary # " with layout " # layout, + "::mlir::TensorType">; + def UnrankedZTensor : UnrankedTensorOf<[F16]>; +def UnrankedQZTensor : UnrankedTensorOf<[I8, F16]>; + def ZTensor_1D: AnyTypeOf<[UnrankedZTensor, ZTensorOf<"_1D", [1]>]>; def ZTensor_2D: AnyTypeOf<[UnrankedZTensor, ZTensorOf<"_2D", [2]>]>; def ZTensor_2DS: AnyTypeOf<[UnrankedZTensor, ZTensorOf<"_2DS", [2]>]>; @@ -119,6 +152,15 @@ def AnyZTensor: AnyTypeOf<[ZTensor_1D, ZTensor_2D, ZTensor_3D, ZTensor_4D, ZTensor_NCHW, ZTensor_NHWC, ZTensor_HWCK, ZTensor_FICO, ZTensor_ZRH, ZTensor_BFICO, ZTensor_BZRH]>; +def QZTensor_1D: AnyTypeOf<[UnrankedQZTensor, QZTensorOf<"_1D", [1]>]>; +def QZTensor_2D: AnyTypeOf<[UnrankedQZTensor, QZTensorOf<"_2D", [2]>]>; +def QZTensor_2DS: AnyTypeOf<[UnrankedQZTensor, QZTensorOf<"_2DS", [2]>]>; +def QZTensor_3D: AnyTypeOf<[UnrankedQZTensor, QZTensorOf<"_3D", [3]>]>; +def QZTensor_3DS: AnyTypeOf<[UnrankedQZTensor, QZTensorOf<"_3DS", [3]>]>; + +def AnyQZTensor: AnyTypeOf<[QZTensor_1D, QZTensor_2D, QZTensor_3D, QZTensor_2DS, + QZTensor_3DS]>; + //===----------------------------------------------------------------------===// // ZHigh Operations //===----------------------------------------------------------------------===// @@ -174,6 +216,47 @@ def ZHighStickOp:ZHigh_Op<"Stick", [Pure, }]; } +def ZHighQuantizedStickOp:ZHigh_Op<"QuantizedStick", [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh QuantizedStick operation"; + let description = [{ + ZHigh operation to perform a quantized Stick. + Type is one of values: dlfloat16, int8, and weights. + `sym_mode` indicates whether to use symmetric quantization or not to compute the output rescale and offset. + `sym_mode` is only effective when the input rescale and offset are None. + By default, asymmetric quantization is used. + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I8]>, + ZTensor_3D, ZTensor_2DS, ZTensor_3DS]>:$In, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$InRecScale, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$InOffset, + StrAttr:$layout, + StrAttr:$quantized_type, + DefaultValuedAttr:$sym_mode); + let results = (outs AnyTypeOf<[QZTensor_1D, QZTensor_2D, QZTensor_3D, + QZTensor_2DS, QZTensor_3DS, + NoneType]>:$Out, + 0DTensorOf<[F32]>:$RecScale, + 0DTensorOf<[F32]>:$Offset); + let hasCanonicalizer = 1; + let builders = [ + OpBuilder<(ins "::mlir::Value":$In, "::mlir::Value":$InRecScale, "::mlir::Value":$InOffset, + "::mlir::StringAttr":$layout, "::mlir::StringAttr":$quantized_type)>, + OpBuilder<(ins "::mlir::Value":$In, "::mlir::Value":$InRecScale, "::mlir::Value":$InOffset, + "::mlir::StringAttr":$layout, "::mlir::StringAttr":$quantized_type, + "::mlir::IntegerAttr":$sym_mode)> + ]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighQuantizedStickOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighQuantizedStickOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ZHighUnstickOp:ZHigh_Op<"Unstick", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -417,6 +500,45 @@ def ZHighExpOp:ZHigh_Op<"Exp", [Pure, SameOperandsAndResultLayout, }]; } +def ZHighLeakyReluOp:ZHigh_Op<"LeakyRelu", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh LeakyRelu operation"; + let description = [{ + "ZHigh operation to perform a LeakyRelu." + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X, + DefaultValuedAttr:$alpha); + let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighLeakyReluOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + +def ZHighInvSqrtOp:ZHigh_Op<"InvSqrt", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh InvSqrt operation"; + let description = [{ + ZHigh operation to perform a InvSqrt. + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); + let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighInvSqrtOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ZHighReluOp:ZHigh_Op<"Relu", [Pure, SameOperandsAndResultLayout, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -436,6 +558,25 @@ def ZHighReluOp:ZHigh_Op<"Relu", [Pure, SameOperandsAndResultLayout, }]; } +def ZHighGeluOp:ZHigh_Op<"Gelu", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh Gelu operation"; + let description = [{ + "ZHigh operation to perform a Gelu." + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X, DefaultValuedStrAttr:$approximate); + let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighGeluOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighUnaryOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ZHighTanhOp:ZHigh_Op<"Tanh", [Pure, SameOperandsAndResultLayout, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -495,6 +636,70 @@ def ZHighSoftmaxOp:ZHigh_Op<"Softmax", [Pure, SameOperandsAndResultLayout, }]; } +def ZHighSqrtOp:ZHigh_Op<"Sqrt", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh Sqrt operation"; + let description = [{ + ZHigh operation to perform a Sqrt. + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$X); + let results = (outs AnyTypeOf<[AnyZTensor]>:$Out); +} + +def ZHighReduceMaxOp:ZHigh_Op<"ReduceMax", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh ReduceMax operation"; + let description = [{ + ZHigh operation to perform a ReduceMax. + op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM. + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data); + let results = (outs AnyTypeOf<[AnyZTensor]>:$output); + let builders = [ + OpBuilder<(ins "::mlir::Value":$data), [{ + Type elementType = mlir::cast(data.getType()).getElementType(); + UnrankedTensorType resType = UnrankedTensorType::get(elementType); + build($_builder, $_state, resType, data); + }]> + ]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighReduceMaxOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReductionOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + +def ZHighReduceMinOp:ZHigh_Op<"ReduceMin", [Pure, SameOperandsAndResultLayout, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh ReduceMin operation"; + let description = [{ + ZHigh operation to perform a ReduceMin. + op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM. + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data); + let results = (outs AnyTypeOf<[AnyZTensor]>:$output); + let builders = [ + OpBuilder<(ins "::mlir::Value":$data), [{ + Type elementType = mlir::cast(data.getType()).getElementType(); + UnrankedTensorType resType = UnrankedTensorType::get(elementType); + build($_builder, $_state, resType, data); + }]> + ]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighReduceMinOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReductionOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ZHighMeanReduce2DOp:ZHigh_Op<"MeanReduce2d", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -595,14 +800,17 @@ def ZHighMatMulOp:ZHigh_Op<"MatMul", [Pure, }]; let arguments = (ins AnyTypeOf<[ZTensor_2D, ZTensor_3DS]>:$X, AnyTypeOf<[ZTensor_2D, ZTensor_3DS]>:$Y, - AnyTypeOf<[ZTensor_1D, ZTensor_2DS, NoneType]>:$B); + AnyTypeOf<[ZTensor_1D, ZTensor_2DS, NoneType]>:$B, + DefaultValuedAttr:$transposeA, + DefaultValuedAttr:$transposeB); let results = (outs AnyTypeOf<[ZTensor_2D, ZTensor_3DS]>:$Out); let builders = [ - OpBuilder<(ins "::mlir::Value":$X, "::mlir::Value":$Y, "::mlir::Value":$B), [{ + OpBuilder<(ins "::mlir::Value":$X, "::mlir::Value":$Y, "::mlir::Value":$B, + "::mlir::IntegerAttr":$transposeA, "::mlir::IntegerAttr":$transposeB), [{ Type elementType = mlir::cast(X.getType()).getElementType(); UnrankedTensorType resType = UnrankedTensorType::get(elementType); - build($_builder, $_state, resType, X, Y, B); + build($_builder, $_state, resType, X, Y, B, transposeA, transposeB); }]> ]; let hasVerifier = 1; @@ -616,6 +824,54 @@ def ZHighMatMulOp:ZHigh_Op<"MatMul", [Pure, }]; } +def ZHighQuantizedMatMulOp:ZHigh_Op<"QuantizedMatMul", [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh QuantizedMatMul operation"; + let description = [{ + ZHigh operation to perform a quantized MatMul. + + `OutRecScaleIn` and `OutOffsetIn` are recscale and offset for the output. + If `OutRecScaleIn` is given, it will be passed to `OutRecScale`. If it is + None, `OutRescScale` is set to 1.0. + If `OutOffsetIn` is given, it will be passed to `OutOffset`. If it is + None, `OutOffset` is set to 0.0. + + * PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed. + + `DequantizeOutput` indicates if the output + is dequantized to real dfloat16 or not. If not, the output is int8 but stored in dlfloat (int8-as-dlfloat). + * DequantizeOutput: -1 output is dequantized, 0: output is not dequantized. + }]; + let arguments = (ins AnyTypeOf<[QZTensor_2D, QZTensor_3DS]>:$X, + 0DTensorOf<[F32]>:$XRecScale, + 0DTensorOf<[F32]>:$XOffset, + AnyTypeOf<[QZTensor_2D, QZTensor_3DS]>:$Y, + 0DTensorOf<[F32]>:$YRecScale, + 0DTensorOf<[F32]>:$YOffset, + AnyTypeOf<[ZTensor_1D, ZTensor_2DS, QZTensor_1D, QZTensor_2DS, NoneType]>:$B, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$BRecScale, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$BOffset, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$OutRecScaleIn, + AnyTypeOf<[0DTensorOf<[F32]>, NoneType]>:$OutOffsetIn, + DefaultValuedAttr:$PreComputedBias, + DefaultValuedAttr:$DisableClipping, + DefaultValuedAttr:$DequantizeOutput); + + let results = (outs AnyTypeOf<[QZTensor_2D, QZTensor_3DS, ZTensor_2D, ZTensor_3DS]>:$Out, + 0DTensorOf<[F32]>:$OutRecScale, + 0DTensorOf<[F32]>:$OutOffset); + let hasVerifier = 1; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighQuantizedMatMulOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighQuantizedMatMulOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ZHighLSTMOp:ZHigh_Op<"LSTM", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -868,7 +1124,7 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> { }]; let arguments = (ins OptionalAttr:$value, DefaultValuedAttr:$alignment); - let results = (outs AnyZTensor:$output); + let results = (outs AnyTypeOf<[AnyZTensor, AnyQZTensor]>:$output); } def ZHighStickifiedConstantOfShapeOp:ZHigh_Op<"StickifiedConstantOfShape", [Pure, diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp index 8e70fe364f..de587aecba 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp @@ -71,10 +71,10 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultLayout(Operation *op) { namespace onnx_mlir { namespace zhigh { -std::vector getZHighAuxSplitResultType( +std::vector getZHighAuxSplitResultType( Value input, int64_t axis, ArrayAttr split) { Type elementType = mlir::cast(input.getType()).getElementType(); - std::vector outputTypes; + std::vector outputTypes; if (split.size() == 0) { llvm_unreachable("Unsupported split (size==0)"); } else { @@ -111,6 +111,9 @@ Attribute ZTensorEncodingAttr::parse(AsmParser &parser, Type type) { ZTensorEncodingAttr::DataLayout dataLayout = ZTensorEncodingAttr::DataLayout::UNDEFINED; + ZTensorEncodingAttr::QuantizedType quantizedType = + ZTensorEncodingAttr::QuantizedType::UNDEFINED; + // Process the data from the parsed dictionary value into struct-like data. for (const NamedAttribute &attr : dict) { if (attr.getName() == "dataLayout") { @@ -155,6 +158,27 @@ Attribute ZTensorEncodingAttr::parse(AsmParser &parser, Type type) { << strVal; return {}; } + } else if (attr.getName() == "quantizedType") { + StringAttr qtypeAttr = mlir::dyn_cast(attr.getValue()); + if (!qtypeAttr) { + parser.emitError( + parser.getNameLoc(), "expected a string value for quantized type"); + return {}; + } + StringRef strVal = qtypeAttr.getValue(); + if (strVal.equals_insensitive(QTYPE_DLFLOAT16)) { + quantizedType = ZTensorEncodingAttr::QuantizedType::DLFLOAT16; + } else if (strVal.equals_insensitive(QTYPE_INT8)) { + quantizedType = ZTensorEncodingAttr::QuantizedType::INT8; + } else if (strVal.equals_insensitive(QTYPE_WEIGHTS)) { + quantizedType = ZTensorEncodingAttr::QuantizedType::WEIGHTS; + } else if (strVal.equals_insensitive(QTYPE_UNDEFINED)) { + quantizedType = ZTensorEncodingAttr::QuantizedType::UNDEFINED; + } else { + parser.emitError(parser.getNameLoc(), "unexpected quantized type: ") + << strVal; + return {}; + } } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().str(); @@ -163,7 +187,7 @@ Attribute ZTensorEncodingAttr::parse(AsmParser &parser, Type type) { } // Construct struct-like storage for attribute. return parser.getChecked( - parser.getContext(), dataLayout); + parser.getContext(), dataLayout, quantizedType); } void ZTensorEncodingAttr::print(AsmPrinter &printer) const { @@ -216,6 +240,27 @@ void ZTensorEncodingAttr::print(AsmPrinter &printer) const { llvm_unreachable("Unexpected data layout"); break; } + + // QuantizedType is optional. + switch (getQuantizedType()) { + case QuantizedType::DLFLOAT16: + printer << ", quantizedType = "; + printer << "\"" << QTYPE_DLFLOAT16 << "\""; + break; + case QuantizedType::INT8: + printer << ", quantizedType = "; + printer << "\"" << QTYPE_INT8 << "\""; + break; + case QuantizedType::WEIGHTS: + printer << ", quantizedType = "; + printer << "\"" << QTYPE_WEIGHTS << "\""; + break; + case QuantizedType::UNDEFINED: + break; + default: + llvm_unreachable("Unexpected quantized type"); + break; + } printer << "}>"; } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/BatchNorm/BatchNorm.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/BatchNorm/BatchNorm.cpp index 50e05541a5..e4f7e1b51b 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/BatchNorm/BatchNorm.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/BatchNorm/BatchNorm.cpp @@ -24,7 +24,7 @@ namespace zhigh { //===----------------------------------------------------------------------===// LogicalResult ZHighBatchNormOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp index a05e73890e..5af7488498 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Conv2D/Conv2D.cpp @@ -139,7 +139,7 @@ LogicalResult ZHighConv2DOp::verify() { //===----------------------------------------------------------------------===// LogicalResult ZHighConv2DOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput()) || !hasRankedType(getInputKernel())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/DLF16ToF32.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/DLF16ToF32.cpp index 616c1f4de2..c2a488d302 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/DLF16ToF32.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/DLF16ToF32/DLF16ToF32.cpp @@ -33,7 +33,7 @@ void ZHighDLF16ToF32Op::build( Type elementType = builder.getF32Type(); Type resType = UnrankedTensorType::get(elementType); - if (auto inType = dyn_cast(input.getType())) + if (auto inType = mlir::dyn_cast(input.getType())) resType = RankedTensorType::get(inType.getShape(), elementType); build(builder, state, resType, input); @@ -44,7 +44,7 @@ void ZHighDLF16ToF32Op::build( //===----------------------------------------------------------------------===// LogicalResult ZHighDLF16ToF32Op::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Elementwise/Elementwise.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Elementwise/Elementwise.cpp index 17463045a3..429d662693 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Elementwise/Elementwise.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Elementwise/Elementwise.cpp @@ -4,7 +4,7 @@ //===------------------ Elementwise.cpp - ZHigh Operations ----------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -22,7 +22,7 @@ namespace zhigh { //===----------------------------------------------------------------------===// // AddOp LogicalResult ZHighAddOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -30,7 +30,7 @@ LogicalResult ZHighAddOp::inferShapes( // SubOp LogicalResult ZHighSubOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -38,7 +38,7 @@ LogicalResult ZHighSubOp::inferShapes( // MulOp LogicalResult ZHighMulOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -46,7 +46,7 @@ LogicalResult ZHighMulOp::inferShapes( // DivOp LogicalResult ZHighDivOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -54,7 +54,7 @@ LogicalResult ZHighDivOp::inferShapes( // MinOp LogicalResult ZHighMinOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -62,7 +62,7 @@ LogicalResult ZHighMinOp::inferShapes( // MaxOp LogicalResult ZHighMaxOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -70,7 +70,7 @@ LogicalResult ZHighMaxOp::inferShapes( // LogOp LogicalResult ZHighLogOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -78,6 +78,22 @@ LogicalResult ZHighLogOp::inferShapes( // ExpOp LogicalResult ZHighExpOp::inferShapes( + std::function doShapeInference) { + return inferShapeForUnaryOps(this->getOperation()); +} + +//===----------------------------------------------------------------------===// +// InvSqrtOp + +LogicalResult ZHighInvSqrtOp::inferShapes( + std::function doShapeInference) { + return inferShapeForUnaryOps(this->getOperation()); +} + +//===----------------------------------------------------------------------===// +// LeakyReluOp + +LogicalResult ZHighLeakyReluOp::inferShapes( std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -86,6 +102,14 @@ LogicalResult ZHighExpOp::inferShapes( // ReluOp LogicalResult ZHighReluOp::inferShapes( + std::function doShapeInference) { + return inferShapeForUnaryOps(this->getOperation()); +} + +//===----------------------------------------------------------------------===// +// GeluOp + +LogicalResult ZHighGeluOp::inferShapes( std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -94,7 +118,7 @@ LogicalResult ZHighReluOp::inferShapes( // TanhOp LogicalResult ZHighTanhOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } @@ -102,6 +126,14 @@ LogicalResult ZHighTanhOp::inferShapes( // SigmoiOp LogicalResult ZHighSigmoidOp::inferShapes( + std::function doShapeInference) { + return inferShapeForUnaryOps(this->getOperation()); +} + +//===----------------------------------------------------------------------===// +// SqrtOp + +LogicalResult ZHighSqrtOp::inferShapes( std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/F32ToDLF16.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/F32ToDLF16.cpp index b36a9ecf4e..1fcd52c53a 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/F32ToDLF16.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/F32ToDLF16/F32ToDLF16.cpp @@ -33,7 +33,7 @@ void ZHighF32ToDLF16Op::build(OpBuilder &builder, OperationState &state, Type elementType = builder.getF16Type(); Type resType = UnrankedTensorType::get(elementType); - if (auto inType = dyn_cast(input.getType())) + if (auto inType = mlir::dyn_cast(input.getType())) resType = RankedTensorType::get(inType.getShape(), elementType); build(builder, state, resType, input, saturation); @@ -44,7 +44,7 @@ void ZHighF32ToDLF16Op::build(OpBuilder &builder, OperationState &state, //===----------------------------------------------------------------------===// LogicalResult ZHighF32ToDLF16Op::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUY.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUY.cpp index 22ddf1e8d6..a4af02d7c7 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUY.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUY.cpp @@ -24,7 +24,7 @@ namespace zhigh { //===----------------------------------------------------------------------===// LogicalResult ZHighFixGRUYOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUYh.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUYh.cpp index 12312ac2e8..8b06993d54 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUYh.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/FixGRUYh.cpp @@ -41,7 +41,7 @@ LogicalResult ZHighFixGRUYhOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighFixGRUYhOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getY())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp index 6c2cebe5c9..bfbcf60827 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/GRU/GRU.cpp @@ -51,14 +51,14 @@ LogicalResult ZHighGRUOpShapeHelper::computeShape() { if (isAllTimesteps) hnOutputDims.emplace_back(S); else - hnOutputDims.emplace_back(LiteralIndexExpr(1)); + hnOutputDims.emplace_back(LitIE(1)); hnOutputDims.emplace_back(D); hnOutputDims.emplace_back(B); hnOutputDims.emplace_back(H); // Shape for cf_ouput : [1, B, H] DimsExpr cfOutputDims; - cfOutputDims.emplace_back(LiteralIndexExpr(1)); + cfOutputDims.emplace_back(LitIE(1)); cfOutputDims.emplace_back(B); cfOutputDims.emplace_back(H); @@ -137,7 +137,7 @@ LogicalResult ZHighGRUOp::verify() { //===----------------------------------------------------------------------===// LogicalResult ZHighGRUOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput()) || !hasRankedType(getHiddenWeights())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp index 41648352e0..8fdd8fc816 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/LSTM/LSTM.cpp @@ -51,14 +51,14 @@ LogicalResult ZHighLSTMOpShapeHelper::computeShape() { if (isAllTimesteps) hnOutputDims.emplace_back(S); else - hnOutputDims.emplace_back(LiteralIndexExpr(1)); + hnOutputDims.emplace_back(LitIE(1)); hnOutputDims.emplace_back(D); hnOutputDims.emplace_back(B); hnOutputDims.emplace_back(H); // Shape for cf_ouput : [1, D, B, H] DimsExpr cfOutputDims; - cfOutputDims.emplace_back(LiteralIndexExpr(1)); + cfOutputDims.emplace_back(LitIE(1)); cfOutputDims.emplace_back(D); cfOutputDims.emplace_back(B); cfOutputDims.emplace_back(H); @@ -139,7 +139,7 @@ LogicalResult ZHighLSTMOp::verify() { //===----------------------------------------------------------------------===// LogicalResult ZHighLSTMOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput()) || !hasRankedType(getHiddenWeights())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp index 6d556ea36f..a37dce3671 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MatMul/MatMul.cpp @@ -4,7 +4,7 @@ //===------------------ MatMul.cpp - ZHigh Operations ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -24,6 +24,7 @@ namespace zhigh { //===----------------------------------------------------------------------===// LogicalResult ZHighMatMulOpShapeHelper::computeShape() { + ZHighMatMulOp matmulOp = llvm::dyn_cast(op); ZHighMatMulOp::Adaptor operandAdaptor(operands); // Output dims of result. DimsExpr outputDims; @@ -32,6 +33,10 @@ LogicalResult ZHighMatMulOpShapeHelper::computeShape() { Value X = operandAdaptor.getX(); Value Y = operandAdaptor.getY(); + // Get transpose attributes. + int64_t transposeA = (matmulOp.getTransposeA() != 0) ? 1 : 0; + int64_t transposeB = (matmulOp.getTransposeB() != 0) ? 1 : 0; + // Get bounds SmallVector XDims, YDims; createIE->getShapeAsDims(X, XDims); @@ -42,46 +47,125 @@ LogicalResult ZHighMatMulOpShapeHelper::computeShape() { if (!(xRank == 2 || xRank == 3)) return failure(); + // Determine the dimensions of the output tensor. if (xRank == 2) { // X :: MxN - // Y :: NxP - outputDims.emplace_back(XDims[0]); - outputDims.emplace_back(YDims[1]); + int64_t xI = 0; + if (transposeA) + // X :: NxM + xI = 1; + if (yRank == 2) { + // Y :: NxP + int64_t yI = 1; + if (transposeB) + // Y :: PxN + yI = 0; + // Unstacked case: X:2D (m,n) - Y:2D (n,p) - Bias:1D (p) - Out:2D (m,p) + outputDims.emplace_back(XDims[xI]); + outputDims.emplace_back(YDims[yI]); + } else if (yRank == 3) { + // Y :: SxNxP + int64_t yI1 = 0; + int64_t yI2 = 2; + if (transposeB) { + // Y :: SxPxN + yI2 = 1; + } + // Broadcast 1 case: X:2D (m,n) - Y:3DS (s,n,p) - Bias:2DS (s,p) - Out:3DS + // (s,m,p) + outputDims.emplace_back(YDims[yI1]); + outputDims.emplace_back(XDims[xI]); + outputDims.emplace_back(YDims[yI2]); + isBroadcasted1 = true; + } } else if (xRank == 3) { // X :: SxMxN - outputDims.emplace_back(XDims[0]); - outputDims.emplace_back(XDims[1]); + int64_t xI1 = 0; + int64_t xI2 = 1; + if (transposeA) + // X :: SxNxM + xI2 = 2; if (yRank == 2) { // Y :: NxP - outputDims.emplace_back(YDims[1]); - isBroadcasted = true; + int64_t yI = 1; + if (transposeB) + // Y :: PxN + yI = 0; + // Broadcast 23 case: X:3DS (s,m,n) - Y:2D (n,p) - Bias:1D (p) - Out:3DS + // (s,m,p) + outputDims.emplace_back(XDims[xI1]); + outputDims.emplace_back(XDims[xI2]); + outputDims.emplace_back(YDims[yI]); + isBroadcasted23 = true; } else if (yRank == 3) { // Y :: SxNxP - outputDims.emplace_back(YDims[2]); + int64_t yI = 2; + if (transposeB) + // Y :: SxPxN + yI = 1; + // Stacked case: X:3DS (s,m,n) - Y:3DS (s,n,p) - Bias:2DS (s,p) - Out:3DS + // (s,m,p) + outputDims.emplace_back(XDims[xI1]); + outputDims.emplace_back(XDims[xI2]); + outputDims.emplace_back(YDims[yI]); isStacked = true; } } // Keep all original dimensions: M, N, P if 2D or S, M, N, P if 3D. if (xRank == 2) { - // M - allOriginalDims.emplace_back(XDims[0]); - // N - allOriginalDims.emplace_back(XDims[1]); - // P - allOriginalDims.emplace_back(YDims[1]); + if (transposeA) { + // M + allOriginalDims.emplace_back(XDims[1]); + // N + allOriginalDims.emplace_back(XDims[0]); + } else { + // M + allOriginalDims.emplace_back(XDims[0]); + // N + allOriginalDims.emplace_back(XDims[1]); + } + if (yRank == 2) { + // P + if (transposeB) + allOriginalDims.emplace_back(YDims[0]); + else + allOriginalDims.emplace_back(YDims[1]); + } else if (yRank == 3) { + // S + allOriginalDims.emplace_back(YDims[0]); + // P + if (transposeB) + allOriginalDims.emplace_back(YDims[1]); + else + allOriginalDims.emplace_back(YDims[2]); + } } else if (xRank == 3) { // S allOriginalDims.emplace_back(XDims[0]); - // M - allOriginalDims.emplace_back(XDims[1]); - // N - allOriginalDims.emplace_back(XDims[2]); + if (transposeA) { + // M + allOriginalDims.emplace_back(XDims[2]); + // N + allOriginalDims.emplace_back(XDims[1]); + } else { + // M + allOriginalDims.emplace_back(XDims[1]); + // N + allOriginalDims.emplace_back(XDims[2]); + } // P if (yRank == 2) - allOriginalDims.emplace_back(YDims[1]); - else if (yRank == 3) - allOriginalDims.emplace_back(YDims[2]); + if (transposeB) + allOriginalDims.emplace_back(YDims[0]); + else + allOriginalDims.emplace_back(YDims[1]); + else if (yRank == 3) { + if (transposeB) + allOriginalDims.emplace_back(YDims[1]); + else + allOriginalDims.emplace_back(YDims[2]); + } } // Save the final result. @@ -94,7 +178,7 @@ LogicalResult ZHighMatMulOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighMatMulOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getX()) || !hasRankedType(getY())) return success(); @@ -138,12 +222,20 @@ LogicalResult ZHighMatMulOp::verify() { (xLayout == ZTensorEncodingAttr::DataLayout::_3DS))) return failure(); - // If X is 2D, Y must be 2D and B must be 1D + // If X is 2D, Y must be 2D or 3DS. + // If X is 2D and Y is 2D, B must be 1D. + // If X is 2D and Y is 3DS, B must be 2DS. if (xLayout == ZTensorEncodingAttr::DataLayout::_2D) { - if (!(yLayout == ZTensorEncodingAttr::DataLayout::_2D)) - return failure(); - if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_1D)) + if (!((yLayout == ZTensorEncodingAttr::DataLayout::_2D) || + (yLayout == ZTensorEncodingAttr::DataLayout::_3DS))) return failure(); + if (yLayout == ZTensorEncodingAttr::DataLayout::_2D) { + if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_1D)) + return failure(); + } else if (yLayout == ZTensorEncodingAttr::DataLayout::_3DS) { + if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_2DS)) + return failure(); + } } // X is 3DS, valid types for (X, Y, B) are (3DS, 3DS, 2DS) or (3DS, 2D, 1D) diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp index 5153fbd6bf..1834be5055 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/MeanReduce2D/MeanReduce2D.cpp @@ -38,8 +38,8 @@ LogicalResult ZHighMeanReduce2DOpShapeHelper::computeShape() { // Input is NHWC, and H and W are reduction dimensions. outputDims.emplace_back(inputDims[0]); - outputDims.emplace_back(LiteralIndexExpr(1)); - outputDims.emplace_back(LiteralIndexExpr(1)); + outputDims.emplace_back(LitIE(1)); + outputDims.emplace_back(LitIE(1)); outputDims.emplace_back(inputDims[3]); // Save the final result. @@ -52,7 +52,7 @@ LogicalResult ZHighMeanReduce2DOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighMeanReduce2DOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp index 36affcd07c..c120bc6b44 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp @@ -4,7 +4,7 @@ //===-------- OpHelper.cpp - NNPA ZHigh Helper Functions ------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -140,6 +140,45 @@ StringAttr convertZTensorDataLayoutToStringAttr( return attr; } +/// Get a ztensor quantized type by StringAttr. +ZTensorEncodingAttr::QuantizedType convertStringAttrToZTensorQuantizedType( + StringAttr qtypeAttr) { + if (qtypeAttr) { + StringRef qtypeStr = qtypeAttr.getValue(); + if (qtypeStr.equals_insensitive(QTYPE_DLFLOAT16)) + return ZTensorEncodingAttr::QuantizedType::DLFLOAT16; + else if (qtypeStr.equals_insensitive(QTYPE_INT8)) + return ZTensorEncodingAttr::QuantizedType::INT8; + else if (qtypeStr.equals_insensitive(QTYPE_WEIGHTS)) + return ZTensorEncodingAttr::QuantizedType::WEIGHTS; + else if (qtypeStr.equals_insensitive(QTYPE_UNDEFINED)) + return ZTensorEncodingAttr::QuantizedType::UNDEFINED; + else + llvm_unreachable("Invalid quantized type string"); + } else + llvm_unreachable("Could not get quantized type by an empty StringAttr"); +} + +/// Convert a quantized type to StringAttr. +StringAttr convertZTensorQuantizedTypeToStringAttr( + OpBuilder &builder, ZTensorEncodingAttr::QuantizedType qtype) { + StringAttr attr; + switch (qtype) { + case ZTensorEncodingAttr::QuantizedType::DLFLOAT16: + attr = builder.getStringAttr(QTYPE_DLFLOAT16); + break; + case ZTensorEncodingAttr::QuantizedType::INT8: + attr = builder.getStringAttr(QTYPE_INT8); + break; + case ZTensorEncodingAttr::QuantizedType::WEIGHTS: + attr = builder.getStringAttr(QTYPE_WEIGHTS); + break; + default: + break; + } + return attr; +} + //===----------------------------------------------------------------------===// // Utility functions to query ztensor information. @@ -169,6 +208,12 @@ StringAttr getZTensorLayoutAttr(OpBuilder &builder, Type type) { return nullptr; } +ZTensorEncodingAttr::QuantizedType getZTensorQuantizedType(Type type) { + if (auto encoding = getZTensorEncoding(type)) + return encoding.getQuantizedType(); + return ZTensorEncodingAttr::QuantizedType::UNDEFINED; +} + //===----------------------------------------------------------------------===// // Utility functions. @@ -190,7 +235,7 @@ Value getConstantOfType( Type elementType = shapedType.getElementType(); DenseElementsAttr denseAttr; if (mlir::isa(elementType)) - denseAttr = DenseElementsAttr::get(shapedType, (int64_t)val); + denseAttr = DenseElementsAttr::get(shapedType, static_cast(val)); else if (mlir::isa(elementType)) denseAttr = DenseElementsAttr::get(shapedType, val); else @@ -217,7 +262,7 @@ bool oneIsOfLayout(Type t1, Type t2, /// Check if ONNXReshapeOp is reshaping 2D to 4D by tiling each input dimension. bool isTiling2DTo4D(Value val) { - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -246,7 +291,7 @@ bool isTiling2DTo4D(Value val) { /// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling the first input /// dimension. bool isTiling3DTo4D(Value val) { - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -276,7 +321,7 @@ bool isTiling3DTo4D(Value val) { /// Check if a 4D tensor is collapsed into 2D by merging the each two /// dimensions. bool isCollapsing4DTo2D(Value val) { - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -305,7 +350,7 @@ bool isCollapsing4DTo2D(Value val) { /// Check if a 4D tensor is collapsed into 3D by merging the first two /// dimensions. bool isCollapsing4DTo3D(Value val) { - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -336,7 +381,7 @@ AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) { assert(isTiling2DTo4D(val) && "ONNXReshapeOp is not suitable for getting a tiling affine map"); - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); Value output = reshapeOp.getReshaped(); Type outputType = output.getType(); ArrayRef outputShape = getShape(outputType); @@ -361,7 +406,7 @@ AffineMapAttr getTiling3DTo4DMap(OpBuilder &b, Value val) { assert(isTiling3DTo4D(val) && "ONNXReshapeOp is not suitable for getting a tiling affine map"); - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); Value output = reshapeOp.getReshaped(); Type outputType = output.getType(); ArrayRef outputShape = getShape(outputType); @@ -384,7 +429,7 @@ AffineMapAttr getCollapsing4DTo2DMap(OpBuilder &b, Value val) { assert(isCollapsing4DTo2D(val) && "ONNXReshapeOp is not suitable for getting a collapsing affine map"); - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); Value input = reshapeOp.getData(); Type inputType = input.getType(); ArrayRef inputShape = getShape(inputType); @@ -409,7 +454,7 @@ AffineMapAttr getCollapsing4DTo3DMap(OpBuilder &b, Value val) { assert(isCollapsing4DTo3D(val) && "ONNXReshapeOp is not suitable for getting a collapsing affine map"); - auto reshapeOp = dyn_cast(val.getDefiningOp()); + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); Value input = reshapeOp.getData(); Type inputType = input.getType(); ArrayRef inputShape = getShape(inputType); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp index def0813d7b..cc346ef17d 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp @@ -34,6 +34,14 @@ ZTensorEncodingAttr::DataLayout getZTensorDataLayoutByRank(int64_t rank); mlir::StringAttr convertZTensorDataLayoutToStringAttr( mlir::OpBuilder &builder, ZTensorEncodingAttr::DataLayout layout); +/// Get a ztensor quantized type by StringAttr. +ZTensorEncodingAttr::QuantizedType convertStringAttrToZTensorQuantizedType( + mlir::StringAttr qtypeAttr); + +/// Convert a quantized type to StringAttr. +mlir::StringAttr convertZTensorQuantizedTypeToStringAttr( + mlir::OpBuilder &builder, ZTensorEncodingAttr::QuantizedType qtype); + //===----------------------------------------------------------------------===// // Convenience method to query information of a ztensor @@ -51,6 +59,9 @@ ZTensorEncodingAttr::DataLayout getZTensorLayout(mlir::Type type); mlir::StringAttr getZTensorLayoutAttr( mlir::OpBuilder &builder, mlir::Type type); +/// Get the quantized type of a ztensor. +ZTensorEncodingAttr::QuantizedType getZTensorQuantizedType(mlir::Type type); + /// Get a minus value. mlir::Value getMinusBcastConst(mlir::OpBuilder &builder, mlir::Location loc, mlir::FloatAttr floatAttr, mlir::Value input); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp index 5f9a11a5ce..65948e90d5 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Pooling/Pooling.cpp @@ -94,7 +94,7 @@ template struct ZHighPoolingOpShapeHelper; //===----------------------------------------------------------------------===// LogicalResult ZHighMaxPool2DOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput())) return success(); @@ -110,7 +110,7 @@ LogicalResult ZHighMaxPool2DOp::inferShapes( //===----------------------------------------------------------------------===// LogicalResult ZHighAvgPool2DOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getInput())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp new file mode 100644 index 0000000000..bc9a34696f --- /dev/null +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp @@ -0,0 +1,177 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------- QuantizedMatMul.cpp - ZHigh Operations ----------------===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace zhigh { + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +LogicalResult ZHighQuantizedMatMulOpShapeHelper::computeShape() { + ZHighQuantizedMatMulOp::Adaptor operandAdaptor(operands); + // Output dims of result. + DimsExpr outputDims; + + // Get operands. + Value X = operandAdaptor.getX(); + Value Y = operandAdaptor.getY(); + + // Get bounds + SmallVector XDims, YDims; + createIE->getShapeAsDims(X, XDims); + createIE->getShapeAsDims(Y, YDims); + int64_t xRank = XDims.size(); + int64_t yRank = YDims.size(); + + if (!(xRank == 2 || xRank == 3)) + return failure(); + + if (xRank == 2) { + // X :: MxN + int64_t xI = 0; + outputDims.emplace_back(XDims[xI]); + // Y :: NxP + int64_t yI = 1; + outputDims.emplace_back(YDims[yI]); + } else if (xRank == 3) { + // X :: SxMxN + outputDims.emplace_back(XDims[0]); + outputDims.emplace_back(XDims[1]); + if (yRank == 2) { + // Y :: NxP + outputDims.emplace_back(YDims[1]); + isBroadcasted = true; + } else if (yRank == 3) { + // Y :: SxNxP + outputDims.emplace_back(YDims[2]); + isStacked = true; + } + } + + // Keep all original dimensions: M, N, P if 2D or S, M, N, P if 3D. + if (xRank == 2) { + // M + allOriginalDims.emplace_back(XDims[0]); + // N + allOriginalDims.emplace_back(XDims[1]); + // P + allOriginalDims.emplace_back(YDims[1]); + } else if (xRank == 3) { + // S + allOriginalDims.emplace_back(XDims[0]); + // M + allOriginalDims.emplace_back(XDims[1]); + // N + allOriginalDims.emplace_back(XDims[2]); + // P + if (yRank == 2) + allOriginalDims.emplace_back(YDims[1]); + else if (yRank == 3) + allOriginalDims.emplace_back(YDims[2]); + } + + // Save the final result. + setOutputDims(outputDims); + return success(); +} + +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// + +LogicalResult ZHighQuantizedMatMulOp::inferShapes( + std::function doShapeInference) { + if (!hasRankedType(getX()) || !hasRankedType(getY())) + return success(); + + bool dequantizeOutput = (getDequantizeOutput() == -1); + ZHighQuantizedMatMulOpShapeHelper shapeHelper(getOperation()); + shapeHelper.computeShapeAndAssertOnFailure(); + + SmallVector outputDims; + IndexExpr::getShape(shapeHelper.getOutputDims(), outputDims); + Type elementType = + mlir::cast(getResult(0).getType()).getElementType(); + ZTensorEncodingAttr encoding; + ZTensorEncodingAttr::QuantizedType qtype = + ZTensorEncodingAttr::QuantizedType::DLFLOAT16; + if (dequantizeOutput) + qtype = ZTensorEncodingAttr::QuantizedType::UNDEFINED; + if (outputDims.size() == 2) + encoding = ZTensorEncodingAttr::get( + this->getContext(), ZTensorEncodingAttr::DataLayout::_2D, qtype); + else if (outputDims.size() == 3) + encoding = ZTensorEncodingAttr::get( + this->getContext(), ZTensorEncodingAttr::DataLayout::_3DS, qtype); + + updateType(getOperation(), getResult(0), outputDims, elementType, encoding); + return success(); +} + +LogicalResult ZHighQuantizedMatMulOp::verify() { + ZHighQuantizedMatMulOpAdaptor operandAdaptor(*this); + // Get operands. + Value X = operandAdaptor.getX(); + Value Y = operandAdaptor.getY(); + Value B = operandAdaptor.getB(); + + if (!hasRankedType(X) || !hasRankedType(Y)) + return success(); + + // Get layouts. + ZTensorEncodingAttr::DataLayout xLayout = getZTensorLayout(X.getType()); + ZTensorEncodingAttr::DataLayout yLayout = getZTensorLayout(Y.getType()); + // Bias can be None. + ZTensorEncodingAttr::DataLayout bLayout; + bool hasBias = !mlir::isa(B.getType()); + if (hasBias) { + if (!hasRankedType(B)) + return success(); + bLayout = getZTensorLayout(B.getType()); + } + + // X must be 2D or 3DS. + if (!((xLayout == ZTensorEncodingAttr::DataLayout::_2D) || + (xLayout == ZTensorEncodingAttr::DataLayout::_3DS))) + return failure(); + + // If X is 2D, Y must be 2D and B must be 1D + if (xLayout == ZTensorEncodingAttr::DataLayout::_2D) { + if (!(yLayout == ZTensorEncodingAttr::DataLayout::_2D)) + return failure(); + if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_1D)) + return failure(); + } + + // X is 3DS, valid types for (X, Y, B) are (3DS, 3DS, 2DS) or (3DS, 2D, 1D) + if (xLayout == ZTensorEncodingAttr::DataLayout::_3DS) { + if (yLayout == ZTensorEncodingAttr::DataLayout::_3DS) { + if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_2DS)) + return failure(); + } else if (yLayout == ZTensorEncodingAttr::DataLayout::_2D) { + if (hasBias && !(bLayout == ZTensorEncodingAttr::DataLayout::_1D)) + return failure(); + } + } + + return success(); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedStick/QuantizedStick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedStick/QuantizedStick.cpp new file mode 100644 index 0000000000..9e8f9515f3 --- /dev/null +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/QuantizedStick/QuantizedStick.cpp @@ -0,0 +1,205 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ Stick.cpp - ZHigh Operations ----------------------===// +// +// Copyright 2023 The IBM Research Authors. +// +// ============================================================================= +// +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace zhigh { + +//===----------------------------------------------------------------------===// +// Custom builders +//===----------------------------------------------------------------------===// + +void ZHighQuantizedStickOp::build(OpBuilder &builder, OperationState &state, + Value input, Value recScale, Value offset, StringAttr layout, + StringAttr qtype, IntegerAttr symMode) { + // Quantized type. + auto quantizedType = convertStringAttrToZTensorQuantizedType(qtype); + + Type resElementType; + if (quantizedType == ZTensorEncodingAttr::QuantizedType::DLFLOAT16) + resElementType = builder.getF16Type(); + else if (quantizedType == ZTensorEncodingAttr::QuantizedType::INT8) + resElementType = builder.getI8Type(); + else if (quantizedType == ZTensorEncodingAttr::QuantizedType::WEIGHTS) + resElementType = builder.getI8Type(); + else + llvm_unreachable("Unsupported quantized transform type"); + + Type resType = builder.getNoneType(); + if (!mlir::isa(input.getType())) { + ShapedType inputType = mlir::cast(input.getType()); + int64_t rank = -1; + if (inputType.hasRank()) { + rank = inputType.getRank(); + ZTensorEncodingAttr::DataLayout dataLayout; + if (layout) + dataLayout = convertStringAttrToZTensorDataLayout(layout); + else { + dataLayout = getZTensorDataLayoutByRank(rank); + // Create a layout attribute. + layout = convertZTensorDataLayoutToStringAttr(builder, dataLayout); + } + // Compute shape. + ArrayRef inputShape = inputType.getShape(); + SmallVector resShape(inputShape.begin(), inputShape.end()); + resType = RankedTensorType::get(resShape, resElementType, + ZTensorEncodingAttr::get( + builder.getContext(), dataLayout, quantizedType)); + } else { + resType = UnrankedTensorType::get(resElementType); + } + } + RankedTensorType scalarTensorF32Type = + RankedTensorType::get({}, builder.getF32Type()); + build(builder, state, {resType, scalarTensorF32Type, scalarTensorF32Type}, + input, recScale, offset, layout, qtype, symMode); +} + +void ZHighQuantizedStickOp::build(OpBuilder &builder, OperationState &state, + Value input, Value recScale, Value offset, StringAttr layout, + StringAttr qtype) { + // By default, sym_mode is off. + IntegerAttr symMode = builder.getIntegerAttr(builder.getI64Type(), 0); + build(builder, state, input, recScale, offset, layout, qtype, symMode); +} + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +LogicalResult ZHighQuantizedStickOpShapeHelper::computeShape() { + ZHighQuantizedStickOp::Adaptor operandAdaptor(operands); + Value input = operandAdaptor.getIn(); + + // Output dims of result. + DimsExpr outputDims; + + // Get operands and bounds. + SmallVector inputDims; + createIE->getShapeAsDims(input, inputDims); + int64_t rank = inputDims.size(); + + for (int64_t i = 0; i < rank; ++i) + outputDims.emplace_back(inputDims[i]); + + // Save the final result. + setOutputDims(outputDims); + return success(); +} + +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// + +LogicalResult ZHighQuantizedStickOp::inferShapes( + std::function doShapeInference) { + Operation *op = getOperation(); + OpBuilder builder(op); + + Value input = getIn(); + if (isa(input.getType()) || !hasRankedType(input)) + return success(); + + auto inputType = mlir::cast(input.getType()); + StringAttr layout = getLayoutAttr(); + StringAttr qtype = getQuantizedTypeAttr(); + int64_t rank = inputType.getRank(); + + ZTensorEncodingAttr::DataLayout dataLayout; + if (layout) + dataLayout = convertStringAttrToZTensorDataLayout(layout); + else + dataLayout = getZTensorDataLayoutByRank(rank); + ZTensorEncodingAttr::QuantizedType quantizedType = + convertStringAttrToZTensorQuantizedType(qtype); + auto encoding = + ZTensorEncodingAttr::get(this->getContext(), dataLayout, quantizedType); + + Type resElementType; + if (quantizedType == ZTensorEncodingAttr::QuantizedType::DLFLOAT16) + resElementType = builder.getF16Type(); + else if (quantizedType == ZTensorEncodingAttr::QuantizedType::INT8) + resElementType = builder.getI8Type(); + else if (quantizedType == ZTensorEncodingAttr::QuantizedType::WEIGHTS) + resElementType = builder.getI8Type(); + else + llvm_unreachable("Unsupported quantized transform type"); + + ZHighQuantizedStickOpShapeHelper shapeHelper(getOperation()); + shapeHelper.computeShapeAndAssertOnFailure(); + SmallVector outputDims; + IndexExpr::getShape(shapeHelper.getOutputDims(0), outputDims); + + updateType(op, getResults()[0], outputDims, resElementType, encoding); + getResults()[1].setType(RankedTensorType::get({}, builder.getF32Type())); + getResults()[2].setType(RankedTensorType::get({}, builder.getF32Type())); + return success(); +} + +//===----------------------------------------------------------------------===// +// Canonicalization patterns +//===----------------------------------------------------------------------===// + +class QuantizedStickUnstickRemovalPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ZHighQuantizedStickOp qStickOp, + PatternRewriter &rewriter) const override { + Location loc = qStickOp.getLoc(); + Value input = qStickOp.getIn(); + StringAttr quantizedType = qStickOp.getQuantizedTypeAttr(); + + // ZHighQuantizedStickOp's type is dlfloat16. + if (!quantizedType.getValue().equals_insensitive(QTYPE_DLFLOAT16)) + return failure(); + + // ZHighQuantizedStickOp's input was defined by ZHighUnstickOp. + auto unstickOp = input.getDefiningOp(); + if (!unstickOp) + return failure(); + // Stickified input's layout is 3D, 2DS or 3DS. + Value stickInput = unstickOp.getIn(); + StringAttr stickLayout = + getZTensorLayoutAttr(rewriter, stickInput.getType()); + if (!(stickLayout.getValue().equals_insensitive("3D") || + stickLayout.getValue().equals_insensitive("2DS") || + stickLayout.getValue().equals_insensitive("3DS"))) + return failure(); + // Match layout. + StringAttr qStickLayout = qStickOp.getLayoutAttr(); + if (stickLayout != qStickLayout) + return failure(); + + // Rewrite by passing the stickified input directly to ZHighQuantizedStick. + ZHighQuantizedStickOp newQStickOp = rewriter.create( + loc, stickInput, qStickOp.getInRecScale(), qStickOp.getInOffset(), + qStickOp.getLayoutAttr(), qStickOp.getQuantizedTypeAttr()); + rewriter.replaceOp(qStickOp, newQStickOp.getResults()); + return success(); + } +}; + +void ZHighQuantizedStickOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + if (nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset) + results.insert(context); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reduction/Reduction.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reduction/Reduction.cpp new file mode 100644 index 0000000000..02daf76ad0 --- /dev/null +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reduction/Reduction.cpp @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ Reduction.cpp - ZHigh Operations ----------------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace zhigh { + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// +template +LogicalResult ZHighReductionOpShapeHelper::computeShape() { + typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); + + // Get operand. + Value data = operandAdaptor.getData(); + + // Output dims of result. + DimsExpr outputDims; + + // Get operands and bounds. + SmallVector inputDims; + createIE->getShapeAsDims(data, inputDims); + + // NNPA only supports reduction over the innermost dimension. + // So set the innermost dimension of the output to one. + int64_t axis = inputDims.size() - 1; + LiteralIndexExpr one(1); + // Copy the input until the second to last dimension + for (int64_t i = 0; i < axis; ++i) { + outputDims.emplace_back(inputDims[i]); + } + outputDims.emplace_back(one); + + // Save the final result. + setOutputDims(outputDims); + return success(); +} + +//===----------------------------------------------------------------------===// +// ZHigh Shape Helper template instantiation +// Keep template instantiation at the end of the file. +//===----------------------------------------------------------------------===// + +template struct ZHighReductionOpShapeHelper; +template struct ZHighReductionOpShapeHelper; + +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// +template +static LogicalResult inferShapeForReductionOps(OP_TYPE &op) { + typename OP_TYPE::Adaptor operandAdaptor(op); + if (!hasRankedType(operandAdaptor.getData())) + return success(); + RankedTensorType dataType = + mlir::cast(operandAdaptor.getData().getType()); + ZHighReductionOpShapeHelper shapeHelper(op.getOperation(), {}); + return shapeHelper.computeShapeAndUpdateType( + dataType.getElementType(), dataType.getEncoding()); +} + +//===----------------------------------------------------------------------===// +// ReduceMax +//===----------------------------------------------------------------------===// + +LogicalResult ZHighReduceMaxOp::inferShapes( + std::function doShapeInference) { + return inferShapeForReductionOps(*this); +} + +//===----------------------------------------------------------------------===// +// ReduceMin +//===----------------------------------------------------------------------===// + +LogicalResult ZHighReduceMinOp::inferShapes( + std::function doShapeInference) { + return inferShapeForReductionOps(*this); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp index cb8194f408..f9427116af 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp @@ -54,6 +54,7 @@ DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickForGRUOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickForLSTMOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickifiedConstantOfShapeOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickOpShapeHelper) +DECLARE_SHAPE_HELPER_ZHIGH(ZHighQuantizedStickOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighUnstickOpShapeHelper) #undef DECLARE_SHAPE_HELPER_ZHIGH @@ -68,6 +69,27 @@ struct ZHighMatMulOpShapeHelper : public ONNXOpShapeHelper { : ONNXOpShapeHelper(op, operands, ieBuilder, scope) {} virtual ~ZHighMatMulOpShapeHelper() {} mlir::LogicalResult computeShape() final; + // Broadcast 1 case: X:2D - Y:3DS + bool isBroadcasted1 = false; + // Broadcast 23 case: X:3DS - Y:2D + bool isBroadcasted23 = false; + // Stack case: X:3DS - Y:3DS + bool isStacked = false; + // Keep original dimensions in this order: m, n, p if 2D or s, m, n, p if 3D. + DimsExpr allOriginalDims; +}; + +//===----------------------------------------------------------------------===// +// Shape helper for QuantizedMatMulOp. +//===----------------------------------------------------------------------===// + +struct ZHighQuantizedMatMulOpShapeHelper : public ONNXOpShapeHelper { + ZHighQuantizedMatMulOpShapeHelper(mlir::Operation *op, + mlir::ArrayRef operands = {}, + IndexExprBuilder *ieBuilder = nullptr, IndexExprScope *scope = nullptr) + : ONNXOpShapeHelper(op, operands, ieBuilder, scope) {} + virtual ~ZHighQuantizedMatMulOpShapeHelper() {} + mlir::LogicalResult computeShape() final; // Broadcast case: X:3DS - Y:2D bool isBroadcasted = false; // Stack case: X:3DS - Y:3DS @@ -145,6 +167,25 @@ struct ZHighPoolingOpShapeHelper : public ONNXOpShapeHelper { DimsExpr allOriginalDims; }; +//===----------------------------------------------------------------------===// +// Shape helper for ReductionOp. +//===----------------------------------------------------------------------===// + +template +struct ZHighReductionOpShapeHelper : public ONNXOpShapeHelper { + ZHighReductionOpShapeHelper(mlir::Operation *op, + mlir::ArrayRef operands = {}, + IndexExprBuilder *ieBuilder = nullptr, IndexExprScope *scope = nullptr) + : ONNXOpShapeHelper(op, operands, ieBuilder, scope) {} + virtual ~ZHighReductionOpShapeHelper() {} + mlir::LogicalResult computeShape() final; +}; + +using ZHighReduceMaxOpShapeHelper = + ZHighReductionOpShapeHelper; +using ZHighReduceMinOpShapeHelper = + ZHighReductionOpShapeHelper; + //===----------------------------------------------------------------------===// // Shape helper for UnaryOp. //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Softmax/Softmax.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Softmax/Softmax.cpp index 1fd28cc079..dbeb197a62 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Softmax/Softmax.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Softmax/Softmax.cpp @@ -20,7 +20,7 @@ namespace onnx_mlir { namespace zhigh { LogicalResult ZHighSoftmaxOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { return inferShapeForUnaryOps(this->getOperation()); } diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp index 327a6ffd9c..179cfe139b 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp @@ -106,7 +106,7 @@ LogicalResult ZHighStickOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighStickOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { Value input = getIn(); if (isa(input.getType()) || !hasRankedType(input)) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp index 911d343c02..4c0d1cb031 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForGRU/StickForGRU.cpp @@ -37,7 +37,7 @@ LogicalResult ZHighStickForGRUOpShapeHelper::computeShape() { for (int64_t i = 0; i < rank - 1; ++i) outputDims.emplace_back(zGateDims[i]); - IndexExpr lastDim = zGateDims[rank - 1] * LiteralIndexExpr(3); + IndexExpr lastDim = zGateDims[rank - 1] * LitIE(3); outputDims.emplace_back(lastDim); // Save the final result. @@ -50,7 +50,7 @@ LogicalResult ZHighStickForGRUOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighStickForGRUOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getZGate()) && !hasRankedType(getRGate()) && !hasRankedType(getHGate())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp index 8f1b4a07a1..ed92a620e7 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickForLSTM/StickForLSTM.cpp @@ -37,7 +37,7 @@ LogicalResult ZHighStickForLSTMOpShapeHelper::computeShape() { for (int64_t i = 0; i < rank - 1; ++i) outputDims.emplace_back(fGateDims[i]); - IndexExpr lastDim = fGateDims[rank - 1] * LiteralIndexExpr(4); + IndexExpr lastDim = fGateDims[rank - 1] * LitIE(4); outputDims.emplace_back(lastDim); // Save the final result. @@ -50,7 +50,7 @@ LogicalResult ZHighStickForLSTMOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighStickForLSTMOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getFGate()) && !hasRankedType(getIGate()) && !hasRankedType(getCGate()) && !hasRankedType(getOGate())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp index c46f97dd79..46bf1e943e 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/StickifiedConstantOfShape/StickifiedConstantOfShape.cpp @@ -92,7 +92,7 @@ LogicalResult ZHighStickifiedConstantOfShapeOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighStickifiedConstantOfShapeOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { Value shape = getShape(); if (!hasRankedType(shape)) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp index 77152ba81c..5e4b84fdf1 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Unstick/Unstick.cpp @@ -100,7 +100,7 @@ LogicalResult ZHighUnstickOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ZHighUnstickOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { if (!hasRankedType(getIn())) return success(); diff --git a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.cpp b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.cpp index cc981bff67..c08aca9e87 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.cpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.cpp @@ -15,26 +15,53 @@ #include "llvm/ADT/TypeSwitch.h" #include "src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp" +#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" using namespace mlir; namespace onnx_mlir { +// ============================================================================= +// ZLow Builder for building ZLow operations +// ============================================================================= + +void ZLowBuilder::stick( + Value x, Value out, StringAttr layout, IntegerAttr saturation) const { + b().create(loc(), x, out, layout, saturation); +} + +void ZLowBuilder::quantizedStick(Value x, Value recScale, Value offset, + Value out, StringAttr layout, StringAttr qType) const { + b().create( + loc(), x, recScale, offset, out, layout, qType); +} + +void ZLowBuilder::quantizedMatMul(Value x, Value xRecScale, Value xOffset, + Value y, Value yRecScale, Value yOffset, Value bias, Value biasRecScale, + Value biasOffset, Value workArea, Value shape, Value out, Value outRecScale, + Value outOffset, StringAttr xQType, StringAttr yQType, StringAttr biasQType, + StringAttr outQType, IntegerAttr isBcast, IntegerAttr isStacked, + IntegerAttr preComputedBias, IntegerAttr disableClipping, + IntegerAttr dequantizeOutput) const { + b().create(loc(), x, xRecScale, xOffset, y, + yRecScale, yOffset, bias, biasRecScale, biasOffset, workArea, shape, out, + outRecScale, outOffset, xQType, yQType, biasQType, outQType, isBcast, + isStacked, preComputedBias, disableClipping, dequantizeOutput); +} + // ============================================================================= // IndexExpr Builder for Analysis // ============================================================================= // Return null if none is found. -ElementsAttr IndexExprBuilderForZLow::getConst(mlir::Value value) { - return nullptr; -} +ElementsAttr IndexExprBuilderForZLow::getConst(Value value) { return nullptr; } Value IndexExprBuilderForZLow::getVal(Value intArrayVal, uint64_t i) { MultiDialectBuilder create(*this); uint64_t rank = getShapedTypeRank(intArrayVal); if (rank == 0) - return create.affine.load(intArrayVal, {}); + return create.affine.load(intArrayVal); uint64_t size = getArraySize(intArrayVal); assert(i < size && "out of bound reference"); Value iVal = create.math.constantIndex(i); diff --git a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp index b3310f0373..5af5db24c3 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/DialectBuilder.hpp @@ -37,10 +37,49 @@ struct IndexExprBuilderForZLow : IndexExprBuilder { mlir::Value getShapeVal(mlir::Value tensorOrMemrefValue, uint64_t i) final; }; +// ============================================================================= +// ZLow Builder for building ZLow operations +// ============================================================================= + +struct ZLowBuilder : public DialectBuilder { + ZLowBuilder(mlir::Location loc) : DialectBuilder(loc) {} + ZLowBuilder(mlir::OpBuilder &b, mlir::Location loc) + : DialectBuilder(b, loc) {} + ZLowBuilder(const DialectBuilder &db) : DialectBuilder(db) {} + virtual ~ZLowBuilder() {} + + void stick(mlir::Value x, mlir::Value out, mlir::StringAttr layout, + mlir::IntegerAttr saturation) const; + + void quantizedStick(mlir::Value x, mlir::Value xRecScale, mlir::Value xOffset, + mlir::Value out, mlir::StringAttr layout, mlir::StringAttr qType) const; + + void quantizedMatMul(mlir::Value x, mlir::Value xRecScale, + mlir::Value xOffset, mlir::Value y, mlir::Value yRecScale, + mlir::Value yOffset, mlir::Value b, mlir::Value bRecScale, + mlir::Value bOffset, mlir::Value workArea, mlir::Value shape, + mlir::Value out, mlir::Value outRecScale, mlir::Value outOffset, + mlir::StringAttr xQType, mlir::StringAttr yQType, mlir::StringAttr bQType, + mlir::StringAttr outQType, mlir::IntegerAttr isBcast, + mlir::IntegerAttr isStacked, mlir::IntegerAttr preComputedBias, + mlir::IntegerAttr disableClipping, + mlir::IntegerAttr dequantizeOutput) const; +}; + // ============================================================================= // MultiDialectBuilder for ZLow // ============================================================================= +// Recursive class specialized for ZLowBuilder referred to as krnl. +template +struct MultiDialectBuilder : MultiDialectBuilder { + MultiDialectBuilder(mlir::OpBuilder &b, mlir::Location loc) + : MultiDialectBuilder(b, loc), zlow(b, loc) {} + MultiDialectBuilder(const DialectBuilder &db) + : MultiDialectBuilder(db), zlow(db) {} + ZLowBuilder zlow; +}; + // Recursive class specialized for IndexExprBuilderForZLow referred to as // zlowIE. template diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td index 63fcb0704d..a66cb8273f 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td @@ -2,7 +2,7 @@ //===-- ZLowOps.td -- ZLow Dialect Operation Definitions -*- tablegen ------==// // -// Copyright 2019-2020 The IBM Research Authors +// Copyright 2019-2024 The IBM Research Authors // // ============================================================================= // @@ -36,8 +36,13 @@ class ZLow_Op traits = []> : def DLF16 : Type, "dlfloat16 type">, BuildableType<"$_builder.getF16Type()">; +// 0-rank MemRef for scalar. +def ODMemRefF32: MemRefRankOf<[F32], [0]>; + // MemRef-like type for zTensor. def ZMemRef : MemRefOf<[DLF16]>; +// Quantized zTensor. +def ZQMemRef : MemRefOf<[DLF16, I8]>; //===----------------------------------------------------------------------===// // ZLow Operations @@ -121,6 +126,19 @@ def ZLowExpOp:ZLow_Op<"exp", [MemRefsNormalizable, StrAttr:$layout); } + +def ZLowInvSqrtOp:ZLow_Op<"invsqrt", [MemRefsNormalizable]> { + let summary = "ZLow invsqrt operation"; + let description = [{ + ZLow operation to perform a invsqrt. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + StrAttr:$layout); +} + + def ZLowMinOp:ZLow_Op<"min", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "ZLow min operation"; @@ -147,6 +165,18 @@ def ZLowMaxOp:ZLow_Op<"max", [MemRefsNormalizable, StrAttr:$layout); } +def ZLowLeakyReluOp:ZLow_Op<"leakyrelu", [MemRefsNormalizable]> { + let summary = "ZLow leakyrelu operation"; + let description = [{ + ZLow operation to perform a leakyrelu. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + DefaultValuedAttr:$alpha, + StrAttr:$layout); +} + def ZLowReluOp:ZLow_Op<"relu", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "ZLow relu operation"; @@ -159,6 +189,17 @@ def ZLowReluOp:ZLow_Op<"relu", [MemRefsNormalizable, StrAttr:$layout); } +def ZLowGeluOp:ZLow_Op<"gelu", [MemRefsNormalizable]> { + let summary = "ZLow gelu operation"; + let description = [{ + ZLow operation to perform a gelu. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + StrAttr:$layout); +} + def ZLowTanhOp:ZLow_Op<"tanh", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "ZLow tanh operation"; @@ -198,6 +239,41 @@ def ZLowSoftmaxOp:ZLow_Op<"softmax", [MemRefsNormalizable, StrAttr:$act_func); } +def ZLowSqrtOp:ZLow_Op<"sqrt", [MemRefsNormalizable]> { + let summary = "ZLow sqrt operation"; + let description = [{ + ZLow operation to perform a sqrt. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + StrAttr:$layout); +} + +def ZLowReduceMaxOp:ZLow_Op<"reducemax", [MemRefsNormalizable]> { + let summary = "ZLow reducemax operation"; + let description = [{ + ZLow operation to perform a reducemax. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I8]>:$work_area, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + StrAttr:$layout); +} + +def ZLowReduceMinOp:ZLow_Op<"reducemin", [MemRefsNormalizable]> { + let summary = "ZLow reducemin operation"; + let description = [{ + ZLow operation to perform a reducemin. + }]; + let arguments = (ins ZMemRef:$X, + MemRefOf<[I8]>:$work_area, + MemRefOf<[I64]>:$shape, + ZMemRef:$Out, + StrAttr:$layout); +} + def ZLowMatMulOp:ZLow_Op<"matmul", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "ZLow matmul operation"; @@ -209,22 +285,71 @@ def ZLowMatMulOp:ZLow_Op<"matmul", [MemRefsNormalizable, * 2nd item: n * 3rd item: p * In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) - or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) + or broadcasting1: X(m, n) * Y(s, n, p) + Bias(s, p) + or broadcasting23: X(s, m, n) * Y(n, p) + Bias(p) shape is a 1D MemRef (memref<4xi64>) whose items are: * 1st item: s * 2nd item: m * 3rd item: n * 4th item: p - * is_bcast: -1 broadcasting, 0: no broadcasting. + * is_bcast1: -1 broadcasting1, 0: no broadcasting1. + * is_bcast23: -1 broadcasting23, 0: no broadcasting23. * is_stacked: -1 stacked, 0: unstacked. + * transposeA: !0 transpose A, 0: do not transpose A. + * transposeB: !0 transpose B, 0: do not transpose B. }]; let arguments = (ins ZMemRef:$X, ZMemRef:$Y, ZMemRef:$Bias, MemRefOf<[I64]>:$shape, ZMemRef:$Out, + DefaultValuedAttr:$is_bcast1, + DefaultValuedAttr:$is_bcast23, + DefaultValuedAttr:$is_stacked, + DefaultValuedAttr:$transposeA, + DefaultValuedAttr:$transposeB); +} + +def ZLowQuantizedMatMulOp:ZLow_Op<"quantizedMatmul", [MemRefsNormalizable]> { + let summary = "ZLow quantized matmul operation"; + let description = [{ + ZLow operation to perform a matmul. + work_area: a 4K-aligned buffer having the same layout as bias but dlfloat16 type. + * In case of unstacked: X(m, n) * Y(n, p) + Bias(p) + shape is a 1D MemRef (memref<3xi64>) whose items are: + * 1st item: m + * 2nd item: n + * 3rd item: p + * In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p) + or broadcasting: X(s, m, n) * Y(n, p) + Bias(p) + shape is a 1D MemRef (memref<4xi64>) whose items are: + * 1st item: s + * 2nd item: m + * 3rd item: n + * 4th item: p + * is_bcast: -1 broadcasting, 0: no broadcasting. + * is_stacked: -1 stacked, 0: unstacked. + * DequantizeOutput: -1 output is dequantized, 0: output is not dequantized. + * PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed. + + Values for `q_type` are "DLFLOAT16", "INT8", "WEIGHTS", "UNDEFINED". + + }]; + let arguments = (ins ZQMemRef:$X, ODMemRefF32:$x_rec_scale, ODMemRefF32:$x_offset, + ZQMemRef:$Y, ODMemRefF32:$y_rec_scale, ODMemRefF32:$y_offset, + ZQMemRef:$Bias, ODMemRefF32:$bias_rec_scale, ODMemRefF32:$bias_offset, + AnyTypeOf<[ZQMemRef, NoneType]>:$work_area, + MemRefOf<[I64]>:$shape, + ZQMemRef:$Out, ODMemRefF32:$out_rec_scale, ODMemRefF32:$out_offset, + StrAttr:$x_q_type, + StrAttr:$y_q_type, + StrAttr:$bias_q_type, + StrAttr:$out_q_type, DefaultValuedAttr:$is_bcast, - DefaultValuedAttr:$is_stacked); + DefaultValuedAttr:$is_stacked, + DefaultValuedAttr:$pre_computed_bias, + DefaultValuedAttr:$disable_clipping, + DefaultValuedAttr:$dequantize_output); } def ZLowLSTMOp:ZLow_Op<"lstm", [MemRefsNormalizable, @@ -340,6 +465,21 @@ def ZLowStickForGRUOp:ZLow_Op<"stickForGRU", [MemRefsNormalizable, DefaultValuedStrAttr:$prev_layer); } +def ZLowQuantizedStickOp:ZLow_Op<"quantizedStick", [MemRefsNormalizable]> { + let summary = "ZLow stick operation for quantization"; + let description = [{ + "ZLow operation to perform a quantization stick." + "Type is one of values: dlfloat16, int8, and weights." + }]; + let arguments = (ins MemRefOf<[I8, F32]>:$X, + MemRefRankOf<[F32], [0]>:$rec_scale, + MemRefRankOf<[F32], [0]>:$offset, + ZQMemRef:$out, + StrAttr:$layout, + StrAttr:$q_type); + let hasVerifier = 1; +} + def ZLowUnstickOp:ZLow_Op<"unstick", [MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "ZLow unstick operation"; diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp index 7526933777..677c666bcc 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp @@ -357,6 +357,48 @@ void ZLowBatchNormOp::getEffects( effects.emplace_back(MemoryEffects::Read::get(), &getShapeMutable(), SideEffects::DefaultResource::get()); } +//===----------------------------------------------------------------------===// +// ZLowOps methods +//===----------------------------------------------------------------------===// + +LogicalResult ZLowQuantizedStickOp::verify() { + ZLowQuantizedStickOp::Adaptor operandAdaptor(*this); + Value recScale = operandAdaptor.getRecScale(); + Value offset = operandAdaptor.getOffset(); + Value output = operandAdaptor.getOut(); + auto outputType = llvm::dyn_cast(output.getType()); + if (!outputType) + return failure(); + + // Verify quantized type. + StringRef quantizedType = getQType(); + if (!(quantizedType.equals_insensitive("dlfloat16") || + quantizedType.equals_insensitive("int8") || + quantizedType.equals_insensitive("weights"))) + return emitOpError("q_type must be one of dlfloat16, int8, and weights"); + + // Verify element type of the output. + // TODO: should we have a more stricted contraint, e.g. signed integer? + Type elementType = outputType.getElementType(); + if (quantizedType.equals_insensitive("dfloat16") && !elementType.isF16()) + return emitOpError("q_type and element type mismatched"); + if (quantizedType.equals_insensitive("int8") && !elementType.isInteger(8)) + return emitOpError("q_type and element type mismatched"); + if (quantizedType.equals_insensitive("weights") && !elementType.isInteger(8)) + return emitOpError("q_type and element type mismatched"); + + // Verify recScale and offset. + if (auto ty = llvm::dyn_cast(recScale.getType())) { + if (!ty.getElementType().isF32()) + return emitOpError("recScale must be f32"); + } + if (auto ty = llvm::dyn_cast(offset.getType())) { + if (!ty.getElementType().isF32()) + return emitOpError("offset must be f32"); + } + + return success(); +} } // namespace zlow } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index afd1bf5596..2e4a06c477 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -4,7 +4,7 @@ //===-------------------------- NNPAAccelerator.cpp -----------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -52,9 +52,9 @@ NNPAAccelerator::NNPAAccelerator() : Accelerator(Accelerator::Kind::NNPA) { LLVM_DEBUG(llvm::dbgs() << "Creating an NNPA accelerator\n"); // Print a warning if mcpu is not set or < z16. - if (!isCompatibleWithNNPALevel(NNPA_Z16)) - llvm::outs() << "Warning: No NNPA code is generated because --mcpu is not " - "set or < z16.\n"; + if (!isCompatibleWithNNPALevel(NNPALevel::M14)) + llvm::outs() << "\nWarning: No NNPA code is generated because:\n" + " --march is not set/older than z16.\n\n"; acceleratorTargets.push_back(this); // Order is important! libRuntimeNNPA depends on libzdnn @@ -63,7 +63,12 @@ NNPAAccelerator::NNPAAccelerator() : Accelerator(Accelerator::Kind::NNPA) { NNPAAccelerator::~NNPAAccelerator() { delete instance; } -uint64_t NNPAAccelerator::getVersionNumber() const { return ZDNN_VERNUM; } +// Return accelerator version number based on compile NNPA version +uint64_t NNPAAccelerator::getVersionNumber() const { + if (isCompatibleWithNNPALevel(NNPALevel::M15)) + return NNPA_ZDNN_VERSIONS[NNPALevel::M15]; + return NNPA_ZDNN_VERSIONS[NNPALevel::M14]; +} void NNPAAccelerator::addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, @@ -162,8 +167,10 @@ void NNPAAccelerator::conversionTargetONNXToKrnl( void NNPAAccelerator::rewritePatternONNXToKrnl( mlir::RewritePatternSet &patterns, mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const { - onnx_mlir::zhigh::populateZHighToZLowConversionPattern( - patterns, typeConverter, ctx, enableParallel); + onnx_mlir::zhigh::populateZHighToZLowConversionPattern(patterns, + typeConverter, ctx, + /*enableSIMD*/ OptimizationLevel >= 3 && !disableSimdOption, + enableParallel); } void NNPAAccelerator::conversionTargetKrnlToLLVM( diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index b15e7f165d..f00fcdedff 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -30,6 +30,7 @@ std::unique_ptr createDevicePlacementPass( /// Add pass for lowering ONNX ops to ZHigh ops. std::unique_ptr createONNXToZHighPass(); +std::unique_ptr createONNXToZHighPass(NNPAQuantType quantMode); void configureOnnxToZHighLoweringPass(bool reportOnNNPAUnsupportedOps); /// Add pass for rewriting ONNX ops for ZHigh. @@ -49,6 +50,10 @@ std::unique_ptr createZHighLayoutPropagationPass(); /// Pass for constant propagation at ZHighIR. std::unique_ptr createZHighConstPropagationPass(); +/// Pass for scrubbing constants at ZHighIR. +std::unique_ptr createZHighScrubDisposablePass( + bool closeAfter = true); + /// Pass for clipping values to dlfloat before stickification at ZHighIR. std::unique_ptr createZHighClipToDLFloatPass(); diff --git a/src/Accelerators/NNPA/Runtime/OMRuntimeNNPA.c b/src/Accelerators/NNPA/Runtime/OMRuntimeNNPA.c index d2d8877f1e..f55c30e795 100644 --- a/src/Accelerators/NNPA/Runtime/OMRuntimeNNPA.c +++ b/src/Accelerators/NNPA/Runtime/OMRuntimeNNPA.c @@ -24,7 +24,6 @@ #include #include "zDNNExtension/zDNNExtension.h" -#include "zdnn.h" #ifdef __cplusplus extern "C" { @@ -114,6 +113,22 @@ void OMInitAccelNNPA() { } } +/*! + * \brief Function to obtain the zDNN versions from the input versionNum. + * + * The zDNN major, minor, and patch versions are extracted from the input + * versionNum and set in *ver_major, *ver_minor, and *ver_patch. + * + * See the zDNN documentation for the definition of the major, minor, and + * patch versions. + */ +void getZDNNVersions(uint32_t versionNum, unsigned long long *ver_major, + unsigned long long *ver_minor, unsigned long long *ver_patch) { + *ver_major = versionNum >> 16; + *ver_minor = (versionNum >> 8) & 0xff; + *ver_patch = versionNum & 0xff; +} + /*! * \brief Function that performs the initialization of the NNPA device and * check that the NNPA version that the program was compiled for is compatible @@ -174,15 +189,37 @@ uint64_t OMInitCompatibleAccelNNPA(uint64_t versionNum) { pthread_mutex_unlock(&OMMutexForInitShutdownNNPA); if (!isCompatible) { /* Code below has to agree with zdnn.h convention. */ - unsigned long long ver_major = versionNum >> 16; - unsigned long long ver_minor = (versionNum >> 8) & 0xff; - unsigned long long ver_patch = versionNum & 0xff; + /* Create and initialize variables to 0 to avoid code scan error. */ + unsigned long long mod_ver_major = 0; + unsigned long long mod_ver_minor = 0; + unsigned long long mod_ver_patch = 0; + /* Invoke getZDNNVersions() to extract the zDNN major, minor, and patch + * version numbers from the model's version number. */ + getZDNNVersions( + versionNum, &mod_ver_major, &mod_ver_minor, &mod_ver_patch); + uint32_t zDNNLibaryVersion = zdnn_get_library_version(); + unsigned long long lib_ver_major = 0; + unsigned long long lib_ver_minor = 0; + unsigned long long lib_ver_patch = 0; + /* Invoke getZDNNVersions() to extract the zDNN major, minor, and patch + * version numbers from the zDNN library version number. */ + getZDNNVersions( + zDNNLibaryVersion, &lib_ver_major, &lib_ver_minor, &lib_ver_patch); + uint32_t zDNNAPIMaxVersion = zdnn_get_max_runnable_version(); + unsigned long long api_ver_major = 0; + unsigned long long api_ver_minor = 0; + unsigned long long api_ver_patch = 0; + /* Invoke getZDNNVersions() to extract the zDNN major, minor, and patch + * version numbers from the zDNN maximum API version number. */ + getZDNNVersions( + zDNNAPIMaxVersion, &api_ver_major, &api_ver_minor, &api_ver_patch); fprintf(stderr, - "Model is running on hardware that is not compatible with " - "the zDNN library that this model was compiled for " - "(version num %llu.%llu.%llu). Please ensure a compatible zDNN " - "library is available.\n ", - ver_major, ver_minor, ver_patch); + "Model requires zDNN API version %llu.%llu.%llu. The system has " + "zDNN library version %llu.%llu.%llu and supports up to zDNN API" + " version %llu.%llu.%llu.\n", + mod_ver_major, mod_ver_minor, mod_ver_patch, lib_ver_major, + lib_ver_minor, lib_ver_patch, api_ver_major, api_ver_minor, + api_ver_patch); errno = EPERM; return false; } diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c index 276c5ee87d..4a2beedf7a 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/Elementwise.c @@ -28,7 +28,6 @@ #include #include "zDNNExtension.h" -#include "zdnn.h" #ifdef __cplusplus extern "C" { @@ -304,6 +303,46 @@ zdnn_status zdnn_tanh_ext(const zdnn_ztensor *input, zdnn_ztensor *output) { return status; } +// ----------------------------------------------------------------------------- +// Extension Functions for arch15 +// arch15 specific zdnn functions but with the `_ext` postfix. +// Retrieve the zdnn status message +// ----------------------------------------------------------------------------- + +zdnn_status zdnn_gelu_ext(const zdnn_ztensor *input, zdnn_ztensor *output) { + zdnn_status status = zdnn_gelu(input, output); + CHECK_ZDNN_STATUS(status, "zdnn_gelu"); + return status; +} + +zdnn_status zdnn_invsqrt_ext( + const zdnn_ztensor *input, float epsilon, zdnn_ztensor *output) { + zdnn_status status = zdnn_invsqrt(input, epsilon, output); + CHECK_ZDNN_STATUS(status, "zdnn_invsqrt"); + return status; +} + +zdnn_status zdnn_leaky_relu_ext(const zdnn_ztensor *input, + const void *clipping_value, float adjustment_factor, zdnn_ztensor *output) { + zdnn_status status = + zdnn_leaky_relu(input, clipping_value, adjustment_factor, output); + CHECK_ZDNN_STATUS(status, "zdnn_leaky_relu"); + return status; +} + +zdnn_status zdnn_reduce_ext(const zdnn_ztensor *input, void *save_area, + int opType, zdnn_ztensor *output) { + zdnn_status status = zdnn_reduce(input, save_area, opType, output); + CHECK_ZDNN_STATUS(status, "zdnn_reduce"); + return status; +} + +zdnn_status zdnn_sqrt_ext(const zdnn_ztensor *input, zdnn_ztensor *output) { + zdnn_status status = zdnn_sqrt(input, output); + CHECK_ZDNN_STATUS(status, "zdnn_sqrt"); + return status; +} + #ifdef __cplusplus } #endif diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c index c15c8666b0..311555ce9e 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c @@ -31,7 +31,6 @@ #include #include "zDNNExtension.h" -#include "zdnn.h" #ifdef __cplusplus extern "C" { @@ -173,6 +172,19 @@ zdnn_status zdnn_matmul_bcast_op_ext(const zdnn_ztensor *inputA, return status; } +// transpose_a and transpose_b are actually boolean values but we will represent +// these values in terms of integer values 0 or 1 for consistency. +zdnn_status zdnn_matmul_transpose_op_ext(const zdnn_ztensor *inputA, + const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int transpose_a, + int transpose_b, int opType, zdnn_ztensor *output) { + zdnn_status status = zdnn_matmul_transpose_op( + inputA, inputB, inputC, transpose_a, transpose_b, opType, output); + // Compiler does not check the return result at this moment. Thus, check it + // here. + CHECK_ZDNN_STATUS(status, "zdnn_matmul_transpose"); + return status; +} + #ifdef __cplusplus } #endif diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c b/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c index 69192260a5..e59143fdbb 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/Softmax.c @@ -28,7 +28,6 @@ #include #include "zDNNExtension.h" -#include "zdnn.h" #ifdef __cplusplus extern "C" { diff --git a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h index 3123dd2e76..e967bdb9fd 100644 --- a/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h +++ b/src/Accelerators/NNPA/Runtime/zDNNExtension/zDNNExtension.h @@ -135,7 +135,7 @@ void zDNNExtensionInit(); // Helper Functions // ----------------------------------------------------------------------------- -inline void omUnreachable() { +static inline void omUnreachable() { // Uses compiler specific extensions if possible. // Even if no extension is used, undefined behavior is still raised by // an empty function body and the noreturn attribute. @@ -329,6 +329,23 @@ zdnn_status zdnn_softmax_ext(const zdnn_ztensor *input, void *save_area, zdnn_softmax_act act_func, zdnn_ztensor *output); zdnn_status zdnn_tanh_ext(const zdnn_ztensor *input, zdnn_ztensor *output); +// ----------------------------------------------------------------------------- +// Extension Functions for arch15 +// arch15 specific zdnn functions but with the `_ext` postfix. +// ----------------------------------------------------------------------------- + +zdnn_status zdnn_gelu_ext(const zdnn_ztensor *input, zdnn_ztensor *output); +zdnn_status zdnn_invsqrt_ext( + const zdnn_ztensor *input, float epsilon, zdnn_ztensor *output); +zdnn_status zdnn_leaky_relu_ext(const zdnn_ztensor *input, + const void *clipping_value, float adjustment_factor, zdnn_ztensor *output); +zdnn_status zdnn_sqrt_ext(const zdnn_ztensor *input, zdnn_ztensor *output); +zdnn_status zdnn_matmul_transpose_op_ext(const zdnn_ztensor *inputA, + const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int transpose_a, + int transpose_b, int opType, zdnn_ztensor *output); +zdnn_status zdnn_reduce_ext(const zdnn_ztensor *input, void *save_area, + int op_type, zdnn_ztensor *output); + // ----------------------------------------------------------------------------- // Misc Utility Functions // ----------------------------------------------------------------------------- diff --git a/src/Accelerators/NNPA/Support/LayoutHelper.hpp b/src/Accelerators/NNPA/Support/LayoutHelper.hpp index fb512fc90a..26ca4a6801 100644 --- a/src/Accelerators/NNPA/Support/LayoutHelper.hpp +++ b/src/Accelerators/NNPA/Support/LayoutHelper.hpp @@ -35,6 +35,12 @@ const std::string LAYOUT_ZRH = "ZRH"; const std::string LAYOUT_BFICO = "BFICO"; const std::string LAYOUT_BZRH = "BZRH"; +// Quantized transform type. +const std::string QTYPE_DLFLOAT16 = "DLFLOAT16"; +const std::string QTYPE_INT8 = "INT8"; +const std::string QTYPE_WEIGHTS = "WEIGHTS"; +const std::string QTYPE_UNDEFINED = "UNDEFINED"; + zdnn_data_layouts convertLayoutAttrToZDNNDataLayout( int64_t rank, mlir::StringAttr layoutAttr); diff --git a/src/Accelerators/NNPA/Support/NNPALimit.cpp b/src/Accelerators/NNPA/Support/NNPALimit.cpp index 5206442803..df14d4a4ce 100644 --- a/src/Accelerators/NNPA/Support/NNPALimit.cpp +++ b/src/Accelerators/NNPA/Support/NNPALimit.cpp @@ -18,29 +18,87 @@ #include #include +using namespace onnx_mlir; + //===----------------------------------------------------------------------===// -// Compatibility checks +// Scan mcpu and march flags into NNPALevel -/// Convert the input NNPA level, ie. "z16", to a integer value representing the -/// level, ie. "16". When unkown / out of bounds, returns 0. -int64_t convertNNPALevel(std::string inputNNPALevel) { - if (inputNNPALevel.size() != 3 || inputNNPALevel[0] != 'z') - return 0; - if (inputNNPALevel[1] == '1') { - if (inputNNPALevel[2] == '6') - return 16; +static NNPALevel getNNPAFromTargetFlag(std::string str) { + // Coded it efficiently as it is called over and over again. + if (str.size() == 3) { + if (str[0] == 'z') { + if (str[1] == '1') { + if (str[2] == '6') + return NNPALevel::M14; + } + } + } else if (str.size() == 6) { + if (str[0] == 'a' && str[1] == 'r' && str[2] == 'c' && str[3] == 'h') { + if (str[4] == '1') { + if (str[5] == '4') + return NNPALevel::M14; + if (str[5] == '5') + return NNPALevel::M15; + } + } } - return 0; + return NNPALevel::NONE; +} + +// Read march flag, and if undefined, then read mcpu. +NNPALevel getNNPAFromFlags() { + NNPALevel level = getNNPAFromTargetFlag(march); + if (level == NNPALevel::NONE) + level = getNNPAFromTargetFlag(mcpu); + return level; +} + +//===----------------------------------------------------------------------===// +// Print NNPALevel as a string (depending on which option was given) + +// Print level using mcpu, march, or both depending on the options that were +// given to the compiler. Favor the zYY names below over the archXX names. +std::string getNNPAString(NNPALevel level) { + std::string val; + if (!mcpu.empty()) { + // The mcpu compiler option is defined, give an answer + if (level == NNPALevel::M14) + val = "--mcpu=z16"; // Note: --mcpu is deprecated. + else if (level == NNPALevel::M15) + val = "--mcpu=arch15"; // Note: --mcpu is deprecated. + else + assert(level == NNPALevel::NONE && "unknown mcpu option"); + } + if (!march.empty()) { + if (!val.empty() && level != NNPALevel::NONE) + val = val.append(" "); + // The march compiler option is defined, give an answer + if (level == NNPALevel::M14) + val = val.append("--march=z16"); + else if (level == NNPALevel::M15) + val = val.append("--march=arch15"); + else + assert(level == NNPALevel::NONE && "unknown march option"); + } + return val; } /// A function to check whether the input NNPA level, ie. "z16", is compatible /// with the current NNPA level. -bool isCompatibleWithNNPALevel(std::string inputNNPALevel) { - int64_t inLevel = convertNNPALevel(inputNNPALevel); - int64_t mcpuLevel = convertNNPALevel(onnx_mlir::mcpu); - if (inLevel == 0 && mcpuLevel == 0) +bool isCompatibleWithNNPALevel(NNPALevel level) { + NNPALevel flagLevel = getNNPAFromFlags(); + if (level == NNPALevel::NONE && flagLevel == NNPALevel::NONE) return false; - return inLevel <= mcpuLevel; + return level <= flagLevel; +} + +/// A function to check whether the current --march, ie. "z16", is less than or +/// equal to the given NNPA level. +bool isLessEqualNNPALevel(NNPALevel level) { + NNPALevel flagLevel = getNNPAFromFlags(); + if (level == NNPALevel::NONE && flagLevel == NNPALevel::NONE) + return false; + return flagLevel <= level; } //===----------------------------------------------------------------------===// @@ -48,14 +106,41 @@ bool isCompatibleWithNNPALevel(std::string inputNNPALevel) { // The NNPA maximum supported dimension index size value by using // zdnn_get_nnpa_max_dim_idx_size() This value depends on HW. -static constexpr int64_t NNPA_Z16_MAXIMUM_DIMENSION_INDEX_SIZE = 32768; +static constexpr int64_t NNPA_ARCH14_MAXIMUM_DIMENSION_INDEX_SIZE = 32768; + +/* + ARCH15 sizes are dimension dependent: + for(int i=1; i<=4; ++i) { + uint32_t maxDimSize = zdnn_get_max_for_dim((uint8_t) i); + printf(" max size for dim e%i: %i\n", i, (int) maxDimSize); + } + + max size for dim e1: 2097152 + max size for dim e2: 1048576 + max size for dim e3: 32768 + max size for dim e4: 32768 +*/ +static constexpr int64_t NNPA_ARCH15_MAXIMUM_DIMENSION_INDEX_SIZES[] = { + /*e1*/ 2097152, /*e2*/ 1048576, /*e3*/ 32768, /*e4*/ 32768}; int64_t NNPAGetMaxForDim(int64_t dim, int64_t rank) { assert(rank >= 0 && "expected positive rank"); assert(dim >= 0 && dim < rank && "dim outside range [0..rank)"); if (rank > 4) return 0; - if (isCompatibleWithNNPALevel(NNPA_Z16)) - return NNPA_Z16_MAXIMUM_DIMENSION_INDEX_SIZE; + // rank 4: (index from memref = 0, 1, 2, 3) -> e (4, 3, 2, 1) + // rank 3: (index from memref = 0, 1, 2) -> e (3, 2, 1) + // rank 2: (index from memref = 0, 1) -> e (2, 1) + // rank 1: (index from memref = 0) -> e (1) + int64_t e = rank - dim; + + // List from newest NNPA to oldest, to select the most recent compatible + // one. + if (isCompatibleWithNNPALevel(NNPALevel::M15)) + return NNPA_ARCH15_MAXIMUM_DIMENSION_INDEX_SIZES[e - 1]; + + if (isCompatibleWithNNPALevel(NNPALevel::M14)) + return NNPA_ARCH14_MAXIMUM_DIMENSION_INDEX_SIZE; + return 0; } diff --git a/src/Accelerators/NNPA/Support/NNPALimit.hpp b/src/Accelerators/NNPA/Support/NNPALimit.hpp index fdf43a65e3..25f7d71c70 100644 --- a/src/Accelerators/NNPA/Support/NNPALimit.hpp +++ b/src/Accelerators/NNPA/Support/NNPALimit.hpp @@ -16,6 +16,7 @@ #define ONNX_MLIR_NNPA_LIMIT_H #include +#include // Get maximum number of element for a given NNPA tensor. Dim is a tensor/memref // index (from 0 to rank-1), with dim=0 being the outermost dimension and @@ -32,8 +33,28 @@ static constexpr int64_t NNPA_MAXIMUM_TENSOR_SIZE = 4294967296; static constexpr int64_t MAXIMUM_NUM_HIDDEN_SIZE_LSTM = 8192; static constexpr int64_t MAXIMUM_NUM_HIDDEN_SIZE_GRU = 10880; -// The NNPA levels. -static constexpr const char *NNPA_Z16 = "z16"; +// The NNPA levels. Newer versions must have larger numbers than older versions. +typedef enum NNPALevel { + NONE = 0, + M14 = 1, // Associated with march=arch14 | z16. + M15 = 2, // Associated with march=arch15. +} NNPALevel; + +// The NNPA ZDNN versions. Keep in sync with enum NNPALevel. +static constexpr uint64_t NNPA_ZDNN_VERSIONS[3] = { + /*NONE*/ 0x0, /*M14*/ 0x010001, /*M15*/ 0x010101}; + +// Scan to NNPALevel and print from NNPALevel. +NNPALevel getNNPAFromFlags(); +std::string getNNPAString(NNPALevel level); + +/// A function to check whether the input NNPA level, ie. "z16" or "arch14", is +/// compatible with the current NNPA level. +bool isCompatibleWithNNPALevel(NNPALevel level); + +/// A function to check whether the current --march (or deprecated --mcpu), ie. +/// "z16" or "arch14", is less than or equal to the given NNPA level. +bool isLessEqualNNPALevel(NNPALevel level); // Maximum/Minimum value in dlfloat16. // dlfloat value = (-1)^s * 2^(e-31) * (1 + m/512), e=[0, 63], m=[0, 511], @@ -43,4 +64,5 @@ static constexpr const char *NNPA_Z16 = "z16"; // and (s=1,e=63,m=510) as the minimum value. static constexpr float DLF16_MAX = (1L << 32) * (1.0 + (510.0 / 512.0)); static constexpr float DLF16_MIN = -1 * (1L << 32) * (1.0 + (510.0 / 512.0)); + #endif diff --git a/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp b/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp index d2ddc767b5..646ae0ec91 100644 --- a/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp +++ b/src/Accelerators/NNPA/Support/Stickify/Stickify.cpp @@ -30,28 +30,54 @@ #pragma export(zdnn_get_library_version) #endif +/// Verify the transformed descriptor zdnn_status verify_transformed_descriptor(const zdnn_tensor_desc *tfrmd_desc); +zdnn_status set_zdnn_status(zdnn_status status, const char *func_name, + const char *file_name, int line_no, const char *format, ...); + +#define ZDNN_STATUS(status, format, ...) \ + set_zdnn_status(status, __func__, __FILE__, __LINE__, format, __VA_ARGS__) + +#define ZDNN_STATUS_NO_MSG(status) ZDNN_STATUS(status, NULL, NO_ARG) +#ifndef ZDNN_CONFIG_DEBUG +#define ZDNN_STATUS_OK ZDNN_OK +#else +#define ZDNN_STATUS_OK ZDNN_STATUS_NO_MSG(ZDNN_OK) +#endif + /// Macros from third_party/zdnn-lib/zdnn/zdnn_private.h #define AIU_BYTES_PER_STICK 128 +#define AIU_1BYTE_CELLS_PER_STICK 128 #define AIU_2BYTE_CELLS_PER_STICK 64 +#define AIU_4BYTE_CELLS_PER_STICK 32 #define AIU_2BYTE_CELL_SIZE 2 #define AIU_STICKS_PER_PAGE 32 #define AIU_PAGESIZE_IN_BYTES 4096 #define ZDNN_MAX_DIMS 4 // number of dims in AIU's Tensor Descriptor -#define CEIL(a, b) (uint64_t)(((a) + (b)-1) / (b)) // positive numbers only +// From status.c +// maximum size for the format string, including the prepended STATUS_STR_XXX +#define MAX_STATUS_FMTSTR_SIZE 1024 + +// ----------------------------------------------------------------------------- +// Misc Macros +// ----------------------------------------------------------------------------- +#define CEIL(a, b) \ + static_cast(((a) + (b)-1) / (b)) // positive numbers only #define MIN(a, b) (((a) > (b)) ? (b) : (a)) #define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define BIT_SIZEOF(a) (sizeof(a) * 8) // padded = next multiple of AIU_2BYTE_CELLS_PER_STICK #define PADDED(x) \ - ((uint32_t)CEIL((x), AIU_2BYTE_CELLS_PER_STICK) * AIU_2BYTE_CELLS_PER_STICK) + (static_cast(CEIL((x), AIU_2BYTE_CELLS_PER_STICK)) * \ + AIU_2BYTE_CELLS_PER_STICK) #define ZDNN_STATUS_OK ZDNN_OK +// From zdnn_private.h typedef enum elements_mode { ELEMENTS_AIU, ELEMENTS_PRE, @@ -59,11 +85,14 @@ typedef enum elements_mode { ELEMENTS_PRE_ALL_GATES } elements_mode; -typedef /*vector*/ unsigned int vec_float32; -typedef /*vector*/ unsigned short vec_int16; -typedef /*vector*/ unsigned char vec_char8; // End - Macros from third_party/zdnn-lib/zdnn/zdnn_private.h +// Functions from third_party/zdnn-lib/zdnn/status.h +zdnn_status set_zdnn_status(zdnn_status status, const char *func_name, + const char *file_name, int line_no, const char *format, ...) { + return status; +} + // Functions from third_party/zdnn-lib/zdnn/get.c #define DECLARE_DATA_LAYOUT_STR(a) static const char *DATA_LAYOUT_STR_##a = #a; @@ -217,6 +246,7 @@ short get_data_type_size(zdnn_data_types type) { CASE_RTN_SIZE(FP16, 2); CASE_RTN_SIZE(FP32, 4); CASE_RTN_SIZE(ZDNN_DLFLOAT16, 2); + CASE_RTN_SIZE(INT8, 1); } #undef CASE_RTN_SIZE @@ -243,10 +273,11 @@ void *malloc_aligned_4k(size_t size) { } // find the 4k boundary after ptr - void *aligned_ptr = (void *)(((uintptr_t)ptr + extra_allocation) & - ~(AIU_PAGESIZE_IN_BYTES - 1)); + void *aligned_ptr = reinterpret_cast( + ((reinterpret_cast(ptr) + extra_allocation) & + ~(AIU_PAGESIZE_IN_BYTES - 1))); // put the original malloc'd address right before aligned_ptr - ((void **)aligned_ptr)[-1] = ptr; + (static_cast(aligned_ptr))[-1] = ptr; return aligned_ptr; } @@ -254,7 +285,7 @@ void *malloc_aligned_4k(size_t size) { void free_aligned_4k(void *aligned_ptr) { if (aligned_ptr) { // get the original malloc'd address from where we put it and free it - void *original_ptr = ((void **)aligned_ptr)[-1]; + void *original_ptr = (static_cast(aligned_ptr))[-1]; free(original_ptr); } } @@ -289,7 +320,7 @@ uint64_t get_num_elements(const zdnn_ztensor *ztensor, elements_mode mode) { // Multiply by the size of each expected dimension for (; i < ZDNN_MAX_DIMS; i++) { - num_elements *= (uint64_t)dims_ptr[i]; + num_elements *= static_cast(dims_ptr[i]); } if (mode == ELEMENTS_PRE_ALL_GATES) { @@ -302,11 +333,33 @@ uint64_t get_num_elements(const zdnn_ztensor *ztensor, elements_mode mode) { // Functions from third_party/zdnn-lib/zdnn/allochelper.c uint64_t getsize_ztensor(const zdnn_tensor_desc *tfrmd_desc) { - // same formula for 4DFEATURE and 4DKERNEL tensors - return (uint64_t)(tfrmd_desc->dim4) * tfrmd_desc->dim3 * - CEIL(tfrmd_desc->dim2, AIU_STICKS_PER_PAGE) * - CEIL(tfrmd_desc->dim1, AIU_2BYTE_CELLS_PER_STICK) * - AIU_PAGESIZE_IN_BYTES; + uint32_t cells_per_stick; + uint32_t number_of_sticks; + switch (tfrmd_desc->type) { + case ZDNN_BINARY_INT8: + if (tfrmd_desc->format == ZDNN_FORMAT_4DWEIGHTS) { + // 4DWEIGHTS has two vectors interleaved, therefore only 64 cells vs 128 + // Due to this interleaving, number_of_sticks is halved, but must be + // rounded up to stay even for proper interleaving. + cells_per_stick = AIU_2BYTE_CELLS_PER_STICK; + number_of_sticks = CEIL(tfrmd_desc->dim2, 2); + } else { + cells_per_stick = AIU_1BYTE_CELLS_PER_STICK; + number_of_sticks = tfrmd_desc->dim2; + } + break; + case ZDNN_BINARY_INT32: + cells_per_stick = AIU_4BYTE_CELLS_PER_STICK; + number_of_sticks = tfrmd_desc->dim2; + break; + case ZDNN_DLFLOAT16: /* fallthrough */ + default: + cells_per_stick = AIU_2BYTE_CELLS_PER_STICK; + number_of_sticks = tfrmd_desc->dim2; + } + return static_cast(tfrmd_desc->dim4) * tfrmd_desc->dim3 * + CEIL(number_of_sticks, AIU_STICKS_PER_PAGE) * + CEIL(tfrmd_desc->dim1, cells_per_stick) * AIU_PAGESIZE_IN_BYTES; } zdnn_status allochelper_ztensor_alloc(zdnn_ztensor *ztensor) { @@ -322,7 +375,7 @@ zdnn_status allochelper_ztensor_alloc(zdnn_ztensor *ztensor) { // get the size and allocate space aligned on a 4k boundary. If the malloc // fails, return error. - size = getsize_ztensor(ztensor->transformed_desc); + size = getsize_ztensor(ztensor->transformed_desc); // Modified if (!(ztensor->buffer = malloc_aligned_4k(size))) { return ZDNN_ALLOCATION_FAILURE; } @@ -339,7 +392,9 @@ void allochelper_ztensor_free(zdnn_ztensor *ztensor) { free_aligned_4k(ztensor->buffer); ztensor->buffer = NULL; ztensor->buffer_size = 0; -} // End - Functions from third_party/zdnn-lib/zdnn/allochelper.c +} + +/* End - Functions from third_party/zdnn-lib/zdnn/allochelper.c */ // Functions from third_party/zdnn-lib/zdnn/tensor_desc.c zdnn_status verify_pre_transformed_descriptor( @@ -368,6 +423,7 @@ zdnn_status verify_pre_transformed_descriptor( case BFLOAT: case FP16: case FP32: + case INT8: // all of these are good cases break; default: @@ -392,35 +448,71 @@ zdnn_status verify_transformed_descriptor(const zdnn_tensor_desc *tfrmd_desc) { case ZDNN_BIDIR_ZRH: break; default: - return ZDNN_INVALID_LAYOUT; + return ZDNN_STATUS(ZDNN_INVALID_LAYOUT, "Format is %s but layout is %s", + get_data_format_str(tfrmd_desc->format), + get_data_layout_str(tfrmd_desc->layout)); } break; case ZDNN_FORMAT_4DKERNEL: if (tfrmd_desc->layout != ZDNN_HWCK) { - return ZDNN_INVALID_LAYOUT; + return ZDNN_STATUS(ZDNN_INVALID_LAYOUT, "Format is %s but layout is %s", + get_data_format_str(tfrmd_desc->format), + get_data_layout_str(tfrmd_desc->layout)); + } + break; + case ZDNN_FORMAT_4DWEIGHTS: + if (tfrmd_desc->layout != ZDNN_NHWC) { + return ZDNN_STATUS(ZDNN_INVALID_LAYOUT, "Format is %s but layout is %s", + get_data_format_str(tfrmd_desc->format), + get_data_layout_str(tfrmd_desc->layout)); } break; + default: + // unrecognized + return ZDNN_STATUS(ZDNN_INVALID_FORMAT, "Invalid format: %d (%s)", + tfrmd_desc->format, get_data_format_str(tfrmd_desc->format)); } - - // for right now only ZDNN_DLFLOAT16 is valid - if (tfrmd_desc->type != ZDNN_DLFLOAT16) { + // Only ZDNN_DLFLOAT16, ZDNN_BINARY_INT8, and ZDNN_BINARY_INT32 are currently + // supported. + if (tfrmd_desc->type != ZDNN_DLFLOAT16 && + tfrmd_desc->type != ZDNN_BINARY_INT8 && + tfrmd_desc->type != ZDNN_BINARY_INT32) { return ZDNN_INVALID_TYPE; } const uint32_t *dims_ptr = &(tfrmd_desc->dim4); + /* ToFix: the nnpa_query_result is not set up with onnx-mlir + * Temporarily commented out. + * Refer to issue #3034 + */ + +#if 0 // is the dimension above the limit or zero? // transformed layout uses all dim* entries, so we'll check them all for (int i = 0; i < ZDNN_MAX_DIMS; i++) { if (!dims_ptr[i] || dims_ptr[i] > NNPAGetMaxForDim(i, ZDNN_MAX_DIMS)) { return ZDNN_INVALID_SHAPE; } + if (dims_ptr[i] > zdnn_get_max_for_dim(ZDNN_MAX_DIMS - i)) { + + if (!zdnn_get_max_for_dim(ZDNN_MAX_DIMS - i)) { + return ZDNN_UNSUPPORTED_AIU_EXCEPTION; + } else { + return ZDNN_STATUS( + ZDNN_INVALID_SHAPE, + "Invalid shape for dim%d. (reason: dimension value %d exceeds %d)", + ZDNN_MAX_DIMS - i, dims_ptr[i], + zdnn_get_max_for_dim(ZDNN_MAX_DIMS - i)); + } + } } // is stick area size above the limit? - if (getsize_ztensor(tfrmd_desc) > NNPA_MAXIMUM_TENSOR_SIZE) { + if (getsize_ztensor(tfrmd_desc) > zdnn_get_nnpa_max_tensor_size()) { return ZDNN_INVALID_SHAPE; } +#endif return ZDNN_STATUS_OK; } @@ -545,6 +637,36 @@ zdnn_status generate_transformed_desc( return status; } +zdnn_status generate_quantized_transformed_desc( + const zdnn_tensor_desc *pre_tfrmd_desc, + zdnn_quantized_transform_types transform_type, + zdnn_tensor_desc *tfrmd_desc) { + + zdnn_status status; + if ((status = generate_transformed_desc(pre_tfrmd_desc, tfrmd_desc)) != + ZDNN_OK) { + return status; + } + switch (transform_type) { + case QUANTIZED_DLFLOAT16: + tfrmd_desc->format = ZDNN_FORMAT_4DFEATURE; + tfrmd_desc->type = ZDNN_DLFLOAT16; + return ZDNN_STATUS_OK; + case QUANTIZED_INT8: + tfrmd_desc->format = ZDNN_FORMAT_4DFEATURE; + tfrmd_desc->type = ZDNN_BINARY_INT8; + return ZDNN_STATUS_OK; + case QUANTIZED_WEIGHTS_INT8: + tfrmd_desc->format = ZDNN_FORMAT_4DWEIGHTS; + tfrmd_desc->type = ZDNN_BINARY_INT8; + return ZDNN_STATUS_OK; + default: + return ZDNN_INVALID_TRANSFORM_TYPE; + // return ZDNN_STATUS(ZDNN_INVALID_TRANSFORM_TYPE, + // "Invalid transform type: %d", transform_type); + } +} + zdnn_status generate_transformed_desc_concatenated( const zdnn_tensor_desc *pre_tfrmd_desc, zdnn_concat_info info, zdnn_tensor_desc *tfrmd_desc) { @@ -628,6 +750,9 @@ void init_ztensor(zdnn_tensor_desc *pre_tfrmd_desc, output->transformed_desc = tfrmd_desc; output->is_transformed = false; memset(&output->reserved, 0, sizeof(output->reserved)); + output->rec_scale = 0; + output->offset = 0; + memset(&output->reserved2, 0, sizeof(output->reserved2)); } // End - Functions from third_party/zdnn-lib/zdnn/init_ztensor.c // Functions from third_party/zdnn-lib/zdnn/stickify.c @@ -652,8 +777,8 @@ uint32_t convert_data_format(void *input_data, zdnn_data_types in_data_fmt, if (out_data_fmt == ZDNN_DLFLOAT16) { switch (in_data_fmt) { case FP32: - num_fields_converted = fp32_to_dlf16( - (float *)input_data, (uint16_t *)output_data, num_fields); + num_fields_converted = fp32_to_dlf16(static_cast(input_data), + static_cast(output_data), num_fields); break; default: break; // something really wrong, get out and return 0 @@ -662,8 +787,8 @@ uint32_t convert_data_format(void *input_data, zdnn_data_types in_data_fmt, } else if (in_data_fmt == ZDNN_DLFLOAT16) { switch (out_data_fmt) { case FP32: - num_fields_converted = dlf16_to_fp32( - (uint16_t *)input_data, (float *)output_data, num_fields); + num_fields_converted = dlf16_to_fp32(static_cast(input_data), + static_cast(output_data), num_fields); break; default: break; // something really wrong, get out and return 0 @@ -726,7 +851,7 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // loop invariant values uint64_t bytes_all_h = - (uint64_t)ztensor->transformed_desc->dim3 * + static_cast(ztensor->transformed_desc->dim3) * CEIL(ztensor->transformed_desc->dim2, AIU_STICKS_PER_PAGE) * AIU_PAGESIZE_IN_BYTES; uint64_t bytes_per_n = bytes_all_h * CEIL(ztensor->transformed_desc->dim1, @@ -749,9 +874,13 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // "notice" our sequential accesses and continue them, so we won't // need to aggressively prefetch here. #if defined(__MVS__) - __dcbt((void *)((uintptr_t)in_buf + input_offset)); + __dcbt(reinterpret_cast( + reinterpret_cast(in_buf) + input_offset)); #else - __builtin_prefetch((void *)((uintptr_t)in_buf + input_offset), 0); + __builtin_prefetch( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + 0); #endif // used for pushing out_offset from w to w+1 (i.e., + // AIU_BYTES_PER_STICK) @@ -764,18 +893,26 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // Prefetch to L1 newest offset to write that HW wouldn't // know about #if defined(__MVS__) - __dcbtst((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbtst(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset)); #else __builtin_prefetch( - (void *)((uintptr_t)ztensor->buffer + output_offset), 1); + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset), + 1); #endif fields_to_convert = MIN((ztensor->transformed_desc->dim1 - e1x), AIU_2BYTE_CELLS_PER_STICK); nbr_fields_converted = convert_data_format( - (void *)((uintptr_t)in_buf + input_offset), + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), ztensor->pre_transformed_desc->type, - (void *)((uintptr_t)ztensor->buffer + output_offset), + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset), ztensor->transformed_desc->type, fields_to_convert); if (nbr_fields_converted == 0) { @@ -785,7 +922,9 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // Release L1 cacheline for stick. The next "touch" will be // from NNPA, and it doesn't need L1 caching. #if defined(__MVS__) - __dcbf((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbf(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset)); #else // No known equivalent fn without dropping to ASM.... #endif @@ -846,15 +985,20 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // "notice" our sequential accesses and continue them, so we won't // need to aggressively prefetch here. #if defined(__MVS__) - __dcbt((void *)((uintptr_t)in_buf + input_offset)); + __dcbt(reinterpret_cast( + reinterpret_cast(in_buf) + input_offset)); #else - __builtin_prefetch((void *)((uintptr_t)in_buf + input_offset), 0); + __builtin_prefetch( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + 0); #endif - nbr_fields_converted = - convert_data_format((void *)((uintptr_t)in_buf + input_offset), - ztensor->pre_transformed_desc->type, temp_buff, - ztensor->transformed_desc->type, fields_to_convert); + nbr_fields_converted = convert_data_format( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + ztensor->pre_transformed_desc->type, temp_buff, + ztensor->transformed_desc->type, fields_to_convert); if (nbr_fields_converted == 0) { return ZDNN_CONVERT_FAILURE; @@ -867,14 +1011,20 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // Prefetch to L1 newest offset to write that HW wouldn't // know about #if defined(__MVS__) - __dcbtst((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbtst(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset)); #else __builtin_prefetch( - (void *)((uintptr_t)ztensor->buffer + output_offset), 1); + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset), + 1); #endif - *(uint16_t *)((uintptr_t)ztensor->buffer + output_offset) = - temp_buff[w]; + *reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset) = temp_buff[w]; // go to same C location of the next stick output_offset += AIU_BYTES_PER_STICK; } @@ -931,21 +1081,32 @@ zdnn_status transform_ztensor(const void *in_buf, zdnn_ztensor *ztensor) { // Also, Prefetch the new output offset to write that HW wouldn't // know about. #if defined(__MVS__) - __dcbt((void *)((uintptr_t)in_buf + input_offset)); - __dcbtst((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbt(reinterpret_cast( + reinterpret_cast(in_buf) + input_offset)); + __dcbtst(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + output_offset)); #else - __builtin_prefetch((void *)((uintptr_t)in_buf + input_offset), 0); __builtin_prefetch( - (void *)((uintptr_t)ztensor->buffer + output_offset), 1); + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + 0); + __builtin_prefetch( + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset), + 1); #endif fields_to_convert = MIN((ztensor->transformed_desc->dim1 - e1x), AIU_2BYTE_CELLS_PER_STICK); - nbr_fields_converted = - convert_data_format((void *)((uintptr_t)in_buf + input_offset), - ztensor->pre_transformed_desc->type, - (void *)((uintptr_t)ztensor->buffer + output_offset), - ztensor->transformed_desc->type, fields_to_convert); + nbr_fields_converted = convert_data_format( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + ztensor->pre_transformed_desc->type, + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + + output_offset), + ztensor->transformed_desc->type, fields_to_convert); if (nbr_fields_converted == 0) { return ZDNN_CONVERT_FAILURE; @@ -1040,33 +1201,43 @@ zdnn_status transform_bidir_weight_ztensor( for (uint32_t e2x = 0; e2x < real_dim2; e2x++) { #if defined(__MVS__) - __dcbt((void *)((uintptr_t)in_buf + input_offset)); + __dcbt(reinterpret_cast( + reinterpret_cast(in_buf) + input_offset)); #else - __builtin_prefetch((void *)((uintptr_t)in_buf + input_offset), 0); + __builtin_prefetch( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + 0); #endif uint64_t out_offset_w = output_offset; for (uint32_t e1x = 0; e1x < ztensor->transformed_desc->dim1; e1x += AIU_2BYTE_CELLS_PER_STICK) { #if defined(__MVS__) - __dcbtst((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbtst(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + output_offset)); #else __builtin_prefetch( - (void *)((uintptr_t)ztensor->buffer + output_offset), 1); + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + output_offset), + 1); #endif fields_to_convert = MIN( (ztensor->transformed_desc->dim1 - e1x), AIU_2BYTE_CELLS_PER_STICK); - nbr_fields_converted = - convert_data_format((void *)((uintptr_t)in_buf + input_offset), - ztensor->pre_transformed_desc->type, - (void *)((uintptr_t)ztensor->buffer + output_offset), - ztensor->transformed_desc->type, fields_to_convert); + nbr_fields_converted = convert_data_format( + reinterpret_cast( + reinterpret_cast(in_buf) + input_offset), + ztensor->pre_transformed_desc->type, + reinterpret_cast( + reinterpret_cast(ztensor->buffer) + output_offset), + ztensor->transformed_desc->type, fields_to_convert); if (nbr_fields_converted == 0) return ZDNN_CONVERT_FAILURE; #if defined(__MVS__) - __dcbf((void *)((uintptr_t)ztensor->buffer + output_offset)); + __dcbf(reinterpret_cast( + reinterpret_cast(ztensor->buffer) + output_offset)); #else #endif input_offset += (nbr_fields_converted << input_cell_shift); @@ -1132,7 +1303,8 @@ zdnn_status stickify(zdnn_ztensor *ztensor, ...) { * b) buffer does not start on a 4k boundary * c) buffer_size is smaller than what's needed */ - if (!ztensor->buffer || (uintptr_t)ztensor->buffer & 0xFFF || + if (!ztensor->buffer || + reinterpret_cast(ztensor->buffer) & 0xFFF || ztensor->buffer_size < getsize_ztensor(ztensor->transformed_desc)) { return ZDNN_INVALID_BUFFER; } @@ -1290,9 +1462,9 @@ zdnn_status stickify(zdnn_ztensor *ztensor, ...) { for (uint32_t slice = 0; slice < num_slices; slice++) { for (uint8_t gate = 0; gate < num_gates; gate++) { // Points to a single slice of a single gate data. - const void *gate_data_slice = - (void *)((uintptr_t)gate_data[gate] + - (slice * sliced_gate_data_size)); + const void *gate_data_slice = reinterpret_cast( + reinterpret_cast(gate_data[gate]) + + (slice * sliced_gate_data_size)); // Transform the current slice of the current gate into final // ztensor @@ -1310,8 +1482,9 @@ zdnn_status stickify(zdnn_ztensor *ztensor, ...) { // Increment the temp_ztensor buffer by one sliced gate size // so we write to the correct location in the final output // ztensor. - temp_ztensor.buffer = (void *)((uintptr_t)(temp_ztensor.buffer) + - sliced_gate_buffer_size); + temp_ztensor.buffer = reinterpret_cast( + reinterpret_cast(temp_ztensor.buffer) + + sliced_gate_buffer_size); // Reset temp_ztensor is_transformed so we can recursively // call zdnn_transform_ztensor to process each slice of each @@ -1338,6 +1511,152 @@ zdnn_status stickify(zdnn_ztensor *ztensor, ...) { return status; } // End - Functions from third_party/zdnn-lib/zdnn/stickify.c +#define AIU_STICKS_PER_PAGE 32 +#define AIU_BYTES_PER_STICK 128 +#define AIU_1BYTE_CELLS_PER_STICK 128 +#define AIU_PAGESIZE_IN_BYTES 4096 + +#define VECPERM_MAX_INT8_ENTRIES 8 + +// The scalar version of transform_quantized_weights_ztensor() +zdnn_status transform_quantized_weights_ztensor_element_wise( + const void *in_buf, zdnn_ztensor *output) { + + // moving position as the input is processed, in BYTES + uint64_t input_offset = 0; + // moving position as the output is processed, in BYTES + uint64_t output_offset = 0; + + // loop invariant values + uint64_t bytes_all_h = + (uint64_t)output->transformed_desc->dim3 * + CEIL(CEIL(output->transformed_desc->dim2, 2), AIU_STICKS_PER_PAGE) * + AIU_PAGESIZE_IN_BYTES; + + uint64_t bytes_per_n = bytes_all_h * CEIL(output->transformed_desc->dim1, + (AIU_1BYTE_CELLS_PER_STICK / 2)); + + // N + for (uint32_t e4x = 0; e4x < output->transformed_desc->dim4; e4x++) { + + // used for pushing out_offset from n to n+1 (i.e., + bytes_per_n) + uint64_t out_offset_n = output_offset; + + // H + for (uint32_t e3x = 0; e3x < output->transformed_desc->dim3; e3x++) { + + // W, sticks are processed in pairs + for (uint32_t e2x = 0; e2x < output->transformed_desc->dim2; + e2x = e2x + 2) { + + // used for pushing out_offset from w to w+1 (i.e., + + // AIU_BYTES_PER_STICK) + uint64_t out_offset_w = output_offset; + + // true when dim2 is odd number and we're at the last w + bool no_stick2 = ((output->transformed_desc->dim2 - e2x) == 1); + + int8_t *stick1 = (int8_t *)in_buf + input_offset; + int8_t *stick2 = no_stick2 ? stick1 + // duplicate stick1 entries if no stick2 + : stick1 + output->transformed_desc->dim1; + + // this C loop takes care of the full VECPERM_MAX_INT8_ENTRIES-entries + // groups + for (uint32_t i = 0; + i < output->transformed_desc->dim1 / VECPERM_MAX_INT8_ENTRIES; + i++) { + ((int8_t *)output->buffer + output_offset)[0] = stick1[0]; + ((int8_t *)output->buffer + output_offset)[1] = stick2[0]; + ((int8_t *)output->buffer + output_offset)[2] = stick1[1]; + ((int8_t *)output->buffer + output_offset)[3] = stick2[1]; + ((int8_t *)output->buffer + output_offset)[4] = stick1[2]; + ((int8_t *)output->buffer + output_offset)[5] = stick2[2]; + ((int8_t *)output->buffer + output_offset)[6] = stick1[3]; + ((int8_t *)output->buffer + output_offset)[7] = stick2[3]; + + ((int8_t *)output->buffer + output_offset)[8] = stick1[4]; + ((int8_t *)output->buffer + output_offset)[9] = stick2[4]; + ((int8_t *)output->buffer + output_offset)[10] = stick1[5]; + ((int8_t *)output->buffer + output_offset)[11] = stick2[5]; + ((int8_t *)output->buffer + output_offset)[12] = stick1[6]; + ((int8_t *)output->buffer + output_offset)[13] = stick2[6]; + ((int8_t *)output->buffer + output_offset)[14] = stick1[7]; + ((int8_t *)output->buffer + output_offset)[15] = stick2[7]; + + stick1 += VECPERM_MAX_INT8_ENTRIES; + stick2 += VECPERM_MAX_INT8_ENTRIES; + output_offset += VECPERM_MAX_INT8_ENTRIES * 2; + + if ((i + 1) % + (AIU_BYTES_PER_STICK / (VECPERM_MAX_INT8_ENTRIES * 2)) == + 0) { + // we need to jump to the next c-stick of the same super c-stick + // + // roll-back to the beginning and jump to bytes_all_h number of + // bytes away + output_offset = output_offset - AIU_BYTES_PER_STICK + bytes_all_h; + } + } + + // takes care of the leftover c entries + for (uint32_t i = 0; + i < output->transformed_desc->dim1 % VECPERM_MAX_INT8_ENTRIES; + i++) { + ((int8_t *)output->buffer + output_offset)[0] = stick1[i]; + ((int8_t *)output->buffer + output_offset)[1] = stick2[i]; + + output_offset += 2; + } + + // move on to the next set + input_offset += output->transformed_desc->dim1 * (no_stick2 ? 1 : 2); + // output_offset was pushed around in dim1 loops, so reset it to + // the next w + output_offset = out_offset_w + AIU_BYTES_PER_STICK; + } + + // after processing all the w-entries, go to the next 4k-boundary + // location (aka stick padding) + output_offset = (output_offset + (AIU_PAGESIZE_IN_BYTES - 1)) & + (-AIU_PAGESIZE_IN_BYTES); + } + + // output_offset was pushed around in the dims[2-0] loops, so reset it + // to the next n + output_offset = out_offset_n + bytes_per_n; + } + + // Update the tensor's format to indicate it has been stickified + output->is_transformed = true; + return ZDNN_STATUS_OK; +} + +zdnn_status quantized_stickify(zdnn_ztensor *ztensor, const void *in_buf) { + /* It is supposed to use zdnn_transform_quantized_ztensor here. + * return zdnn_transform_quantized_ztensor(ztensor, 0, 0, in_buf); + * The clip_min and clip_max will not be used when + * transform_quantized_weights_ztensor() is called in this transform. + * The reason that zdnn_transform_quantized_ztensor can't be called + * is that the variable, nnpa_query_result, in the zdnn library built with + * onnx-mlir has not been properly set up. Therefore, the check on + * dimension size will fail. verify_transformed_descriptor() is called + * by zdnn_transform_quantized_ztensor(). + * Tried to call zdnn_refresh_nnpa_query_result(), but failed. + * In the copied verify_transformed_descriptor code, the code for checking + * has been commented out. + * Refer to issue #3034 + */ + + zdnn_status status; + if ((status = verify_transformed_descriptor(ztensor->transformed_desc)) != + ZDNN_OK) { + return status; + } + + return transform_quantized_weights_ztensor_element_wise(in_buf, ztensor); +} + /// Set information for a pre transformed descriptor. void set_info_pre_transformed_desc(zdnn_tensor_desc *pre_tfrmd_desc, zdnn_data_layouts layout, zdnn_data_types type, @@ -1350,7 +1669,7 @@ void set_info_pre_transformed_desc(zdnn_tensor_desc *pre_tfrmd_desc, // we do not need to set the unused dim vars to 1 for pre-transformed int startIdx = ZDNN_MAX_DIMS - get_data_layout_dims(layout); for (int i = startIdx; i < ZDNN_MAX_DIMS; i++) { - dims_ptr[i] = (uint32_t)shape[i - startIdx]; + dims_ptr[i] = static_cast(shape[i - startIdx]); } pre_tfrmd_desc->layout = layout; pre_tfrmd_desc->format = diff --git a/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp b/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp index 9bc1284f0c..304893a14a 100644 --- a/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp +++ b/src/Accelerators/NNPA/Support/Stickify/Stickify.hpp @@ -15,7 +15,10 @@ #ifndef ONNX_MLIR_STICKIFY_H #define ONNX_MLIR_STICKIFY_H +extern "C" { #include "zdnn.h" +} + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -28,6 +31,10 @@ void set_info_pre_transformed_desc(zdnn_tensor_desc *pre_tfrmd_desc, zdnn_status generate_transformed_desc( const zdnn_tensor_desc *pre_tfrmd_desc, zdnn_tensor_desc *tfrmd_desc); +zdnn_status generate_quantized_transformed_desc( + const zdnn_tensor_desc *pre_tfrmd_desc, zdnn_quantized_transform_types, + zdnn_tensor_desc *tfrmd_desc); + /// Generate a concatenated transformed descriptor. zdnn_status generate_transformed_desc_concatenated( const zdnn_tensor_desc *pre_tfrmd_desc, zdnn_concat_info concat_info, @@ -66,4 +73,5 @@ void allochelper_ztensor_free(zdnn_ztensor *ztensor); /// ZDNN_CONVERT_FAILURE /// zdnn_status stickify(zdnn_ztensor *ztensor, ...); +zdnn_status quantized_stickify(zdnn_ztensor *ztensor, const void *in_buf); #endif diff --git a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp index 475b2f03b7..3c63406d4b 100644 --- a/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp +++ b/src/Accelerators/NNPA/Transform/FoldStdAlloc.cpp @@ -1,6 +1,6 @@ //===-------- FoldStdAlloc.cpp - Fold std.alloc ---------------------------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -150,9 +150,9 @@ class FoldStdAlloc : public OpRewritePattern { // There must be exactly N stores to N different locations, where N is the // number of elements. - if ((int)storeOps.size() != numElements) + if (static_cast(storeOps.size()) != numElements) return failure(); - if ((int)indexToValueMap.size() != numElements) + if (static_cast(indexToValueMap.size()) != numElements) return failure(); // 2. Rewrite. @@ -211,7 +211,8 @@ class FoldStdAllocPass RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - (void)applyPatternsAndFoldGreedily(function, std::move(patterns)); + static_cast( + applyPatternsAndFoldGreedily(function, std::move(patterns))); } }; diff --git a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt index 7f9bfe05ec..30378c0f9e 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt @@ -84,3 +84,16 @@ add_onnx_mlir_library(OMZHighRecomposeToStickUnstick ACCEL_INCLUDE_DIRS PRIVATE ${NNPA_INCLUDE_PATH} ) + +add_onnx_mlir_library(OMZHighScrubDisposable + ZHighScrubDisposablePass.cpp + + LINK_LIBS PUBLIC + MLIRRewrite + MLIRTransformUtils + OMZHighOps + OMONNXOps + + ACCEL_INCLUDE_DIRS PRIVATE + ${NNPA_INCLUDE_PATH} + ) diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp index 1f7889eba8..9006c36669 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp @@ -70,7 +70,7 @@ bool valueFromZTensor(Value tensor) { return valueFromZTensor(op->getOperand(0)); // PadOp - if (auto padOp = dyn_cast(op)) { + if (auto padOp = mlir::dyn_cast(op)) { Value padVal = padOp.getConstantValue(); // Only support default constant value that is 0 at this moment. if (isNoneValue(padVal)) @@ -96,7 +96,7 @@ class ZHighClipToDLFloatPattern : public OpRewritePattern { Type inputElementType = getElementType(input.getType()); // Only clip if the input is in float > 16 bit. - auto floatType = dyn_cast(inputElementType); + auto floatType = mlir::dyn_cast(inputElementType); if (!floatType) return failure(); if (floatType.getWidth() <= 16) diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp index a32bacb4c4..e5a87008eb 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp @@ -21,11 +21,13 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Dialect/ONNX/OnnxElementsAttrBuilder.hpp" using namespace mlir; using namespace onnx_mlir; @@ -35,15 +37,35 @@ namespace onnx_mlir { namespace zhigh { /// Get raw data from a dense attribute. -static void getRawData(DenseElementsAttr denseAttr, std::vector &data) { - if (!denseAttr.isSplat()) { - data = denseAttr.getRawData(); - } else { - ShapedType denseShapeType = mlir::cast(denseAttr.getType()); - std::vector rawData = denseAttr.getRawData(); - int64_t numElements = denseShapeType.getNumElements(); +static void getRawData(ElementsAttr attr_, std::vector &data) { + ShapedType tensorType = mlir::cast(attr_.getType()); + Type elemTy = tensorType.getElementType(); + int64_t numElements = tensorType.getNumElements(); + + // Use DenseElementsAttr for boolean values. DisposableElementsAttr handles + // bool differently. + ElementsAttr attr = attr_; + if (elemTy.isInteger(1)) + attr = ElementsAttrBuilder::toDenseElementsAttr(attr_); + + auto denseAttr = mlir::dyn_cast_or_null(attr); + auto disposalAttr = mlir::dyn_cast_or_null(attr); + assert((denseAttr || disposalAttr) && + "Must be DenseElementsAttr or DisposableElementsAttr"); + + if (disposalAttr) { + ArrayBuffer dstBytes = disposalAttr.getRawBytes(); + data = dstBytes.get(); + return; + } + + ArrayRef rawData = denseAttr.getRawData(); + if (denseAttr.isSplat()) { + // Broadcast the splat value. for (int i = 0; i < numElements; i++) data.insert(data.end(), rawData.begin(), rawData.end()); + } else { + data = rawData; } } @@ -57,6 +79,8 @@ zdnn_data_types mlirTypeToZDNNType(Type elementType) { return FP32; } else llvm_unreachable("Unsupported data type."); + } else if (elementType.isInteger(8)) { + return INT8; // INT8 is accepted by verify_pre_transformed_descriptor } else llvm_unreachable("Unsupported data type."); } @@ -71,18 +95,39 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter, /*value=*/nullptr, /*alignment=*/rewriter.getI64IntegerAttr(4096)); - // Use an dense resource attribute to store stickified data. // Attribute type: tensor int64_t sizeInBytes = ztensor->buffer_size; - DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( - RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), - stickifiedConstant.getOperation() - ->getDialect() - ->getNamespace(), // use the dialect as the blob "hint" - HeapAsmResourceBlob::allocateAndCopyWithAlign( - llvm::ArrayRef((char *)ztensor->buffer, sizeInBytes), alignof(char))); - stickifiedConstant.setValueAttr(valueAttr); + // Currently, using DenseResourceElementsAttr leads to less memory consumption + // at compile time. + // In the future, if there is a need to do constant prop for ZHigh Ops whose + // inputs are stickified data, then using ElementsAttr is potentially better. + // In this case, to print or parse ElementsAttr in lit tests, + // ZHighStickifiedConstantOp would be updated to support custom printer and + // parser. + bool useDenseResourceElementsAttr = true; + if (useDenseResourceElementsAttr) { + DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get( + RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()), + stickifiedConstant.getOperation() + ->getDialect() + ->getNamespace(), // use the dialect as the blob "hint" + HeapAsmResourceBlob::allocateAndCopyWithAlign( + llvm::ArrayRef((char *)ztensor->buffer, sizeInBytes), + alignof(char))); + allochelper_ztensor_free(ztensor); + stickifiedConstant.setValueAttr(valueAttr); + } else { + RankedTensorType dataType = + RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()); + std::unique_ptr memBuf = + llvm::MemoryBuffer::getMemBuffer( + StringRef((char *)ztensor->buffer, sizeInBytes), "", + /*RequiresNullTerminator*/ false); + ElementsAttr valueAttr = OnnxElementsAttrBuilder(rewriter.getContext()) + .fromMemoryBuffer(dataType, std::move(memBuf)); + stickifiedConstant.setValueAttr(valueAttr); + } return stickifiedConstant; } @@ -90,14 +135,12 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter, ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter, Value replacingValue, Value input, StringAttr layout) { Location loc = replacingValue.getLoc(); - Operation *op = input.getDefiningOp(); ArrayRef shape = mlir::cast(input.getType()).getShape(); Type elementType = mlir::cast(input.getType()).getElementType(); int rank = shape.size(); // Read dense attributes. - DenseElementsAttr dataAttr = mlir::dyn_cast_or_null( - op->getAttrOfType<::mlir::Attribute>("value")); + ElementsAttr dataAttr = getElementAttributeFromONNXValue(input); assert(dataAttr && "Attribute is null"); // Read attributes's raw data. std::vector rawData; @@ -132,14 +175,71 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter, return constantOp; } +bool isFoldableQuantizedStickOp(Value res) { + ZTensorEncodingAttr::QuantizedType qtype = + getZTensorQuantizedType(res.getType()); + return (qtype == ZTensorEncodingAttr::QuantizedType::WEIGHTS || + qtype == ZTensorEncodingAttr::QuantizedType::INT8); +} + +ZHighStickifiedConstantOp createQuantizedConstantForStick( + PatternRewriter &rewriter, Value replacingValue, Value input, + Value recScale, Value offset, StringAttr layout, StringAttr quantizeType) { + Location loc = replacingValue.getLoc(); + ArrayRef shape = mlir::cast(input.getType()).getShape(); + Type elementType = mlir::cast(input.getType()).getElementType(); + int rank = shape.size(); + + // Read dense attributes. + ElementsAttr dataAttr = getElementAttributeFromONNXValue(input); + assert(dataAttr && "Attribute is null"); + // Read attributes's raw data. + std::vector rawData; + getRawData(dataAttr, rawData); + // assert((rawData.size() == (uint64_t)getMemRefSizeInBytes(input)) && + // "Data size mismatched"); + + // Call stickify. + zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc; + // pre-transformed desc. + zdnn_data_layouts zDNNLayout = + convertLayoutAttrToZDNNDataLayout(rank, layout); + // If zDNNLayout is NHWC, we stickify directly from NCHW. + if (zDNNLayout == ZDNN_NHWC) + zDNNLayout = ZDNN_NCHW; + zdnn_data_types zDNNType = mlirTypeToZDNNType(elementType); + set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType, shape); + // Check the condition for transformed desc. + // Currently, only QUANTIZED_WEIGHTS_INT8 is supported. + // The condition of being the weight for QuantizedMatMul has been checked + // by the matching pattern. + assert(zDNNType == INT8); + zdnn_quantized_transform_types transform_type = QUANTIZED_WEIGHTS_INT8; + zdnn_status status = generate_quantized_transformed_desc( + &pre_tfrmd_desc, transform_type, &tfrmd_desc); + assert(status == ZDNN_OK); + // Stick data using the software stickify. + zdnn_ztensor ztensor; + // init_quantized_ztensor can be used if the constant value for recScale and + // offset is extracted at compile time. However, in the following + // transformation for the quantized weight tensor, the recScale and offset + // is not used. The parameters are kept for possible future use. + init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor); + status = allochelper_ztensor_alloc(&ztensor); + assert(status == ZDNN_OK); + status = quantized_stickify(&ztensor, rawData.data()); + assert(status == ZDNN_OK); + // Emit a constant global in ZHigh dialect. + ZHighStickifiedConstantOp constantOp = emitZHighStickifiedConstant( + rewriter, loc, &ztensor, replacingValue.getType()); + + return constantOp; +} + ZHighStickifiedConstantOp createConstantForStickForLSTM( PatternRewriter &rewriter, Value replacingValue, Value inputF, Value inputI, Value inputC, Value inputO) { Location loc = replacingValue.getLoc(); - Operation *fOp = inputF.getDefiningOp(); - Operation *iOp = inputI.getDefiningOp(); - Operation *cOp = inputC.getDefiningOp(); - Operation *oOp = inputO.getDefiningOp(); ArrayRef fShape = mlir::cast(inputF.getType()).getShape(); @@ -147,14 +247,10 @@ ZHighStickifiedConstantOp createConstantForStickForLSTM( Type elementType = mlir::cast(inputF.getType()).getElementType(); // Read dense attributes. - DenseElementsAttr fDataAttr = mlir::dyn_cast_or_null( - fOp->getAttrOfType<::mlir::Attribute>("value")); - DenseElementsAttr iDataAttr = mlir::dyn_cast_or_null( - iOp->getAttrOfType<::mlir::Attribute>("value")); - DenseElementsAttr cDataAttr = mlir::dyn_cast_or_null( - cOp->getAttrOfType<::mlir::Attribute>("value")); - DenseElementsAttr oDataAttr = mlir::dyn_cast_or_null( - oOp->getAttrOfType<::mlir::Attribute>("value")); + ElementsAttr fDataAttr = getElementAttributeFromONNXValue(inputF); + ElementsAttr iDataAttr = getElementAttributeFromONNXValue(inputI); + ElementsAttr cDataAttr = getElementAttributeFromONNXValue(inputC); + ElementsAttr oDataAttr = getElementAttributeFromONNXValue(inputO); assert((fDataAttr && iDataAttr && cDataAttr && oDataAttr) && "Attribute is null"); // Read attributes's raw data. @@ -198,9 +294,6 @@ ZHighStickifiedConstantOp createConstantForStickForGRU( PatternRewriter &rewriter, Value replacingValue, Value inputZ, Value inputR, Value inputH) { Location loc = replacingValue.getLoc(); - Operation *zOp = inputZ.getDefiningOp(); - Operation *rOp = inputR.getDefiningOp(); - Operation *hOp = inputH.getDefiningOp(); ArrayRef zShape = mlir::cast(inputZ.getType()).getShape(); @@ -208,12 +301,9 @@ ZHighStickifiedConstantOp createConstantForStickForGRU( Type elementType = mlir::cast(inputZ.getType()).getElementType(); // Read dense attributes. - DenseElementsAttr zDataAttr = mlir::dyn_cast_or_null( - zOp->getAttrOfType<::mlir::Attribute>("value")); - DenseElementsAttr rDataAttr = mlir::dyn_cast_or_null( - rOp->getAttrOfType<::mlir::Attribute>("value")); - DenseElementsAttr hDataAttr = mlir::dyn_cast_or_null( - hOp->getAttrOfType<::mlir::Attribute>("value")); + ElementsAttr zDataAttr = getElementAttributeFromONNXValue(inputZ); + ElementsAttr rDataAttr = getElementAttributeFromONNXValue(inputR); + ElementsAttr hDataAttr = getElementAttributeFromONNXValue(inputH); assert((zDataAttr && rDataAttr && hDataAttr) && "Attribute is null"); // Read attributes's raw data. std::vector rawZData, rawHData, rawRData, rawOData; @@ -261,10 +351,139 @@ namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Accelerators/NNPA/Transform/ZHigh/ONNXZHighConstPropagation.inc" +static void replaceOpAndGC( + PatternRewriter &rewriter, Operation *op, ValueRange newValues) { + for (Value v : op->getOperands()) { + // v is consumed by only the current stick op. + if (!v.hasOneUse()) + continue; + if (auto cop = v.getDefiningOp()) { + if (auto disposableAttr = + mlir::dyn_cast(cop.getValueAttr())) { + // Since the current op is the only consummer of the constant, + // this constant op will be dead soon after the current op is replaced + // (but the attribute's buffer is not disposed automatically until the + // next call of garbage collector). So, it's safe to dispose the + // attribute's buffer now in order to eagerly save memory. + // + // Once the buffer is dispose, any touch to the attribute would be + // invalid. So we just remove it from the constant operation. + disposableAttr.dispose(); + cop.removeValueAttr(); + } + } + } + rewriter.replaceOp(op, newValues); +} + +// zhigh.Stick (c) = krnl.global(c1), where c1 is stickified data. +// Always saturate constants. +struct ConstantStickPattern : public OpRewritePattern { + ConstantStickPattern(MLIRContext *context) : OpRewritePattern(context) {} + LogicalResult matchAndRewrite( + ZHighStickOp stickOp, PatternRewriter &rewriter) const override { + Value input = stickOp.getIn(); + Value output = stickOp.getOut(); + StringAttr layout = stickOp.getLayoutAttr(); + + // Match + if (!isDenseONNXConstant(input)) { + return failure(); + } + + // Rewrite + Value stickifiedVal = + createConstantForStick(rewriter, output, input, layout); + replaceOpAndGC(rewriter, stickOp, stickifiedVal); + return success(); + } +}; + +// zhigh.StickForGRU (c1, c2, c3) = krnl.global(c) +// where c is stickified data. +struct ConstantStickForGRUPattern + : public OpRewritePattern { + ConstantStickForGRUPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite( + ZHighStickForGRUOp stickOp, PatternRewriter &rewriter) const override { + Value zGate = stickOp.getZGate(); + Value rGate = stickOp.getRGate(); + Value hGate = stickOp.getHGate(); + Value output = stickOp.getOut(); + + // Match + if (!isDenseONNXConstant(zGate) || !isDenseONNXConstant(rGate) || + !isDenseONNXConstant(hGate)) { + return failure(); + } + + // Rewrite + Value stickifiedVal = + createConstantForStickForGRU(rewriter, output, zGate, rGate, hGate); + replaceOpAndGC(rewriter, stickOp, stickifiedVal); + return success(); + } +}; + +// zhigh.StickForLSTM (c1, c2, c3, c4) = krnl.global(c) +// where c is stickified data. +struct ConstantStickForLSTMPattern + : public OpRewritePattern { + ConstantStickForLSTMPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite( + ZHighStickForLSTMOp stickOp, PatternRewriter &rewriter) const override { + Value fGate = stickOp.getFGate(); + Value iGate = stickOp.getIGate(); + Value cGate = stickOp.getCGate(); + Value oGate = stickOp.getOGate(); + Value output = stickOp.getOut(); + + // Match + if (!isDenseONNXConstant(fGate) || !isDenseONNXConstant(iGate) || + !isDenseONNXConstant(cGate) || !isDenseONNXConstant(oGate)) { + return failure(); + } + + // Rewrite + Value stickifiedVal = createConstantForStickForLSTM( + rewriter, output, fGate, iGate, cGate, oGate); + replaceOpAndGC(rewriter, stickOp, stickifiedVal); + return success(); + } +}; + +// zhigh.QuantizedStick (c) = krnl.global(c1), where c1 is stickified data. +// Always saturate constants. +struct ConstantQuantizedStickPattern + : public OpRewritePattern { + ConstantQuantizedStickPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite( + ZHighQuantizedStickOp stickOp, PatternRewriter &rewriter) const override { + Value input = stickOp.getIn(); + Value recscale = stickOp.getRecScale(); + Value offset = stickOp.getOffset(); + Value output = stickOp.getOut(); + StringAttr layout = stickOp.getLayoutAttr(); + StringAttr quantizedType = stickOp.getQuantizedTypeAttr(); + + // Match + if (!isDenseONNXConstant(input) || !isFoldableQuantizedStickOp(output)) { + return failure(); + } + + // Rewrite + Value stickifiedVal = createQuantizedConstantForStick( + rewriter, output, input, recscale, offset, layout, quantizedType); + replaceOpAndGC(rewriter, stickOp, {stickifiedVal, recscale, offset}); + return success(); + } +}; + struct ZHighConstPropagationPass - //: public PassWrapper> { - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ZHighConstPropagationPass) @@ -275,11 +494,14 @@ struct ZHighConstPropagationPass } void runOnOperation() override { - auto function = getOperation(); + ModuleOp moduleOp = getOperation(); ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); - populateWithGenerated(patterns); - (void)applyPatternsAndFoldGreedily(function, std::move(patterns)); + patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); + patterns.insert(patterns.getContext()); + (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); } }; } // anonymous namespace diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.td b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.td index 699d425780..f8f822ab6c 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.td +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.td @@ -28,63 +28,5 @@ include "src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td" /// dag benefitsAdded = (addBenefit 0) /// >; -//===----------------------------------------------------------------------===// -// Pattern-Match and Rewrite -//===----------------------------------------------------------------------===// - -// Useful test definitions. - -// Check an ONNXConstantOp is using dense attribute or not. -def IsFromDenseONNXConstantOp: - Constraint, - "Value is produced by a dense ONNXConstantOp">; - -// Constant propagation for stickify -def CreateConstantForStick: NativeCodeCall< - "createConstantForStick($_builder, $0, $1, $2)" ->; - -def CreateConstantForStickForLSTM : NativeCodeCall< - "createConstantForStickForLSTM($_builder, $0, $1, $2, $3, $4)" ->; - -def CreateConstantForStickForGRU : NativeCodeCall< - "createConstantForStickForGRU($_builder, $0, $1, $2, $3)" ->; - -// zhigh.Stick (c) = krnl.global(c1), where c1 is stickified data. -// Always saturate constants. -def ConstantStickPattern : Pat< - (ZHighStickOp:$stickOp - (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), - $layout, $_), - (CreateConstantForStick $stickOp, $c, $layout), - [(IsFromDenseONNXConstantOp:$c)] ->; - -// zhigh.StickForLSTM (c1, c2, c3, c4) = krnl.global(c) -// where c is stickified data. -def ConstantStickForLSTMPattern : Pat< - (ZHighStickForLSTMOp:$stickOp - (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$c3 $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$c4 $_, $_, $_, $_, $_, $_, $_, $_)), - (CreateConstantForStickForLSTM $stickOp, $c1, $c2, $c3, $c4), - [(IsFromDenseONNXConstantOp:$c1), (IsFromDenseONNXConstantOp:$c2), - (IsFromDenseONNXConstantOp:$c3), (IsFromDenseONNXConstantOp:$c4)] ->; - -// zhigh.StickForGRU (c1, c2, c3) = krnl.global(c) -// where c is stickified data. -def ConstantStickForGRUPattern : Pat< - (ZHighStickForGRUOp:$stickOp - (ONNXConstantOp:$c1 $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$c2 $_, $_, $_, $_, $_, $_, $_, $_), - (ONNXConstantOp:$c3 $_, $_, $_, $_, $_, $_, $_, $_)), - (CreateConstantForStickForGRU $stickOp, $c1, $c2, $c3), - [(IsFromDenseONNXConstantOp:$c1), (IsFromDenseONNXConstantOp:$c2), - (IsFromDenseONNXConstantOp:$c3)] ->; #endif // ZHIGH_CONST_PROPAGATION diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp index 0c997fcada..ceb4d6459a 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighLayoutPropagation.cpp @@ -47,7 +47,7 @@ std::pair areProducedByUnstickOpSameLayout( !isa(first.getDefiningOp())) return std::make_pair(false, nullptr); Value firstStickifiedVal = - cast(first.getDefiningOp()).getIn(); + mlir::cast(first.getDefiningOp()).getIn(); StringAttr firstLayout = convertZTensorDataLayoutToStringAttr( rewriter, getZTensorLayout(firstStickifiedVal.getType())); @@ -56,7 +56,7 @@ std::pair areProducedByUnstickOpSameLayout( using namespace onnx_mlir::zhigh; if (mlir::isa(v) || !isa(v.getDefiningOp())) return false; - Value stickifiedVal = cast(v.getDefiningOp()).getIn(); + Value stickifiedVal = mlir::cast(v.getDefiningOp()).getIn(); StringAttr nextLayout = convertZTensorDataLayoutToStringAttr( rewriter, getZTensorLayout(stickifiedVal.getType())); return (nextLayout == firstLayout); @@ -127,7 +127,7 @@ class ONNXUnaryOpLayoutPropPattern : public OpRewritePattern { return failure(); // Input is a CPU tensor, do nothing. - auto unstickOp = dyn_cast(input.getDefiningOp()); + auto unstickOp = mlir::dyn_cast(input.getDefiningOp()); if (!unstickOp) return failure(); @@ -182,8 +182,8 @@ class ONNXBinaryOpLayoutPropPattern : public OpRewritePattern { return failure(); // Input is a CPU tensor, do nothing. - auto unstickAOp = dyn_cast(A.getDefiningOp()); - auto unstickBOp = dyn_cast(B.getDefiningOp()); + auto unstickAOp = mlir::dyn_cast(A.getDefiningOp()); + auto unstickBOp = mlir::dyn_cast(B.getDefiningOp()); if (!unstickAOp || !unstickBOp) return failure(); diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighScrubDisposablePass.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighScrubDisposablePass.cpp new file mode 100644 index 0000000000..435c75fd1f --- /dev/null +++ b/src/Accelerators/NNPA/Transform/ZHigh/ZHighScrubDisposablePass.cpp @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------------- ZHighScrubDisposablePass.cpp --------------------===// +// +// Replaces each DisposableElementsAttr with a DenseElementsAttr. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/Passes.h" + +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" +#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" +#include "src/Dialect/ONNX/ONNXDialect.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace zhigh { +namespace { + +struct ZHighScrubDisposablePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ZHighScrubDisposablePass) + + ZHighScrubDisposablePass(bool closeAfter) : closeAfter(closeAfter) {} + + StringRef getArgument() const override { return "zhigh-scrub-disposable"; } + + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + DisposablePool *pool = getDisposablePool(); + pool->scrub(moduleOp, + {{ONNXConstantOp::getOperationName(), "value"}, + {ONNXConstantOfShapeOp::getOperationName(), "value"}, + {ZHighStickifiedConstantOp::getOperationName(), "value"}}); + if (closeAfter) + pool->close(); + } + + DisposablePool *getDisposablePool() { + // It can be hard to get the MLIRContext at the time of construction + // of the pass, so we look it up the first time the pass is run. + if (!disposablePool) + disposablePool = DisposablePool::get(&getContext()); + return disposablePool; + } + + const bool closeAfter; + DisposablePool *disposablePool = nullptr; +}; + +} // namespace + +std::unique_ptr createZHighScrubDisposablePass(bool closeAfter) { + return std::make_unique(closeAfter); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt index 3710f0929b..fb304ec35b 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZLow/CMakeLists.txt @@ -14,6 +14,7 @@ add_onnx_mlir_library(OMZLowRewrite MLIRTransformUtils MLIRViewLikeInterface OMONNXToKrnl + OMZHighToZLow OMZLowOps diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index 4db736a82d..e434f309b7 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -25,6 +25,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/Accelerators/NNPA/Conversion/ZHighToZLow/ProcessStickData.hpp" #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" #include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp" #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" @@ -40,7 +41,6 @@ #define DEBUG_TYPE "zlow-stick-expansion" // Todo: cleanup after we are done experimenting. -#define ENABLE_CSU_PAR true /* Allow parallel compiler gen Stick/Unstick. */ #define PREFETCH_CSU_DIST 0 #define PREFETCH_CSU 1 @@ -75,7 +75,7 @@ class UnstickExpansionPattern : public OpRewritePattern { layout.getValue().equals_insensitive("3D") || layout.getValue().equals_insensitive("2D") || layout.getValue().equals_insensitive("3DS") || - (layout.getValue().equals_insensitive("NHWC"))) { + layout.getValue().equals_insensitive("NHWC")) { return generateUnstickCodeNoBuffer(rewriter, unstickOp); } // Otherwise, we don't replace and keep the zdnn call. @@ -96,202 +96,27 @@ class UnstickExpansionPattern : public OpRewritePattern { Value alloc = unstickOp.getOut(); DimsExpr outputDims; create.krnlIE.getShapeAsSymbols(alloc, outputDims); - int64_t rank = outputDims.size(); - - // Info for SIMD Vector Length (VL) and associated types. - int64_t archVL = 8; // FP16 archVL. - int64_t archVLHalf = archVL / 2; // FP32 archVL. - assert(64 % archVL == 0 && "SIMD vector length must divide 64"); - Type f16Type = rewriter.getF16Type(); - Type f32Type = rewriter.getF32Type(); - VectorType vecF16Type = VectorType::get({archVL}, f16Type); - MemRefType bufferType = MemRefType::get({archVL}, f32Type); - - // Define useful literals. - IndexExpr litZero = LitIE(0); - IndexExpr lit1 = LitIE(1); - IndexExpr litArchVLHalf = LitIE(archVLHalf); - IndexExpr litArchVL = LitIE(archVL); - IndexExpr lit64 = LitIE(64); - - // Useful references for indexing dimensions (neg val are not used). - int64_t E1 = rank - 1; - // Create loop iterations. Note that we iterate over E1 as tiles of 64 - // elements. - ValueRange loopDefs = create.krnl.defineLoops(rank); - DimsExpr lbs(rank, litZero); + DimsExpr lbs(outputDims.size(), LitIE(0)); DimsExpr ubs = outputDims; - IndexExpr T1 = outputDims[E1].ceilDiv(64); - ubs[E1] = T1; // E1 dim is over tiles. - - // Parallel... - if (enableParallel) { - int64_t parId; - // TODO: may want to check if ub of rank makes sense here. - if (findSuitableParallelDimension(lbs, ubs, 0, rank, parId, 8)) { - create.krnl.parallel(loopDefs[parId]); - onnxToKrnlParallelReport(op, true, parId, lbs[parId], ubs[parId], - "compiler-generated stickify"); - } else { - onnxToKrnlParallelReport(op, false, -1, -1, - "no dim with enough work in compiler-generated stickify"); - } - } - - // Compute max tiles. It is actually not easy to compute the max number of - // tiles. Since we don't allocate, it is just a "view", we only need to - // index by the "tile size", it is sufficient to assume 2 or more. Tiles are - // 64. - IndexExpr T = LitIE(2); - DimsExpr reallocTileDims = {T, lit64}; - Value inputAsTx64 = - create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims); - - // Outer loop (E4, E3, E2, E1 iterates over tiles of 64 elements) - create.krnl.iterateIE( - loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) { - MDBuilder create(b); - IndexExprScope outerScope(create.krnl, &allocScope); - DimsExpr outerIndices = DimListIE(loopInd); - // Computation for reading inputs. - DimsExpr inputAF = outerIndices; - IndexExpr e1 = outerIndices[E1] * 64; - inputAF[E1] = e1; - // Translate the tile index t1 to the actual targetted data. - Value inputOffset = - create.krnl.getLinearOffsetIndexIE(input, inputAF); - IndexExpr inputDataOffset = SymIE(inputOffset); - IndexExpr inputTileOffset = inputDataOffset.floorDiv(64); - -// Prefetch -#if PREFETCH_CSU - DimsExpr prefetchAF = inputAF; - // Prefetch current line - create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, - /*locality*/ 1); - create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, - /*locality*/ 1); -#if PREFETCH_CSU_DIST > 0 - // Prefetch line in advance. - prefetchAF[E1] = prefetchAF[E1] + (PREFETCH_CSU_DIST * 64); - create.krnl.prefetchIE(input, prefetchAF, /*isWrite*/ false, - /*locality*/ 1); - create.krnl.prefetchIE(alloc, prefetchAF, /*isWrite*/ true, - /*locality*/ 1); -#endif -#endif - - // I may process here up to [e1 ... e1 + m*64), make sure its - // not going out of bound, i.e. beyond outputDIms[E1]; - IndexExpr ub1 = SymIE(outputDims[E1]); - IndexExpr lit64Bis = LitIE(64); - IndexExpr isFull = create.krnlIE.isTileFull(e1, lit64, ub1); - IndexExpr isFullLogical = isFull >= 0; - create.scf.ifThenElse( - // Condition - isFullLogical.getValue(), - // Then (is full). - [&](SCFBuilder b) { - MDBuilder create(b); - // Loop (tried unroll of 2 and 8, 4 was best). - const int64_t unrollVL = 4; - const int64_t totVL = unrollVL * archVL; - assert(totVL <= 64 && "bad unroll"); - create.scf.forLoop(litZero.getValue(), lit64.getValue(), totVL, - [&](SCFBuilder b, Value loopIndex) { - MDBuilder create(b); - IndexExprScope innerScope(b, &outerScope); - IndexExpr l = DimIE(loopIndex); - Value vecF16[unrollVL], vecF32H[unrollVL], - vecF32L[unrollVL]; - // Load f16 values from input via reinterpreted data tile. - for (int64_t i = 0; i < unrollVL; ++i) { - vecF16[i] = create.vec.loadIE(vecF16Type, inputAsTx64, - {SymIE(inputTileOffset), l + (i * archVL)}, {}); - } - // Convert back to f32. - for (int64_t i = 0; i < unrollVL; ++i) { - auto convertOp = - rewriter.create( - loc, vecF16[i]); - vecF32H[i] = convertOp.getResult(0); - vecF32L[i] = convertOp.getResult(1); - } - // Store f32 values back to the (normal layout) output. - DimsExpr outputAF = SymListIE(inputAF); - outputAF[E1] = outputAF[E1] + l; - for (int64_t i = 0; i < unrollVL; ++i) { - LitIE iH(i * archVL), iL(i * archVL + archVL / 2); - create.vec.storeIE( - vecF32H[i], alloc, outputAF, {iH.getValue()}); - create.vec.storeIE( - vecF32L[i], alloc, outputAF, {iL.getValue()}); - } - }); - }, - // else, we don't have a full (64 e1) tile. - [&](SCFBuilder b) { - MDBuilder create(b); - IndexExprScope middleScope(b, &outerScope); - IndexExpr tripCount = SymIE(ub1) - SymIE(e1); - // Note: if we only have multiple of VL, loop below will handle - // all as we subtract (VL-1). Aka if VL=8 and tripCount = 16, - // tripCountWithoutPartialLastVL is 16 - 7 = 9. Thus we iterate - // over i=0 & i=8 as both are < 9. - IndexExpr tripCountWithoutPartialLastVL = - tripCount - (archVL - 1); - create.scf.forLoop(litZero.getValue(), - tripCountWithoutPartialLastVL.getValue(), archVL, - [&](SCFBuilder b, Value loopIndex) { - MDBuilder create(b); - IndexExprScope innerScope(b, &middleScope); - IndexExpr l = DimIE(loopIndex); - // Load f16 values from input via reinterpreted data tile. - Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, - {SymIE(inputTileOffset), l}, {}); - // Convert back to f32. - auto convertOp = - rewriter.create( - loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Store f32 values back to the (normal layout) output. - DimsExpr outputAF = SymListIE(inputAF); - outputAF[E1] = outputAF[E1] + l; - create.vec.storeIE(vecF32H, alloc, outputAF, {}); - create.vec.storeIE( - vecF32L, alloc, outputAF, {litArchVLHalf.getValue()}); - }); - // Deal with the last values: compute f32 using simd. - IndexExpr remainingScalarValues = tripCount % archVL; - IndexExpr lastL = tripCount - remainingScalarValues; - Value vecF16 = create.vec.loadIE(vecF16Type, inputAsTx64, - {SymIE(inputTileOffset), lastL}, {}); - // Convert back to f32. - auto convertOp = - rewriter.create(loc, vecF16); - Value vecF32H = convertOp.getResult(0); - Value vecF32L = convertOp.getResult(1); - // Save into archVL value buffer. - Value bufferF32 = create.mem.alignedAlloca(bufferType); - create.vec.storeIE(vecF32H, bufferF32, {litZero}, {}); - create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {}); - // Save the remaining values as scalars. - create.scf.forLoop(litZero.getValue(), - remainingScalarValues.getValue(), 1, - [&](SCFBuilder b, Value loopIndex) { - MDBuilder create(b); - IndexExprScope innerScope(b, &middleScope); - IndexExpr l = DimIE(loopIndex); - // Load converted value. - Value f32 = create.krnl.loadIE(bufferF32, {l}); - DimsExpr outputAF = SymListIE(inputAF); - outputAF[E1] = outputAF[E1] + SymIE(lastL); - outputAF[E1] = outputAF[E1] + l; - create.krnl.storeIE(f32, alloc, outputAF); - }); - }); + IterateOverStickInputData(/* Affine, fine to use Krnl.*/ + create.krnl, op, lbs, ubs, outputDims, unstickOp.getLayoutAttr(), input, + alloc, /*unroll*/ 4, enableParallel, PREFETCH_CSU, + [&](const KrnlBuilder &b, SmallVectorImpl &vecOfF32Vals, + DimsExpr &loopIndices) { + MultiDialectBuilder create(b); + // Save the vectors of vecOfF32Vals in consecutive values of alloc. + int64_t size = vecOfF32Vals.size(); + for (int64_t i = 0; i < size; ++i) { + Value val = vecOfF32Vals[i]; + IndexExpr offset = LitIE(4 * i); // Vector of float have 4 values. + create.vec.storeIE(val, alloc, loopIndices, {offset.getValue()}); + } + }, + [&](const KrnlBuilder &b, mlir::Value scalarF32Val, + DimsExpr &loopIndices) { + // Save scalar value in alloc. + b.storeIE(scalarF32Val, alloc, loopIndices); }); rewriter.eraseOp(unstickOp); return success(); @@ -423,20 +248,19 @@ class StickExpansionPattern : public OpRewritePattern { // 64 elements. IndexExpr T = LitIE(2); DimsExpr reallocTileDims = {T, lit64}; - Value allocAsTx64 = - create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims); + Value allocAsTx64 = create.mem.reinterpretCast(alloc, reallocTileDims); // Outer loop (E1 iterates over tiles of 64 elements). - create.krnl.iterateIE( - loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) { + create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs, + [&](const KrnlBuilder &b, ValueRange loopInd) { MDBuilder create(b); IndexExprScope outerScope(create.krnl, &allocScope); DimsExpr outerIndices; - getIndexExprList(loopInd, outerIndices); + getIndexExprList(loopInd, outerIndices); DimsExpr memAF = outerIndices; memAF[E1] = memAF[E1] * 64; // Loop index for E1 is in tiles of 64. Value allocOffset = create.krnl.getLinearOffsetIndexIE(alloc, memAF); - IndexExpr allocTileIndex = SymIE(allocOffset).floorDiv(64); + IndexExpr allocTileIndex = DimIE(allocOffset).floorDiv(64); #if PREFETCH_CSU DimsExpr prefetchAF = memAF; // Prefetch current lines. @@ -454,12 +278,12 @@ class StickExpansionPattern : public OpRewritePattern { #endif #endif - create.affine.forIE(litZero, simdLoopUB, totVL, - [&](AffineBuilder &b, ValueRange loopInd) { + create.affine.forLoopIE(litZero, simdLoopUB, totVL, + [&](const AffineBuilder &b, ValueRange loopInd) { MDBuilder create(b); DimsExpr inputAF; IndexExprScope innerScope(create.krnl, &outerScope); - SymIE l(loopInd[0]); + DimIE l(loopInd[0]); getIndexExprList(memAF, inputAF); // E1: add the "l" local E1 offset. inputAF[E1] = inputAF[E1] + l; diff --git a/src/Accelerators/NNPA/zdnn.cmake b/src/Accelerators/NNPA/zdnn.cmake index 6585c32b0a..ff7a280c1f 100644 --- a/src/Accelerators/NNPA/zdnn.cmake +++ b/src/Accelerators/NNPA/zdnn.cmake @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 function(setup_zdnn version) - set(ZDNN_GITHUB_URL https://github.com/IBM/zDNN) + # Set policy CMP0097 to NEW for it to not initialize submodules + cmake_policy(SET CMP0097 NEW) + + set(ZDNN_GITHUB_URL https://github.com/IBM/zDNN.git) + message("Git clone zDNN. The ZDNN_GITHUB_URL is: ${ZDNN_GITHUB_URL}") + set(ZDNN_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/zDNN) set(ZDNN_TOPDIR ${ZDNN_PREFIX}/src/zdnn) set(ZDNN_OBJDIR ${ZDNN_TOPDIR}/zdnn/obj) @@ -12,6 +17,7 @@ function(setup_zdnn version) ExternalProject_Add(zdnn GIT_REPOSITORY ${ZDNN_GITHUB_URL} GIT_TAG ${version} + GIT_SUBMODULES "" PREFIX ${ZDNN_PREFIX} BUILD_IN_SOURCE ON CONFIGURE_COMMAND sh -c "autoconf && ./configure" @@ -55,6 +61,7 @@ function(setup_zdnn version) ExternalProject_Add(zdnn GIT_REPOSITORY ${ZDNN_GITHUB_URL} GIT_TAG ${version} + GIT_SUBMODULES "" PREFIX ${ZDNN_PREFIX} BUILD_IN_SOURCE ON CONFIGURE_COMMAND "" diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 9035b81f06..67b56a94a7 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -60,10 +60,12 @@ SUPPRESS_WARNINGS_POP using namespace mlir; -namespace onnx_mlir { - namespace { +bool isDefaultDomain(std::string_view domain) { + return domain.empty() || (domain == "ai.onnx"); +} + /// We consider opset < 6 is old. Users will see a warning if their model /// contains ops of old opset. constexpr int32_t MINIMUM_SUPPORTED_OPSET = 6; @@ -146,6 +148,8 @@ void replaceAttrRefs(onnx::GraphProto &graph, const AttrMap &attr_map) { } // namespace +namespace onnx_mlir { + namespace detail { using ValueSymbolMapping = SymbolMapping; @@ -167,7 +171,7 @@ class FrontendGenImpl { in_model_functions_ = GetModelLocalFunctions(model); importGraph(model.graph()); if (options_.verboseOutput) { - llvm::outs() + llvm::errs() << "The ONNX model has " << num_of_parameters_ << " elements in its initializers. This value would be close to and " "greater than the number of parameters in the model. Because " @@ -257,7 +261,7 @@ class FrontendGenImpl { onnx::TypeProto onnxType; if (mlir::isa(mlirType)) { // Done: Uninitialized TypeProto onnxType represents NoneType. - } else if (auto mlirTensorType = dyn_cast(mlirType)) { + } else if (auto mlirTensorType = mlir::dyn_cast(mlirType)) { onnx::TypeProto::Tensor &onnxTensorType = *onnxType.mutable_tensor_type(); onnxTensorType.set_elem_type( mlirTypeToOnnxType(mlirTensorType.getElementType())); @@ -304,12 +308,10 @@ class FrontendGenImpl { } auto shape_proto = tensor_type.shape(); for (int i = 0; i < shape_proto.dim_size(); i++) { - if (shape_proto.dim()[i].dim_value()) { + if (shape_proto.dim()[i].has_dim_value()) { // Dim is a constant value. int dim_numeric_size = shape_proto.dim()[i].dim_value(); - assert(dim_numeric_size != 0 && - "Parsed an tensor with a dimension size of zero"); - if (dim_numeric_size > 0) { + if (dim_numeric_size >= 0) { dims.push_back(dim_numeric_size); } else { // If dim_value < 0, then dim is parametric. @@ -784,7 +786,10 @@ class FrontendGenImpl { // Variadic output is a single ODS result. if (variadicOut) j = 0; - if (j < outputMap.size() && outputMap[j] >= MAX_NUM_TYPES) { + if (!givenOutputTypes.empty()) { + outputTypes.emplace_back( + UnrankedTensorType::get(givenOutputTypes[i])); + } else if (j < outputMap.size() && outputMap[j] >= MAX_NUM_TYPES) { // Mapping gives a connection with an input. Type inputType = inputs[outputMap[j] - MAX_NUM_TYPES].getType(); if (mlir::isa(inputType)) { @@ -800,9 +805,6 @@ class FrontendGenImpl { Type elementType = buildTypeFromIndex(outputMap[j]); auto outType = UnrankedTensorType::get(elementType); outputTypes.emplace_back(outType); - } else if (!givenOutputTypes.empty()) { - outputTypes.emplace_back( - UnrankedTensorType::get(givenOutputTypes[i])); } else { outputTypes.emplace_back(builder_.getNoneType()); } @@ -823,7 +825,7 @@ class FrontendGenImpl { "Op contains subgraph attributes but does not " "implement HasOnnxSubgraphOpInterface interface."); auto opWithSubgraph = - cast(op.getOperation()); + mlir::cast(op.getOperation()); auto regionIdx = opWithSubgraph.getSubgraphRegionIdx(attr.name()); auto ®ion = op->getRegion(regionIdx); region.push_back(new Block); @@ -839,7 +841,7 @@ class FrontendGenImpl { } } if (auto opWithTypeInference = - dyn_cast(op.getOperation())) { + mlir::dyn_cast(op.getOperation())) { auto outTypes = opWithTypeInference.resultTypeInference(); for (int i = 0; i < node.output().size(); i++) { OpResult result = op->getResult(i); @@ -906,10 +908,16 @@ class FrontendGenImpl { std::vector attributes; for (int i = 0; i < node.attribute_size(); ++i) { auto attr = node.attribute(i); - auto mlir_type = convertONNXTypeToMLIRType( - builder_, static_cast(attr.i())); - Attribute mlirAttr = TypeAttr::get(mlir_type); - attributes.push_back(builder_.getNamedAttr(attr.name(), mlirAttr)); + // The 'to' attribute is an integer in ONNX that represents a type. + if (attr.name() == "to") { + auto mlir_type = convertONNXTypeToMLIRType( + builder_, static_cast(attr.i())); + Attribute mlirAttr = TypeAttr::get(mlir_type); + attributes.push_back(builder_.getNamedAttr(attr.name(), mlirAttr)); + } else { + NamedAttribute na = convertOnnxAttributeProtoToMlirNamedAttribute(attr); + attributes.push_back(na); + } } // If the node has a name, then import it. @@ -961,8 +969,11 @@ class FrontendGenImpl { if (nOuts == 1) { // Inference mode with one output. buildOperation(node); + } else if (nOuts == 5) { + // Training mode with four trailing optional outputs. + buildOperation(node); } else { - // Training mode with four trailing optional outputs. Not handled yet. + // Training mode with two trailing optional outputs. buildOperation(node); } } @@ -1140,17 +1151,17 @@ class FrontendGenImpl { auto opset_list = opset_list_it->second; - // A new opset is added to onnx-mlir when it becomes imcompactible. + // A new opset is added to onnx-mlir when it becomes incompatible. // But the lowest opset in op_dialect_version_map_ is an exception. // It is the current opset when onnx-mlir project is started. - // All opset lower than the last opset should use the last opset(version) - if (node.domain().compare("ai.onnx.ml") != 0 && - current_opset < opset_list.back() && + // All opset lower than the last opset should use the last opset(version). + // Note the minimum supported opset only applies to the default domain. + if (isDefaultDomain(node.domain()) && current_opset < opset_list.back() && current_opset < MINIMUM_SUPPORTED_OPSET) - llvm::outs() << "Warning: ONNX " << node.op_type() + llvm::errs() << "\nWarning: ONNX " << node.op_type() << " in your model is using Opset " << current_opset << ", which is quite old. Please consider regenerating your " - "model with a newer Opset.\n"; + "model with a newer Opset.\n\n"; for (int i = opset_list.size() - 1; i > 0; i--) { if (current_opset < opset_list[i - 1]) { @@ -1355,20 +1366,53 @@ class FrontendGenImpl { int nOut = 0; getNodeInputs(node, inputs); nOut = node.output().size(); + std::vector givenOutputTypes; + + // We lack a way of specifying import behavior for custom domains. For now + // some are hard-coded here, but an extension specification would be + // preferred. + if (node.domain().compare("com.microsoft") == 0) { + Type outElementType = {}; + if (opName == "DequantizeLinear") { + outElementType = + cast(inputs.at(1).getType()).getElementType(); + } else if (opName == "QuantizeLinear") { + outElementType = + cast(inputs.at(2).getType()).getElementType(); + } else if (opName == "Gelu") { + outElementType = + cast(inputs.at(0).getType()).getElementType(); + } + if (outElementType) { + auto outElemTypeAttr = builder_.getNamedAttr( + "output_element_type", TypeAttr::get(outElementType)); + attributes.push_back(outElemTypeAttr); + givenOutputTypes.push_back(outElementType); + + auto shapeInferAttr = builder_.getNamedAttr( + "shape_infer_pattern", builder_.getStringAttr("MDBroadcast")); + attributes.push_back(shapeInferAttr); + } + } + // ToFix: The type inference may go wrong if the element type of the output // of CustomOp is not the same as the first input. - buildOutputAndOperation(node, inputs, nIn, nOut, attributes); + buildOutputAndOperation( + node, inputs, nIn, nOut, attributes, givenOutputTypes); } void ImportNode(const onnx::NodeProto &node) { - std::string opName = node.op_type() + GetImportVersionOfNode(node); - auto handler = import_handler_map_.find(opName); - std::vector funcs = options_.functionsToDecompose; - if (!(std::find(funcs.begin(), funcs.end(), opName) != funcs.end())) { - if (handler != import_handler_map_.end()) { - // It's a regular op with a registered handler. - (this->*(handler->second))(node); - return; + if (isDefaultDomain(node.domain()) || (node.domain() == "ai.onnx.ml") || + (node.domain() == "ai.onnx.preview.training")) { + std::string opName = node.op_type() + GetImportVersionOfNode(node); + auto handler = import_handler_map_.find(opName); + std::vector funcs = options_.functionsToDecompose; + if (!(std::find(funcs.begin(), funcs.end(), opName) != funcs.end())) { + if (handler != import_handler_map_.end()) { + // It's a regular op with a registered handler. + (this->*(handler->second))(node); + return; + } } } @@ -1416,8 +1460,8 @@ class FrontendGenImpl { if (output.type().value_case() == onnx::TypeProto::kTensorType) { Type outTy = ImportType(output.type(), dim_params); if (std::getenv("IMPORTER_FORCE_DYNAMIC")) - outTy = - UnrankedTensorType::get(cast(outTy).getElementType()); + outTy = UnrankedTensorType::get( + mlir::cast(outTy).getElementType()); if (output.type().tensor_type().has_shape()) { val.setType(outTy); } @@ -1518,7 +1562,7 @@ bool ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context, // Code copied from onnx/onnx/version_coverter/convert.cc for (auto it = model.opset_import().begin(); it != model.opset_import().end(); ++it) { - if (it->domain() == "" || it->domain() == "ai.onnx") { + if (isDefaultDomain(it->domain())) { originVersion = it->version(); break; } @@ -1526,7 +1570,7 @@ bool ImportFrontendModelInternal(onnx::ModelProto &model, MLIRContext &context, if (options.allowSorting && !IsTopologicallySorted(model.graph())) { if (!SortGraph(model.mutable_graph())) { - llvm::outs() << "The graph is not topologically sortable.\n"; + llvm::errs() << "The graph is not topologically sortable.\n"; return false; } } @@ -1590,9 +1634,9 @@ int readAndStripComments( if (line->contains("//")) { // Not stripping end-of-line comments because there's no robust way to // distinguish them from valid uses of // in the json itself. - llvm::errs() << "Warning: possible invalid end-of-line // comment in " + llvm::errs() << "\nWarning: possible invalid end-of-line // comment in " "json input file " - << fname.str() << ":" << line.line_number() << "\n"; + << fname.str() << ":" << line.line_number() << "\n\n"; } contents.append(*line); } diff --git a/src/Builder/ModelInputShaper.cpp b/src/Builder/ModelInputShaper.cpp index 067ab67896..6b77d9dc33 100644 --- a/src/Builder/ModelInputShaper.cpp +++ b/src/Builder/ModelInputShaper.cpp @@ -84,9 +84,9 @@ void ModelInputShaper::setShapeInformation( } if (hasAllInputSetting && (inputs_shape_information_.size() > 1)) { llvm::outs() - << "Warning: Found multiple settings that includes -1:d1xd2x...xdn " + << "\nWarning: Found multiple settings that includes -1:d1xd2x...xdn " "for all inputs. Only the first -1:d1xd2x...xdn is effective and " - "the other settings are ignored.\n"; + "the other settings are ignored.\n\n"; } } } diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index 6122d0c205..a77a9028a5 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -5,8 +5,8 @@ //******************************************************** op_dialect_version_map_["Abs"] = {13}; -op_dialect_version_map_["Acos"] = {7}; -op_dialect_version_map_["Acosh"] = {9}; +op_dialect_version_map_["Acos"] = {22}; +op_dialect_version_map_["Acosh"] = {22}; op_dialect_version_map_["Adagrad"] = {1}; op_dialect_version_map_["Adam"] = {1}; op_dialect_version_map_["Add"] = {14}; @@ -14,13 +14,13 @@ op_dialect_version_map_["And"] = {7}; op_dialect_version_map_["ArgMax"] = {13}; op_dialect_version_map_["ArgMin"] = {13}; op_dialect_version_map_["ArrayFeatureExtractor"] = {1}; -op_dialect_version_map_["Asin"] = {7}; -op_dialect_version_map_["Asinh"] = {9}; -op_dialect_version_map_["Atan"] = {7}; -op_dialect_version_map_["Atanh"] = {9}; -op_dialect_version_map_["AveragePool"] = {19}; -op_dialect_version_map_["BatchNormalization"] = {15}; -op_dialect_version_map_["Bernoulli"] = {15}; +op_dialect_version_map_["Asin"] = {22}; +op_dialect_version_map_["Asinh"] = {22}; +op_dialect_version_map_["Atan"] = {22}; +op_dialect_version_map_["Atanh"] = {22}; +op_dialect_version_map_["AveragePool"] = {22}; +op_dialect_version_map_["BatchNormalization"] = {15, 9}; +op_dialect_version_map_["Bernoulli"] = {22}; op_dialect_version_map_["Binarizer"] = {1}; op_dialect_version_map_["BitShift"] = {11}; op_dialect_version_map_["BitwiseAnd"] = {18}; @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18}; op_dialect_version_map_["BitwiseOr"] = {18}; op_dialect_version_map_["BitwiseXor"] = {18}; op_dialect_version_map_["BlackmanWindow"] = {17}; -op_dialect_version_map_["Cast"] = {19}; +op_dialect_version_map_["Cast"] = {21}; op_dialect_version_map_["CastLike"] = {19}; op_dialect_version_map_["CastMap"] = {1}; op_dialect_version_map_["CategoryMapper"] = {1}; @@ -41,60 +41,60 @@ op_dialect_version_map_["Concat"] = {13}; op_dialect_version_map_["ConcatFromSequence"] = {11}; op_dialect_version_map_["Constant"] = {19}; op_dialect_version_map_["ConstantOfShape"] = {20}; -op_dialect_version_map_["Conv"] = {11}; +op_dialect_version_map_["Conv"] = {22}; op_dialect_version_map_["ConvInteger"] = {10}; -op_dialect_version_map_["ConvTranspose"] = {11}; -op_dialect_version_map_["Cos"] = {7}; -op_dialect_version_map_["Cosh"] = {9}; +op_dialect_version_map_["ConvTranspose"] = {22}; +op_dialect_version_map_["Cos"] = {22}; +op_dialect_version_map_["Cosh"] = {22}; op_dialect_version_map_["Col2Im"] = {18}; op_dialect_version_map_["CumSum"] = {14}; -op_dialect_version_map_["DeformConv"] = {19}; +op_dialect_version_map_["DeformConv"] = {22}; op_dialect_version_map_["DepthToSpace"] = {13}; op_dialect_version_map_["DequantizeLinear"] = {19}; -op_dialect_version_map_["Det"] = {11}; +op_dialect_version_map_["Det"] = {22}; op_dialect_version_map_["DFT"] = {20, 17}; op_dialect_version_map_["DictVectorizer"] = {1}; op_dialect_version_map_["Div"] = {14}; -op_dialect_version_map_["Dropout"] = {13}; +op_dialect_version_map_["Dropout"] = {22}; op_dialect_version_map_["DynamicQuantizeLinear"] = {11}; op_dialect_version_map_["Einsum"] = {12}; -op_dialect_version_map_["Elu"] = {6}; +op_dialect_version_map_["Elu"] = {22}; op_dialect_version_map_["Equal"] = {19}; op_dialect_version_map_["Erf"] = {13}; op_dialect_version_map_["Exp"] = {13}; op_dialect_version_map_["Expand"] = {13}; -op_dialect_version_map_["EyeLike"] = {9}; +op_dialect_version_map_["EyeLike"] = {22}; op_dialect_version_map_["FeatureVectorizer"] = {1}; op_dialect_version_map_["Flatten"] = {13}; op_dialect_version_map_["Floor"] = {13}; -op_dialect_version_map_["GRU"] = {14}; +op_dialect_version_map_["GRU"] = {22}; op_dialect_version_map_["Gather"] = {13}; op_dialect_version_map_["GatherElements"] = {13}; op_dialect_version_map_["GatherND"] = {13}; op_dialect_version_map_["Gelu"] = {20}; op_dialect_version_map_["Gemm"] = {13}; -op_dialect_version_map_["GlobalAveragePool"] = {1}; +op_dialect_version_map_["GlobalAveragePool"] = {22}; op_dialect_version_map_["GlobalLpPool"] = {2}; -op_dialect_version_map_["GlobalMaxPool"] = {1}; +op_dialect_version_map_["GlobalMaxPool"] = {22}; op_dialect_version_map_["Gradient"] = {1}; op_dialect_version_map_["Greater"] = {13}; op_dialect_version_map_["GreaterOrEqual"] = {16}; -op_dialect_version_map_["GridSample"] = {16}; -op_dialect_version_map_["GroupNormalization"] = {18}; +op_dialect_version_map_["GridSample"] = {22, 20, 16}; +op_dialect_version_map_["GroupNormalization"] = {21, 18}; op_dialect_version_map_["HammingWindow"] = {17}; op_dialect_version_map_["HannWindow"] = {17}; -op_dialect_version_map_["HardSigmoid"] = {6}; +op_dialect_version_map_["HardSigmoid"] = {22}; op_dialect_version_map_["Hardmax"] = {13}; -op_dialect_version_map_["HardSwish"] = {14}; +op_dialect_version_map_["HardSwish"] = {22}; op_dialect_version_map_["Identity"] = {19}; op_dialect_version_map_["If"] = {19}; op_dialect_version_map_["Imputer"] = {1}; -op_dialect_version_map_["InstanceNormalization"] = {6}; +op_dialect_version_map_["InstanceNormalization"] = {22}; op_dialect_version_map_["IsInf"] = {20}; op_dialect_version_map_["IsNaN"] = {20}; op_dialect_version_map_["LayerNormalization"] = {17}; op_dialect_version_map_["LRN"] = {13}; -op_dialect_version_map_["LSTM"] = {14}; +op_dialect_version_map_["LSTM"] = {22}; op_dialect_version_map_["LabelEncoder"] = {2}; op_dialect_version_map_["LeakyRelu"] = {16}; op_dialect_version_map_["Less"] = {13}; @@ -104,25 +104,25 @@ op_dialect_version_map_["LinearRegressor"] = {1}; op_dialect_version_map_["Log"] = {13}; op_dialect_version_map_["LogSoftmax"] = {13}; op_dialect_version_map_["Loop"] = {19}; -op_dialect_version_map_["LpNormalization"] = {1}; -op_dialect_version_map_["LpPool"] = {18}; +op_dialect_version_map_["LpNormalization"] = {22}; +op_dialect_version_map_["LpPool"] = {22}; op_dialect_version_map_["MatMul"] = {13}; op_dialect_version_map_["MatMulInteger"] = {10}; op_dialect_version_map_["Max"] = {13}; -op_dialect_version_map_["MaxPool"] = {12}; -op_dialect_version_map_["MaxRoiPool"] = {1}; -op_dialect_version_map_["MaxUnpool"] = {11}; +op_dialect_version_map_["MaxPool"] = {22}; +op_dialect_version_map_["MaxRoiPool"] = {22}; +op_dialect_version_map_["MaxUnpool"] = {22}; op_dialect_version_map_["Mean"] = {13}; op_dialect_version_map_["MeanVarianceNormalization"] = {13}; op_dialect_version_map_["MelWeightMatrix"] = {17}; op_dialect_version_map_["Min"] = {13}; -op_dialect_version_map_["Mish"] = {18}; +op_dialect_version_map_["Mish"] = {22}; op_dialect_version_map_["Mod"] = {13}; op_dialect_version_map_["Momentum"] = {1}; op_dialect_version_map_["Mul"] = {14}; -op_dialect_version_map_["Multinomial"] = {7}; +op_dialect_version_map_["Multinomial"] = {22}; op_dialect_version_map_["Neg"] = {13}; -op_dialect_version_map_["NegativeLogLikelihoodLoss"] = {13}; +op_dialect_version_map_["NegativeLogLikelihoodLoss"] = {22}; op_dialect_version_map_["NonMaxSuppression"] = {11}; op_dialect_version_map_["NonZero"] = {13}; op_dialect_version_map_["Normalizer"] = {1}; @@ -139,11 +139,11 @@ op_dialect_version_map_["Pow"] = {15}; op_dialect_version_map_["QLinearConv"] = {10}; op_dialect_version_map_["QLinearMatMul"] = {10}; op_dialect_version_map_["QuantizeLinear"] = {19}; -op_dialect_version_map_["RNN"] = {14}; -op_dialect_version_map_["RandomNormal"] = {1}; -op_dialect_version_map_["RandomNormalLike"] = {1}; -op_dialect_version_map_["RandomUniform"] = {1}; -op_dialect_version_map_["RandomUniformLike"] = {1}; +op_dialect_version_map_["RNN"] = {22}; +op_dialect_version_map_["RandomNormal"] = {22}; +op_dialect_version_map_["RandomNormalLike"] = {22}; +op_dialect_version_map_["RandomUniform"] = {22}; +op_dialect_version_map_["RandomUniformLike"] = {22}; op_dialect_version_map_["Range"] = {11}; op_dialect_version_map_["Reciprocal"] = {13}; op_dialect_version_map_["ReduceL1"] = {18, 13}; @@ -160,8 +160,8 @@ op_dialect_version_map_["Relu"] = {14}; op_dialect_version_map_["Reshape"] = {19}; op_dialect_version_map_["Resize"] = {19, 18, 13, 11, 10}; op_dialect_version_map_["ReverseSequence"] = {10}; -op_dialect_version_map_["RoiAlign"] = {16}; -op_dialect_version_map_["Round"] = {11}; +op_dialect_version_map_["RoiAlign"] = {22}; +op_dialect_version_map_["Round"] = {22}; op_dialect_version_map_["SVMClassifier"] = {1}; op_dialect_version_map_["SVMRegressor"] = {1}; op_dialect_version_map_["Scaler"] = {1}; @@ -169,7 +169,7 @@ op_dialect_version_map_["Scan"] = {19}; op_dialect_version_map_["Scatter"] = {11}; op_dialect_version_map_["ScatterElements"] = {18}; op_dialect_version_map_["ScatterND"] = {18}; -op_dialect_version_map_["Selu"] = {6}; +op_dialect_version_map_["Selu"] = {22}; op_dialect_version_map_["SequenceAt"] = {11}; op_dialect_version_map_["SequenceConstruct"] = {11}; op_dialect_version_map_["SequenceEmpty"] = {11}; @@ -181,14 +181,14 @@ op_dialect_version_map_["Shape"] = {19}; op_dialect_version_map_["Shrink"] = {9}; op_dialect_version_map_["Sigmoid"] = {13}; op_dialect_version_map_["Sign"] = {13}; -op_dialect_version_map_["Sin"] = {7}; -op_dialect_version_map_["Sinh"] = {9}; +op_dialect_version_map_["Sin"] = {22}; +op_dialect_version_map_["Sinh"] = {22}; op_dialect_version_map_["Size"] = {19}; op_dialect_version_map_["Slice"] = {13}; op_dialect_version_map_["Softmax"] = {13, 11}; op_dialect_version_map_["SoftmaxCrossEntropyLoss"] = {13}; -op_dialect_version_map_["Softplus"] = {1}; -op_dialect_version_map_["Softsign"] = {1}; +op_dialect_version_map_["Softplus"] = {22}; +op_dialect_version_map_["Softsign"] = {22}; op_dialect_version_map_["SpaceToDepth"] = {13}; op_dialect_version_map_["Split"] = {18, 13, 11}; op_dialect_version_map_["SplitToSequence"] = {11}; @@ -198,10 +198,10 @@ op_dialect_version_map_["StringNormalizer"] = {10}; op_dialect_version_map_["STFT"] = {17}; op_dialect_version_map_["Sub"] = {14}; op_dialect_version_map_["Sum"] = {13}; -op_dialect_version_map_["Tan"] = {7}; +op_dialect_version_map_["Tan"] = {22}; op_dialect_version_map_["Tanh"] = {13}; op_dialect_version_map_["TfIdfVectorizer"] = {9}; -op_dialect_version_map_["ThresholdedRelu"] = {10}; +op_dialect_version_map_["ThresholdedRelu"] = {22}; op_dialect_version_map_["Tile"] = {13}; op_dialect_version_map_["TopK"] = {11}; op_dialect_version_map_["Transpose"] = {13}; @@ -240,6 +240,8 @@ import_handler_map_["AveragePool"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["BatchNormalization"] = &onnx_mlir::detail::FrontendGenImpl::ImportNodeBatchNormalization; +import_handler_map_["BatchNormalizationV9"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["Bernoulli"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["BitShift"] = @@ -356,8 +358,14 @@ import_handler_map_["GreaterOrEqual"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GridSample"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GridSampleV20"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GridSampleV16"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["GroupNormalization"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GroupNormalizationV18"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["HammingWindow"] = &onnx_mlir::detail::FrontendGenImpl::buildOperation; import_handler_map_["HannWindow"] = diff --git a/src/Compiler/CMakeLists.txt b/src/Compiler/CMakeLists.txt index 5f64e8b7f0..14a9c27025 100644 --- a/src/Compiler/CMakeLists.txt +++ b/src/Compiler/CMakeLists.txt @@ -37,6 +37,8 @@ if (ONNX_MLIR_VENDOR) endif() target_compile_definitions(OMVersion PUBLIC ${DEFINITIONS}) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + add_onnx_mlir_library(OMCompilerOptions CompilerOptions.cpp @@ -70,6 +72,7 @@ add_onnx_mlir_library(OMCompilerDialects OMONNXOps MLIRIR MLIROpenMPToLLVMIRTranslation + ${dialect_libs} ) add_onnx_mlir_library(OMCompilerPasses @@ -159,6 +162,7 @@ add_onnx_mlir_library(OMCompilerUtils ExternalUtil llc opt + ${dialect_libs} INCLUDE_DIRS PRIVATE ${FILE_GENERATE_DIR} @@ -167,10 +171,12 @@ add_onnx_mlir_library(OMCompilerUtils ${ONNX_MLIR_SRC_ROOT}/include LINK_LIBS PUBLIC + MLIRBytecodeWriter OMCompilerDialects OMCompilerPasses OMAccelerator OMVersion + ${dialect_libs} # Link LLVM libraries necessary to query which target architectures # are configured. @@ -234,11 +240,12 @@ endif() pybind11_add_module(PyCompile PyOMCompileSession.cpp) add_dependencies(PyCompile onnx_proto) -target_compile_options(PyCompile - PRIVATE - $<$,$,$>:-frtti -fexceptions> - $<$:/EHsc /GR> - ) +if (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "MSVC") + target_compile_options(PyCompile PRIVATE /EHsc /GR) +elseif (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "GNU") + target_compile_options(PyCompile PRIVATE -frtti -fexceptions) +endif() + target_compile_definitions(PyCompile PRIVATE $ @@ -252,6 +259,8 @@ target_link_libraries(PyCompile OMCompiler ) -install(TARGETS PyCompile - DESTINATION lib - ) +if(ONNX_MLIR_INSTALL_PYTHON_EXTENSIONS) + install(TARGETS PyCompile + DESTINATION lib + ) +endif() diff --git a/src/Compiler/CompilerDialects.cpp b/src/Compiler/CompilerDialects.cpp index a54d014977..87fab41a9c 100644 --- a/src/Compiler/CompilerDialects.cpp +++ b/src/Compiler/CompilerDialects.cpp @@ -46,8 +46,9 @@ DialectRegistry registerDialects(ArrayRef accels) { for (auto *accel : accel::Accelerator::getAccelerators()) accel->registerDialects(registry); - if (useOldBufferization) - memref::registerAllocationOpInterfaceExternalModels(registry); + // Register interface needed by both old and new buffer deallocation pass. + memref::registerAllocationOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); return registry; } diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 6d010bd219..0b2d764856 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -4,7 +4,7 @@ //===------------------------ CompilerOptions.cpp -------------------------===// // -// Copyright 2022, 2023 The IBM Research Authors. +// Copyright 2022, 2024 The IBM Research Authors. // // ============================================================================= // @@ -33,6 +33,7 @@ std::vector maccel; // common for both OptLevel OptimizationLevel; // common for both std::string mtriple; // common for both std::string mcpu; // common for both +float nnpaEpsilon; // common for both std::string march; // common for both InstrumentStages instrumentStage; // common for both bool onnxConstPropRoundFPToInt; // common for both @@ -42,10 +43,15 @@ bool enableONNXHybridPass; // common for both std::vector functionsToDecompose; // common for both std::string opsForCall; // common for both bool disableKrnlOpFusion; // common for both +bool disableQuantZeroPoint; // common for both +bool enableKrnlBufferReuse; // common for both +bool disableMemRefPrefetch; // common for both EmissionTargetType emissionTarget; // onnx-mlir only bool invokeOnnxVersionConverter; // onnx-mlir only bool preserveLocations; // onnx-mlir only bool printIR; // onnx-mlir only +bool printBytecode; // onnx-mlir only +bool doNotEmitFullMLIRCode; // onnx-mlir only bool preserveBitcode; // onnx-mlir only bool preserveLLVMIR; // onnx-mlir only bool preserveMLIR; // onnx-mlir only @@ -54,6 +60,7 @@ int repeatOnnxTransform; // onnx-mlir only std::string shapeInformation; // onnx-mlir only std::string dimParams; // onnx-mlir only ModelSize modelSize; // onnx-mlir only +std::string externalDataDir; // onnx-mlir only bool storeConstantsToFile; // onnx-mlir only float constantsToFileTotalThreshold; // onnx-mlir only float constantsToFileSingleThreshold; // onnx-mlir only @@ -70,7 +77,9 @@ int onnxOpTransformThreshold; // onnx-mlir only bool onnxOpTransformReport; // onnx-mlir only bool enableParallel; // onnx-mlir only bool disableSimdOption; // onnx-mlir only +bool enableFastMathOption; // onnx-mlir only bool disableRecomposeOption; // onnx-mlir only +bool disableConvTransposeDecomposeOption; // onnx-mlir only bool enableSimdDataLayout; // onnx-mlir only bool verifyInputTensors; // onnx-mlir only bool allowSorting; // onnx-mlir only @@ -93,11 +102,11 @@ bool allowUnregisteredDialects; // onnx-mlir-opt only // Category for common options shared between onnx-mlir and onnx-mlir-opt. llvm::cl::OptionCategory OnnxMlirCommonOptions("common options", - "These are options shared between onnx-mlir and onnx-mlir-opt"); + "These are options shared between onnx-mlir and onnx-mlir-opt."); // Category for options for onnx-mlir only. llvm::cl::OptionCategory OnnxMlirOptions( - "onnx-mlir options", "These are onnx-mlir frontend options"); + "onnx-mlir options", "These are onnx-mlir frontend options."); // Category for options for onnx-mlir-opt only. llvm::cl::OptionCategory OnnxMlirOptOptions( @@ -112,7 +121,7 @@ static llvm::cl::opt inputFilenameOpt(llvm::cl::Positional, static llvm::cl::opt outputBaseNameOpt("o", llvm::cl::desc("For onnx-mlir, specify the base path for output file, " - "extension will be added. Default is input filename " + "extension will be added.\nDefault is input filename " "without the extension, or \"stdin\" if input is stdin.\n" "For onnx-mlir-opt, specify the output filename. Default is " "stdout."), @@ -123,7 +132,7 @@ static llvm::cl::opt outputBaseNameOpt("o", static llvm::cl::list> maccelOpt("maccel", - llvm::cl::desc("Specify an accelerator to generate code for"), + llvm::cl::desc("Specify an accelerator to generate code for."), llvm::cl::location(maccel), // clang-format off llvm::cl::values( @@ -135,27 +144,36 @@ static llvm::cl::list OptimizationLevelOpt( llvm::cl::desc("Levels:"), - llvm::cl::values(clEnumVal(O0, "Optimization level 0 (default):"), - clEnumVal(O1, "Optimization level 1"), - clEnumVal(O2, "Optimization level 2"), - clEnumVal(O3, "Optimization level 3, SIMD is enabled")), + llvm::cl::values(clEnumVal(O0, "Optimization level 0 (default)."), + clEnumVal(O1, "Optimization level 1."), + clEnumVal(O2, "Optimization level 2."), + clEnumVal(O3, "Optimization level 3, SIMD is enabled.")), llvm::cl::location(OptimizationLevel), llvm::cl::init(O0), llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt mtripleOpt("mtriple", - llvm::cl::desc("Override target triple for module"), + llvm::cl::desc("Override target triple for module."), llvm::cl::value_desc("LLVM target triple"), llvm::cl::location(mtriple), llvm::cl::init(kDefaultTriple), llvm::cl::cat(OnnxMlirCommonOptions), llvm::cl::ValueRequired); static llvm::cl::opt mcpuOpt("mcpu", - llvm::cl::desc("Target cpu"), + llvm::cl::desc("Target cpu."), llvm::cl::value_desc("Target a specific CPU type"), llvm::cl::location(mcpu), llvm::cl::cat(OnnxMlirCommonOptions), llvm::cl::ValueRequired); +static llvm::cl::opt nnpaEpsilonOpt("nnpa-epsilon", + // TODO: what text should go here. + llvm::cl::desc("A value added to inputs during computations to prevent " + "undefined mathematical operations, \n" + "such as division by zero or logarithms of zero. Default " + "value set to 1e-5."), + llvm::cl::value_desc("Float value"), llvm::cl::location(nnpaEpsilon), + llvm::cl::cat(OnnxMlirCommonOptions), llvm::cl::init(1e-5)); + static llvm::cl::opt marchOpt("march", - llvm::cl::desc("Target architecture to generate code for"), + llvm::cl::desc("Target architecture to generate code for."), llvm::cl::value_desc("Target a specific architecture type"), llvm::cl::location(march), llvm::cl::cat(OnnxMlirCommonOptions), llvm::cl::ValueRequired); @@ -171,16 +189,17 @@ static llvm::cl::opt onnxConstPropRoundFPToIntOpt( "onnx-const-prop-round-fp-to-int", llvm::cl::desc("If true constant propagates onnx.Cast from a floating " "point type to an integer type by rounding to nearest, " - "ties to even. If false truncates towards zero."), + "ties to even.\nIf false truncates towards zero."), llvm::cl::location(onnxConstPropRoundFPToInt), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt onnxConstPropExpansionBoundOpt( "onnx-const-prop-expansion-bound", - llvm::cl::desc("ONNX dialect constant propagation maximum expansion factor." - " Constants are not propagated if their bytes size exceed" - " the aggregate operands' sizes by more than this factor." - " Set to -1 to always propagate, which is the default."), + llvm::cl::desc( + "ONNX dialect constant propagation maximum expansion factor\n" + "Constants are not propagated if their bytes size exceed " + "the aggregate operands' sizes by more than this factor\n" + "Set to -1 to always propagate, which is the default."), llvm::cl::location(onnxConstPropExpansionBound), llvm::cl::init(-1), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -193,29 +212,59 @@ static llvm::cl::list> llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt enableONNXHybridPassOpt("onnx-hybrid-pass", - llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n" + llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n" "Set to 'false' if you want to disable ONNX hybrid pass."), llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::list> functionsToDecomposeOpt("functions-to-decompose", - llvm::cl::desc("Specify ONNX functions to decompose"), + llvm::cl::desc("Specify ONNX functions to decompose."), llvm::cl::location(functionsToDecompose), llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt disableKrnlOpFusionOpt( "disable-krnl-op-fusion", - llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n" + llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n" "Set to 'true' if you want to disable fusion."), llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt disable_quantization_zero_point( + "disable-quantization-zero-point", + llvm::cl::desc( + "Disable the use of zero-point in quantization (default=false).\n" + "Set to 'true' if you want to disable the use of zero-point\n" + "in dyn/static quantization/dequantization."), + llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + +static llvm::cl::opt enableKrnlBufferReuseOpt( + "enable-krnl-buffer-reuse", + llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass " + "(default=false).\n" + "Set to 'true' if you want to enable buffer reuse."), + llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + +static llvm::cl::opt disableMemRefPrefetchOpt( + "disable-memref-prefetch", + llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n" + "Set to 'true' if you want to disable prefetch."), + llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::opt disableRecomposeOptionOpt("disable-recompose", llvm::cl::desc("Disable recomposition of ONNX operations."), llvm::cl::location(disableRecomposeOption), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt disableConvTranposeDecomposeOptionOpt( + "disable-convtranspose-decompose", + llvm::cl::desc("Disable decomposition of ONNX ConvTranspose operator."), + llvm::cl::location(disableConvTransposeDecomposeOption), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); + // Options for onnx-mlir only static llvm::cl::opt emissionTargetOpt( llvm::cl::desc("Choose target to emit:"), @@ -238,116 +287,150 @@ static llvm::cl::opt emissionTargetOpt( static llvm::cl::opt invokeOnnxVersionConverterOpt( "invokeOnnxVersionConverter", - llvm::cl::desc( - "call onnx version converter to convert ONNX model to current version"), + llvm::cl::desc("Call onnx version converter to convert ONNX model to " + "current version."), llvm::cl::location(invokeOnnxVersionConverter), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt preserveLocationsOpt("preserveLocations", - llvm::cl::desc("emit location data:"), + llvm::cl::desc("Emit location data."), llvm::cl::location(preserveLocations), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt printIROpt("printIR", - llvm::cl::desc("print the IR to stdout:"), llvm::cl::location(printIR), + llvm::cl::desc("Print the IR to stdout:."), llvm::cl::location(printIR), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); -static llvm::cl::opt preserveBitcodeOpt("preserveBitcode", +static llvm::cl::opt printBytecodeOpt("printBytecode", + llvm::cl::desc("print bytecode to stdout:"), + llvm::cl::location(printBytecode), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + +static llvm::cl::opt doNotEmitFullMLIRCodeOpt( + "do-not-emit-full-mlir-code", llvm::cl::desc( - "dont delete the bitcode files (optimized and unoptimized):"), + "Do not emit the MLIR the constant values are embeded " + "(onnx.mlir). Emit only the MLIR without the constants " + "(.tmp). Need to be used with emitting MLIR options such as " + "--EmitONNXIR and --EmitMLIR."), + llvm::cl::location(doNotEmitFullMLIRCode), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + +static llvm::cl::opt preserveBitcodeOpt("preserveBitcode", + llvm::cl::desc("Preserve the bitcode files (optimized and unoptimized)."), llvm::cl::location(preserveBitcode), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt preserveLLVMIROpt("preserveLLVMIR", - llvm::cl::desc("dont delete the LLVMIR files:"), + llvm::cl::desc("Preserve the LLVMIR files."), llvm::cl::location(preserveLLVMIR), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt preserveMLIROpt("preserveMLIR", - llvm::cl::desc("dont delete the MLIR files (input and llvm):"), + llvm::cl::desc("Preserve the MLIR files (input and llvm)."), llvm::cl::location(preserveMLIR), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt useOnnxModelTypesOpt("useOnnxModelTypes", - llvm::cl::desc("use types and shapes from ONNX model"), + llvm::cl::desc("Use types and shapes from ONNX model."), llvm::cl::location(useOnnxModelTypes), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt repeatOnnxTransformOpt("repeatOnnxTransform", - llvm::cl::desc( - "invoke extra onnx transform pass(shape inference, constant and etc.)"), + llvm::cl::desc("Invoke extra onnx transform pass(shape inference, constant " + "and etc.)."), llvm::cl::location(repeatOnnxTransform), llvm::cl::init(0), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt shapeInformationOpt("shapeInformation", llvm::cl::desc( - "Custom shapes for the inputs of the ONNX model, e.g. setting static " + "Custom shapes for the inputs of the ONNX model, e.g. setting " + "static " "shapes for dynamic inputs.\n" "\"value\" is in the format of " "\"INPUT_ID1:D1xD2x...xDn,INPUT_ID2:D1xD2x...xDn, ...\",\n" - "where \"INPUT_ID1, INPUT_ID2, ...\" are input indices (starting from " + "where \"INPUT_ID1, INPUT_ID2, ...\" are input indices (starting " + "from " "0 or being -1 for all input indices), and\n" "\"D1, D2, ...\" are dimension sizes (positive integers or -1 for " - "unknown dimensions)"), + "unknown dimensions)."), llvm::cl::value_desc("value"), llvm::cl::location(shapeInformation), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt dimParamsOpt("dimParams", llvm::cl::desc( - "Custom onnx.dim_params attributes for the inputs of the ONNX model for" - "specifying relationship among dynamic dimensions of the inputs.\n" + "Custom onnx.dim_params attributes for the inputs of the ONNX " + "model " + "for specifying relationship among dynamic dimensions of the " + "inputs.\n" "\"value\" is in the format of " "\"INPUT_ID1:D1=S1,D2=S2,...,Dn=Sn|INPUT_ID2:D1=T1,D2=T2,...Dn=Tn|" "...\" where \"INPUT_ID1, INPUT_ID2, ...\" are input indices " "(starting from 0 or being -1 for all input indices), and\n" - "\"S1, S2, ...\" and \"T2, T2, ...\" are symbols to specify that same " + "\"S1, S2, ...\" and \"T2, T2, ...\" are symbols to specify that " + "same " "symbols have the same value. " "All dimensions of onnx.dim_params for a specified input index in " - "the original onnx model are cleared and repalced by this option. " - "onnx.dim_params for other input indices in the original onnx model " + "the original onnx model are cleared and replaced by this option. " + "onnx.dim_params for other input indices in the original onnx " + "model " "are not cleared"), llvm::cl::value_desc("value"), llvm::cl::location(dimParams), llvm::cl::cat(OnnxMlirOptions)); // Default value is defined by the OnnxMlirEnvOptionName constant string -// variable, but the default setting mechanism here cannot be used here as we -// need to evaluate this value prior to the compiler options being set. Proper -// handling of the value of this compiler option is set by the calling the -// parseCustomEnvFlagsCommandLineOption(...) function. +// variable, but the default setting mechanism here cannot be used here as +// we need to evaluate this value prior to the compiler options being set. +// Proper handling of the value of this compiler option is set by the +// calling the parseCustomEnvFlagsCommandLineOption(...) function. static llvm::cl::opt customEnvFlagsOpt("customEnvFlags", llvm::cl::desc("Override default option env var OnnxMlirEnvOptionName: " - "ONNX_MLIR_FLAGS"), + "ONNX_MLIR_FLAGS."), llvm::cl::value_desc("option env var"), llvm::cl::location(customEnvFlags), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt modelSizeOpt("modelSize", - llvm::cl::desc("Model to generate code"), + llvm::cl::desc("Model to generate code:"), llvm::cl::value_desc("Only support small or large"), llvm::cl::location(modelSize), llvm::cl::values( clEnumVal(small, "Generate code for the small model. " "No special treatment at this moment. This is the " - "default code model"), + "default code model."), clEnumVal(large, "Generate code for the large model. " "Global constants are put into large read-only data section.")), llvm::cl::init(small), llvm::cl::cat(OnnxMlirOptions), llvm::cl::ValueRequired); +static llvm::cl::opt externalDataDirOpt("external-data-dir", + llvm::cl::desc( + "ONNX constant initializers can be stored in a separate file. " + "The filename is stored in the ONNX model without a path. By default " + "the path is the same folder as the ONNX model. This allows that " + "default to be overridden.\n" + "Default is empty (use default)."), + llvm::cl::location(externalDataDir), llvm::cl::init(""), + llvm::cl::cat(OnnxMlirOptions)); + static llvm::cl::opt storeConstantsToFileOpt( "store-constants-to-file", llvm::cl::desc( "Constants will be stored on a binary file instead of be embedded " - "into the model.so when compiling a big model. The binary file is in " - "the same folder as the model.so and has the same name as the model " - "with the extension of .constants.bin. For inference, " + "into the model.so when compiling a big model.\nThe binary file is " + "in " + "the same folder as the model.so and has the same name as the " + "model " + "with the extension of .constants.bin.\nFor inference, " "model.constants.bin must be at the same folder as the inference " - "program. If model.constants.bin is at another folder, use the " - "environment variable OM_CONSTANT_PATH to set the constant folder. " + "program.\nIf model.constants.bin is at another folder, use the " + "environment variable OM_CONSTANT_PATH to set the constant " + "folder.\n" "When using this option, two other options " "constants-to-file-single-threshold and " - "constants-to-file-total-threshold can be used to finetune the amount " - "of constants stored on the file. Windows will be supported soon. " + "constants-to-file-total-threshold can be used to fine-tune the " + "amount " + "of constants stored on the file.\n" "Default is True."), llvm::cl::location(storeConstantsToFile), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); @@ -356,8 +439,9 @@ static llvm::cl::opt constantsToFileTotalThresholdOpt( "constants-to-file-total-threshold", llvm::cl::desc( "Put global constants to a file if the total size in " - "bytes of constants is greater than this threshold. " - "store-constants-to-file must be enabled for this to be effective. " + "bytes of constants is greater than this threshold.\n" + "store-constants-to-file must be enabled for this to be " + "effective.\n" "Only count constants whose size is greater than " "constants-to-file-single-threshold. Value is in GB. Default is " "1.5GB."), @@ -368,40 +452,48 @@ static llvm::cl::opt constantsToFileSingleThresholdOpt( "constants-to-file-single-threshold", llvm::cl::desc( "Put global constants to a file if a single constant's size in " - "bytes is greater than this threshold. " - "store-constants-to-file must be enabled for this to be effective. " + "bytes is greater than this threshold.\n" + "store-constants-to-file must be enabled for this to be " + "effective.\n" "Total sizes in bytes of satisfied constants must be greater than " - "constants-to-file-total-threshold. Value is in KB. Default is 1KB."), + "constants-to-file-total-threshold. Value is in KB. Default is " + "1KB."), llvm::cl::location(constantsToFileSingleThreshold), llvm::cl::init(1.0), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt VerboseOutputOpt("v", - llvm::cl::desc("Use verbose output"), llvm::cl::location(VerboseOutput), + llvm::cl::desc("Use verbose output."), llvm::cl::location(VerboseOutput), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::list> XoptOpt("Xopt", - llvm::cl::desc("Arguments to forward to LLVM's 'opt' option processing"), + llvm::cl::desc("Arguments to forward to LLVM's 'opt' option processing " + "multiple arguments to 'opt' need to be pass with " + "separate 'Xopt'.\n" + "For example, '-Xopt opt1 -Xopt opt2 ...'"), llvm::cl::value_desc("A valid LLVM's 'opt' option"), llvm::cl::location(Xopt), llvm::cl::cat(OnnxMlirOptions), llvm::cl::Hidden, llvm::cl::ValueRequired, llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); static llvm::cl::list> XllcOpt("Xllc", - llvm::cl::desc("Arguments to forward to LLVM's 'llc' option processing"), + llvm::cl::desc("Arguments to forward to LLVM's 'llc' option processing " + "multiple arguments to 'llc' need to be pass with " + "separate 'Xllc'.\n" + "For example, '-Xllc opt1 -Xllc opt2 ...'"), llvm::cl::value_desc("A valid LLVM's 'llc' option"), llvm::cl::location(Xllc), llvm::cl::cat(OnnxMlirOptions), llvm::cl::Hidden, llvm::cl::ValueRequired, llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); static llvm::cl::opt mllvmOpt("mllvm", - llvm::cl::desc( - "Arguments to forward to LLVM's 'opt' and 'llc' option processing"), + llvm::cl::desc("Arguments to forward to LLVM's 'opt' and 'llc' option " + "processing."), llvm::cl::value_desc("A valid LLVM's 'opt' and 'llc' option"), llvm::cl::location(mllvm), llvm::cl::cat(OnnxMlirOptions), llvm::cl::Hidden, llvm::cl::ValueRequired); static llvm::cl::opt instrumentOpsOpt("instrument-ops", llvm::cl::desc("Specify operations to be instrumented:\n" - "\"NONE\" or \"\" for no instrument (default),\n" - "\"ALL\" for instrument of all ops,\n" + "\"NONE\" or \"\" for no instrument (default).\n" + "\"ALL\" for instrument of all ops.\n" "\"ops1,ops2, ...\" for the multiple ops.\n" "e.g. \"onnx.Conv,onnx.Add\" for Conv and Add ops.\n" "Asterisk is also available.\n" @@ -423,8 +515,8 @@ static llvm::cl::bits instrumentControlBitsOpt( static llvm::cl::opt parallelizeOpsOpt("parallelize-ops", llvm::cl::desc("Specify explicitly which operations to parallelize:\n" - "\"ALL\" or \"\" for all available operations (default),\n" - "\"NONE\" for no instrument,\n" + "\"ALL\" or \"\" for all available operations (default).\n" + "\"NONE\" for no instrument.\n" "\"ops1,ops2, ...\" for the multiple ops.\n" "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n" "Asterisk is also available.\n" @@ -436,8 +528,8 @@ static llvm::cl::opt instrumentSignatureOpt( "instrument-signature", llvm::cl::desc("Specify which high-level operations should print their" " input type(s) and shape(s)\n" - "\"ALL\" or \"\" for all available operations,\n" - "\"NONE\" for no instrument (default),\n" + "\"ALL\" or \"\" for all available operations.\n" + "\"NONE\" for no instrument (default).\n" "\"ops1,ops2, ...\" for the multiple ops.\n" "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n" "Asterisk is also available.\n" @@ -446,12 +538,13 @@ static llvm::cl::opt instrumentSignatureOpt( llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt ONNXOpStatsOpt("onnx-op-stats", - llvm::cl::desc( - "Report the occurrence frequency of ONNX ops in JSON or TXT format:\n" - "\"TXT\" for report as text,\n" - "\"JSON\" for report as JSON.\n" - "Requires targets like --EmitMLIR, --EmitLLVMIR, or binary-generating " - "commands."), + llvm::cl::desc("Report the occurrence frequency of ONNX ops in JSON or " + "TXT format:\n" + "\"TXT\" for report as text,\n" + "\"JSON\" for report as JSON.\n" + "Requires targets like --EmitMLIR, --EmitLLVMIR, or " + "binary-generating " + "commands."), llvm::cl::location(ONNXOpStats), llvm::cl::init(""), llvm::cl::cat(OnnxMlirOptions)); @@ -460,14 +553,14 @@ static llvm::cl::opt onnxOpTransformThresholdOpt( llvm::cl::desc( "Max iteration for dynamic op transform passes (default=3).\n" "If set to 0, onnxOpTransformPass will be disabled, and\n" - "static iteration will be used"), + "static iteration will be used."), llvm::cl::location(onnxOpTransformThreshold), llvm::cl::init(3), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt onnxOpTransformReportOpt( "onnx-op-transform-report", - llvm::cl::desc( - "Report diagnostic info for ONNX op transform/optimization passes."), + llvm::cl::desc("Report diagnostic info for ONNX op " + "transform/optimization passes."), llvm::cl::location(onnxOpTransformReport), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); @@ -483,31 +576,40 @@ static llvm::cl::opt disableSimdOptionOpt("disable-simd", llvm::cl::location(disableSimdOption), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt enableFastMathOptionOpt("enable-fast-math", + llvm::cl::desc( + "Enable fast math optimizations (default=false). Set to `true` " + "to enable fast math options at O3."), + llvm::cl::location(enableFastMathOption), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirOptions)); + static llvm::cl::opt enableSimdDataLayoutOpt("simd-data-layout", - llvm::cl::desc("Enable SIMD optimization for convolution (default=false)\n" + llvm::cl::desc("Enable SIMD optimization for convolution (default=false).\n" "Set to 'true' if you want to enable SIMD optimizations."), llvm::cl::location(enableSimdDataLayout), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt opsForCallOpt("ops-for-call", - llvm::cl::desc("Specify which ops are lowered to knrl.call instead of" - "krnl loops. op name are used to check against this option." - "Names of opa are separated with space." - "Example: ops-for-call=Conv MatMul" - "The regex match will be used to check against op name"), + llvm::cl::desc( + "Specify which ops are lowered to knrl.call instead of " + "krnl loops. op name are used to check against this option.\n" + "Names of opa are separated with space. " + "Example: ops-for-call=Conv MatMul.\n" + "The regex match will be used to check against op name."), llvm::cl::location(opsForCall), llvm::cl::init(""), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt verifyInputTensorsOpt("verifyInputTensors", - llvm::cl::desc( - "Verify input tensors whenever the entry point function is called.\n" - "Data type and shape are verified. Enable this may introduce overhead " - "at runtime."), + llvm::cl::desc("Verify input tensors whenever the entry point function " + "is called.\n" + "Data type and shape are verified. Enable this may " + "introduce overhead " + "at runtime."), llvm::cl::location(verifyInputTensors), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt allowSortingOpt("allowSorting", - llvm::cl::desc("Perform topological sort on onnx graph"), + llvm::cl::desc("Perform topological sort on onnx graph."), llvm::cl::location(allowSorting), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); @@ -515,29 +617,34 @@ static llvm::cl::list> reportHeapBeforeOpt("report-heap-before", llvm::cl::desc("A list of names of passes.\n" "Before each heap statistics are dumped to " - ".heap.log"), + ".heap.log."), llvm::cl::location(reportHeapBefore), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::list> reportHeapAfterOpt( "report-heap-after", llvm::cl::desc("A list of names of passes.\n" "After each heap statistics are dumped to " - ".heap.log"), + ".heap.log."), llvm::cl::location(reportHeapAfter), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt modelTagOpt("tag", llvm::cl::desc( "Set a tag that will be used to postfix symbols in the generated " - "LLVMIR to make the symbols unique across multiple generated models. " - "By default, use the filename (without extension) of the input onnx " - "model or the value passed to `-o`. The tag will be appended to " - "global variable and function names. For backward compatibility, each " - "function has two versions with the same signature and doing the same " - "computation. For example, we will have two entry points: " + "LLVMIR to make the symbols unique across multiple generated " + "models.\n" + "By default, use the filename (without extension) of the input " + "onnx " + "model or the value passed to `-o`.\nThe tag will be appended to " + "global variable and function names. For backward compatibility, " + "each " + "function has two versions with the same signature and doing the " + "same " + "computation.\nFor example, we will have two entry points: " "`run_main_graph` and `run_main_graph_tag`, where `run_main_graph` " - "is just a wrapper of `run_main_graph_tag`. Users can call one of " - "the entry points and expect the same result. Passing `NONE` to " - "`--tag` will disable tag completely, meaning no tag is appended to " + "is just a wrapper of `run_main_graph_tag`.\nUsers can call one of " + "the entry points and expect the same result.\nPassing `NONE` to " + "`--tag` will disable tag completely, meaning no tag is appended " + "to " "the symbols."), llvm::cl::value_desc("a string that matches regex ([0-9a-z_.-]+)"), llvm::cl::location(modelTag), llvm::cl::init(""), @@ -549,28 +656,29 @@ static llvm::cl::opt enableConvOptPassOpt("enable-conv-opt-pass", llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt disableConstantPropOpt("disable-constant-prop", - llvm::cl::desc("Disable Constant Propagation (default is false)\n" + llvm::cl::desc("Disable Constant Propagation (default is false).\n" "Set to 'true' to disable Constant Propagation."), llvm::cl::location(disableConstantProp), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::list> extraLibPathsOpt( "L", - llvm::cl::desc("Specify extra directories for libraries when compiling" - "an onnx model. Will be add used as -L in the linkage step." - "Each directory can be specified with one extra-lib-dirs"), + llvm::cl::desc( + "Specify extra directories for libraries when compiling " + "an onnx model. Will be add used as -L in the linkage step.\n" + "Each directory can be specified with one extra-lib-dirs."), llvm::cl::location(extraLibPaths), llvm::cl::Prefix, llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::list> extraLibsOpt("l", llvm::cl::desc("Specify extra libraries when compiling an onnx model." - "Will be add used as -l in the linkage step." - "Each lib can be specified with one extra-libs"), + "Will be add used as -l in the linkage step.\n" + "Each lib can be specified with one extra-libs."), llvm::cl::location(extraLibs), llvm::cl::Prefix, llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt profileIROpt("profile-ir", - llvm::cl::desc("Profile operations in an IR"), + llvm::cl::desc("Profile operations in an IR:"), llvm::cl::location(profileIR), llvm::cl::values(clEnumVal(None, "No profiling. Default value."), clEnumVal( @@ -579,7 +687,7 @@ static llvm::cl::opt profileIROpt("profile-ir", llvm::cl::init(ProfileIRs::None), llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt optReportOpt("opt-report", - llvm::cl::desc("Provide information on a specific compiler optimization."), + llvm::cl::desc("Provide information on a specific compiler optimization:"), llvm::cl::location(optReport), llvm::cl::values(clEnumVal(NoReport, "No report. Default value."), clEnumVal(Parallel, @@ -595,8 +703,9 @@ static llvm::cl::opt enable_timing("enable-timing", llvm::cl::cat(OnnxMlirOptions)); static llvm::cl::opt enable_bound_check("enable-bound-check", - llvm::cl::desc("Enable runtime bound check for memrefs (default is false)\n" - "Set to 'true' if you want to enable the check."), + llvm::cl::desc( + "Enable runtime bound check for memrefs (default is false).\n" + "Set to 'true' if you want to enable the check."), llvm::cl::location(enableBoundCheck), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); @@ -604,11 +713,15 @@ static llvm::cl::opt enable_bound_check("enable-bound-check", // Option only available in debug mode: set using command options. static llvm::cl::opt test_compiler_opt("test-compiler-opt", llvm::cl::desc( - "Help compiler writers test a new (small) optimization. When false, " - "the old approach should be used. When true, the new opt should be " - "used. Utilities such as CheckONNXModel.py can then verify that the " + "Help compiler writers test a new (small) optimization. When " + "false, " + "the old approach should be used.\nWhen true, the new opt should " + "be " + "used. Utilities such as CheckONNXModel.py can then verify that " + "the " "new opt deliver the same results.\n" - "E.g. CheckONNXModel.py -m test.mlir -t -O3 -a test-compiler-opt=true\n" + "E.g. CheckONNXModel.py -m test.mlir -t -O3 -a " + "test-compiler-opt=true.\n" "Once the new opt works, it should not rely this option any more.\n" "Only defined in DEBUG build and default to false.\n"), llvm::cl::location(debugTestCompilerOpt), llvm::cl::init(false), @@ -622,33 +735,35 @@ bool debugTestCompilerOpt = false; // Options for onnx-mlir-opt only static llvm::cl::opt split_input_file_opt("split-input-file", llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), + "chunk independently."), llvm::cl::location(split_input_file), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptOptions)); static llvm::cl::opt verify_diagnostics_opt("verify-diagnostics", llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), + "expected-* lines on the corresponding line."), llvm::cl::location(verify_diagnostics), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptOptions)); static llvm::cl::opt verify_passes_opt("verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), + llvm::cl::desc("Run the verifier after each transformation pass."), llvm::cl::location(verify_passes), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptOptions)); static llvm::cl::opt allowUnregisteredDialectsOpt( "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::desc("Allow operation with no registered dialects."), llvm::cl::location(allowUnregisteredDialects), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptOptions)); -// Removed once the new LLVM bufferization works without performance regression. +// Removed once the new LLVM bufferization works without performance +// regression. static llvm::cl::opt useOldBufferizationOpt("use-old-bufferization", llvm::cl::desc( - "Enable the old LLVM bufferization mechanism (default=true)\n" - "This option should be removed once the new LLVM bufferization works " - "well in onnx-mlir"), + "Enable the old LLVM bufferization mechanism (default=true).\n" + "This option should be removed once the new LLVM bufferization " + "works " + "well in onnx-mlir."), llvm::cl::location(useOldBufferization), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); @@ -672,8 +787,8 @@ std::string customEnvFlags; // The customEnvFlags must be scanned before the normal options. bool parseCustomEnvFlagsCommandLineOption( int argc, const char *const *argv, llvm::raw_ostream *errs) { - // Use the default ONNX MLIR Environment variable, unless specified otherwise - // by an argument, see below. + // Use the default ONNX MLIR Environment variable, unless specified + // otherwise by an argument, see below. std::string envVar = OnnxMlirEnvOptionName; // VerboseOutput is not yet set, so scan ourselves. bool verbose = false; @@ -688,14 +803,15 @@ bool parseCustomEnvFlagsCommandLineOption( verbose = true; } } - // Check that the env var does not recursively hold another -customEnvFlags. + // Check that the env var does not recursively hold another + // -customEnvFlags. const char *envValCstr; if ((envValCstr = std::getenv(envVar.c_str()))) { std::string envVal(envValCstr); if (envVal.find("-customEnvFlags") != std::string::npos) { if (errs) - *errs << "Warning: recursive use of --customEnvFlags in " - "environment flag not permited\n"; + *errs << "\nWarning: recursive use of --customEnvFlags in " + "environment flag not permited\n\n"; return false; } if (envVal.find("-v") != std::string::npos) @@ -759,7 +875,15 @@ void setTargetArch(const std::string &arch) { void clearTargetArch() { march.clear(); } -std::string getTargetArchOption() { +std::string getTargetArchOption(bool forLLVMToolchain) { + // LLVM toolchain wants a --march=systemz for all z machines; the specific + // Z architecture will be specified with the LLVM Toolchain --mcpu. + if (forLLVMToolchain) { + // Handle special case for Z. + int64_t zArchNum = getZArchNum(march, mcpu); + if (zArchNum != -1) + return "--march=systemz"; + } return (march != "") ? "--march=" + march : ""; } @@ -772,8 +896,27 @@ void setTargetCPU(const std::string &cpu) { void clearTargetCPU() { mcpu.clear(); } -std::string getTargetCPUOption() { - return (mcpu != "") ? "--mcpu=" + mcpu : ""; +// As the LLVM tooling for Z may not support the latest, cap it by this +// --mcpu=arch{MAX_LLVM_Z_ARCH_LEVEL} value. +#define MAX_LLVM_Z_ARCH_LEVEL 14 + +std::string getTargetCPUOption(bool forLLVMToolchain, bool cpuOnly) { + // With cpu only, return the mcpu value; without it, prepend with "--mcpu=". + std::string str = (cpuOnly ? "" : "--mcpu="); + + // The LLVM toolchain wants the specific Z architecture to be expressed with + // the LLVM Toolchain --mcpu. Convert below the --march into their + // corresponding --mcpu equivalent. + if (forLLVMToolchain) { + // Handle special case for Z. + int64_t zArchNum = getZArchNum(march, mcpu); + if (zArchNum != -1) { + // Cap at max supported LLVM level. + zArchNum = std::min(zArchNum, (int64_t)MAX_LLVM_Z_ARCH_LEVEL); + return str.append("arch" + std::to_string(zArchNum)); + } + } + return (mcpu != "") ? str + mcpu : ""; } // Support for Accel. @@ -887,6 +1030,21 @@ void setLLVMOption(const std::string &flag) { mllvm = flag; } void clearLLVMOption() { mllvm.clear(); } std::string getLLVMOption() { return (mllvm != "") ? mllvm : std::string(); } +static std::vector split(std::string &input) { + std::stringstream ss(input); + std::istream_iterator begin(ss); + std::istream_iterator end; + std::vector vstrings(begin, end); + return vstrings; +} + +std::vector getLLVMOptions() { + if (mllvm == "") + return std::vector(); + + return split(mllvm); +} + // Support for model tag void setModelTag(const std::string &str) { modelTag = str; } void clearModelTag() { modelTag = ""; } @@ -1061,9 +1219,9 @@ std::string getExecPath() { auto execPath = llvm::sys::fs::getMainExecutable(nullptr, nullptr); if (execPath.empty()) { llvm::errs() - << "Warning: Could not find path to current executable, falling " + << "\nWarning: Could not find path to current executable, falling " "back to default install path: " - << kExecPath << "\n"; + << kExecPath << "\n\n"; return kExecPath; } return execPath; @@ -1102,16 +1260,16 @@ std::string getLibraryPath() { // onnx-mlir currently requires llvm tools llc and opt and they are assumed // to be under llvm-project/build/bin. This doesn't work with the case where -// llvm-project has been installed system wide (typically under /usr/local/...) -// and its source has been removed. +// llvm-project has been installed system wide (typically under +// /usr/local/...) and its source has been removed. // // To account for this scenario, we first search for the tools in the same -// directory where onnx-mlir is run. If they are found, it means both onnx-mlir -// and llvm-project have been installed system wide under the same directory, -// so we get them from that directory (typically /usr/local/bin). Otherwise, -// at least one of onnx-mlir and llvm-project has not been installed system -// wide. In this case, getToolPath returns the fallback directory where llvm -// is built which is typically llvm-project/build/bin. +// directory where onnx-mlir is run. If they are found, it means both +// onnx-mlir and llvm-project have been installed system wide under the same +// directory, so we get them from that directory (typically /usr/local/bin). +// Otherwise, at least one of onnx-mlir and llvm-project has not been +// installed system wide. In this case, getToolPath returns the fallback +// directory where llvm is built which is typically llvm-project/build/bin. // // Note that this will not work if both onnx-mlir and llvm-project have been // installed system wide but to different places and their sources have been @@ -1119,8 +1277,8 @@ std::string getLibraryPath() { // llvm-project. // // If the flag is true, getToolPath will simply return the path detected by -// cmake at compile time. This is used for system wide tools such as cc, ld, ar, -// etc. Note that this means the path is valid only on the system where +// cmake at compile time. This is used for system wide tools such as cc, ld, +// ar, etc. Note that this means the path is valid only on the system where // onnx-mlir is built. If onnx-mlir is subsequently run on a system that does // not have these tools installed in the "standard" places, it will fail. // @@ -1128,7 +1286,6 @@ std::string getLibraryPath() { // as lrodataScript. std::string getToolPath( const std::string &tool, bool flag /*false by default*/) { - if (!flag) { std::string execDir = llvm::sys::path::parent_path(getExecPath()).str(); llvm::SmallString<8> toolPath(execDir); @@ -1177,8 +1334,8 @@ void initCompilerConfig() { // Test option requirements. if (!ONNXOpStats.empty() && emissionTarget <= EmitONNXIR) llvm::errs() - << "Warning: --onnx-op-stats requires targets like --EmitMLIR, " - "--EmitLLVMIR, or binary-generating emit commands.\n"; + << "\nWarning: --onnx-op-stats requires targets like --EmitMLIR, " + "--EmitLLVMIR, or binary-generating emit commands.\n\n"; // Library setup for EmitLib and EmitJNI targets if (emissionTarget == EmitLib || emissionTarget == EmitJNI) { @@ -1199,6 +1356,19 @@ void initCompilerConfig() { addCompilerConfig(CCM_SHARED_LIB_DEPS, extraLibs); addCompilerConfig(CCM_SHARED_LIB_PATH_DEPS, extraLibPaths); } + + // Enable aggressive optimization for NNPA with -O3 + if (OptimizationLevel == OptLevel::O3 && + getTargetAccel().find("NNPA") != std::string::npos) { + // Have O3 and NNPA. May enable fast math default in the future. + } + + // Enabling unsafe math. + if (enableFastMathOption && + getLLVMOption().find("enable-unsafe-fp-math") == std::string::npos) { + // Fast math option is enabled (in general) + setLLVMOption(getLLVMOption() + " --enable-unsafe-fp-math"); + } } } // namespace onnx_mlir diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 2ed9f251e1..b21c8a3b31 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -14,7 +14,7 @@ #ifndef ONNX_MLIR_COMPILER_OPTIONS_H #define ONNX_MLIR_COMPILER_OPTIONS_H -#include "onnx-mlir/Compiler/OMCompilerTypes.h" + #include "src/Accelerators/Accelerator.hpp" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" @@ -78,6 +78,7 @@ extern std::vector maccel; // common for both extern OptLevel OptimizationLevel; // common for both extern std::string mtriple; // common for both extern std::string mcpu; // common for both +extern float nnpaEpsilon; // common for both extern std::string march; // common for both extern InstrumentStages instrumentStage; // common for both extern bool onnxConstPropRoundFPToInt; // common for both @@ -87,18 +88,24 @@ extern bool enableONNXHybridPass; // common for both extern std::vector functionsToDecompose; // common for both extern std::string opsForCall; // common for both extern bool disableKrnlOpFusion; // common for both +extern bool disableQuantZeroPoint; // common for both +extern bool enableKrnlBufferReuse; // common for both +extern bool disableMemRefPrefetch; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only extern bool invokeOnnxVersionConverter; // onnx-mlir only extern bool preserveLocations; // onnx-mlir only extern bool printIR; // onnx-mlir only +extern bool printBytecode; // onnx-mlir only extern bool preserveBitcode; // onnx-mlir only extern bool preserveLLVMIR; // onnx-mlir only extern bool preserveMLIR; // onnx-mlir only +extern bool doNotEmitFullMLIRCode; // onnx-mlir only extern bool useOnnxModelTypes; // onnx-mlir only extern int repeatOnnxTransform; // onnx-mlir only extern std::string shapeInformation; // onnx-mlir only extern std::string dimParams; // onnx-mlir only extern ModelSize modelSize; // onnx-mlir only +extern std::string externalDataDir; // onnx-mlir only extern bool storeConstantsToFile; // onnx-mlir only extern float constantsToFileTotalThreshold; // onnx-mlir only extern float constantsToFileSingleThreshold; // onnx-mlir only @@ -115,7 +122,9 @@ extern int onnxOpTransformThreshold; // onnx-mlir only extern bool onnxOpTransformReport; // onnx-mlir only extern bool enableParallel; // onnx-mlir only extern bool disableSimdOption; // onnx-mlir only +extern bool enableFastMathOption; // onnx-mlir only extern bool disableRecomposeOption; // onnx-mlir only +extern bool disableConvTransposeDecomposeOption; // onnx-mlir only extern bool enableSimdDataLayout; // onnx-mlir only extern bool verifyInputTensors; // onnx-mlir only extern bool allowSorting; // onnx-mlir only @@ -154,11 +163,35 @@ std::string getTargetTripleOption(); void setTargetArch(const std::string &arch); void clearTargetArch(); -std::string getTargetArchOption(); + +// AMD: inlined to avoid linking errors +// Sort out architectures for Z systems (hybrid archXX and zYY names). +inline int64_t decodeZArchNum(std::string str) { + if (str == "arch12" || str == "z14") // Z14 and equivalents. + return 12; + if (str == "arch13" || str == "z15") // Z15 and equivalents. + return 13; + if (str == "arch14" || str == "z16") // Z16 and equivalents. + return 14; + if (str == "arch15") + return 15; + return -1; +} + +// AMD: inlined to avoid linking errors +inline int64_t getZArchNum(const std::string &arch, const std::string cpu) { + // Give priority to march, use (deprecated) mcpu if march is not defined. + int64_t num = decodeZArchNum(arch); + if (num == -1) + num = decodeZArchNum(cpu); + return num; +} +std::string getTargetArchOption(bool forLLVMToolchain = false); void setTargetCPU(const std::string &cpu); void clearTargetCPU(); -std::string getTargetCPUOption(); +std::string getTargetCPUOption( + bool forLLVMToolchain = false, bool cpuOnly = false); int setTargetAccel(const std::string &str); void setTargetAccel(const accel::Accelerator::Kind accel); @@ -180,6 +213,10 @@ std::vector getXllcOption(); void setLLVMOption(const std::string &flag); void clearLLVMOption(); std::string getLLVMOption(); +// Break down the result of getLLVMOption into substrings +std::vector getLLVMOptions(); +std::vector getLLVMOPTOptions(); +std::vector getLLVMLLCOptions(); // Options support for OMCompilerOptions. using CompilerOptionList = diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index c8f84e5565..88209eeb02 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -44,6 +44,18 @@ using namespace mlir; namespace onnx_mlir { void configurePasses() { + // Handle deprecated mcpu. + if (!mcpu.empty()) { + if (!march.empty()) { + llvm::outs() << "\nWarning: Got values for both --march and --mcpu, " + "ignore --mcpu. " + "Please remove deprecated --mcpu in the near future.\n\n"; + } else { + llvm::outs() + << "\nWarning: Got deprecated --mcpu option. Please switch to " + "--march in the near future.\n\n"; + } + } // Set global vector machine support. VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); configureConstPropONNXToONNXPass(onnxConstPropRoundFPToInt, @@ -54,7 +66,8 @@ void configurePasses() { !disableSimdOption); } -void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr) { // This is a transition from previous static passes to full dynamic passes // Static passes are kept and the dynamic pass is added as IF-THEN // with the static iteration. @@ -66,8 +79,9 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { // In future, only the dynamic pass, ONNXOpTransformPass, will be used for // this function. - pm.addInstrumentation( - std::make_unique(pm.getContext())); + if (!donotScrubDisposableElementsAttr) + pm.addInstrumentation( + std::make_unique(pm.getContext())); // Decompose first. Eliminates some unsupported ops without shape inference. pm.addNestedPass(onnx_mlir::createDecomposeONNXToONNXPass()); @@ -132,7 +146,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { pm.addPass(mlir::createSymbolDCEPass()); // Replace every DisposableElementsAttr with DenseElementsAttr. - pm.addPass(createScrubDisposablePass()); + if (!donotScrubDisposableElementsAttr) + pm.addPass(createScrubDisposablePass()); // Set onnx_node_name if it is missing. Keep this pass at the end of this // function and just before instrumentation. @@ -189,6 +204,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, instrumentSignatureString)); pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, /*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, + /*enableFastMath*/ optLevel >= 3 && enableFastMathOption, /*opsToCall*/ opsForCall)); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization @@ -251,6 +267,7 @@ void addKrnlToLLVMPasses( // The alloca_scope ops are somewhat fragile; canonicalize remove them when // redundant, which helps reliability of the compilation of these ops. pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(onnx_mlir::createProcessKrnlParallelClausePass()); } // The pass below is needed for subview and collapseShape.. Unfortunately, @@ -299,8 +316,10 @@ void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, EmissionTargetType emissionTarget, std::string outputNameNoExt) { InputIRLevelType inputIRLevel = determineInputIRLevel(module); + // NOTE: FlexML sets the targetCPU flag to false, as we do not want to run + // the CPU specific transformations. if (inputIRLevel <= ONNXLevel && emissionTarget >= EmitONNXIR) - addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty()); + addONNXToMLIRPasses(pm, /*target CPU*/ false); if (emissionTarget >= EmitMLIR) { if (inputIRLevel <= ONNXLevel) diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp index f0c0499f8f..9a6987cf19 100644 --- a/src/Compiler/CompilerPasses.hpp +++ b/src/Compiler/CompilerPasses.hpp @@ -20,7 +20,8 @@ namespace onnx_mlir { // Configures passes up front based on command line options. void configurePasses(); -void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU); +void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, + bool donotScrubDisposableElementsAttr = false); void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, std::string instrumentSignatureString, std::string ONNXOpsStatFilename); void addKrnlToAffinePasses(mlir::PassManager &pm); diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 67c53ce1b3..d2b32bf8e4 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -20,6 +20,7 @@ #include #include +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Parser/Parser.h" @@ -60,9 +61,9 @@ mlir::TimingScope rootTimingScope; namespace onnx_mlir { // Values to report the current phase of compilation. -// Increase TOTAL_COMPILE_PHASE when having more phases. uint64_t CURRENT_COMPILE_PHASE = 1; -uint64_t TOTAL_COMPILE_PHASE = 6; +uint64_t TOTAL_COMPILE_PHASE = 0; +static DiagnosticEngine::HandlerID diagnosticHandlerID = 0; // Make a function that forces preserving all files using the runtime arguments // and/or the overridePreserveFiles enum. @@ -72,6 +73,26 @@ enum class KeepFilesOfType { All, MLIR, LLVMIR, Bitcode, Object, None }; // flags. static constexpr KeepFilesOfType overridePreserveFiles = KeepFilesOfType::None; +// Get optimization level from onnx-mlir only when it is not specified +std::string getOptimizationLevelUniqueOption( + std::vector> specialOptionsList) { + if (std::any_of(specialOptionsList.begin(), specialOptionsList.end(), + [](std::vector specialOptions) { + if (std::any_of(specialOptions.begin(), specialOptions.end(), + [](std::string str) { + std::regex optimization("^-O[0-9]"); + std::smatch m; + return std::regex_match(str, m, optimization); + })) // End of one options + return true; + else + return false; + })) + return std::string(); + else + return getOptimizationLevelOption(); +} + static bool keepFiles(KeepFilesOfType preserve) { // When wanting to preserve all files, do it regardless of isBitcode. if (overridePreserveFiles == KeepFilesOfType::All) @@ -170,7 +191,7 @@ int Command::exec(std::string wdir) const { } void showCompilePhase(std::string msg) { - time_t rawTime; + time_t rawTime = 0; struct tm *timeInfo; char buffer[80]; // Remember first time. @@ -178,10 +199,13 @@ void showCompilePhase(std::string msg) { static bool hasFirstRawTime = false; // Get current date. - time(&rawTime); - timeInfo = localtime(&rawTime); - strftime(buffer, 80, "%c", timeInfo); - std::string currentTime(buffer); + std::string currentTime(""); + if (time(&rawTime) == -1 || (timeInfo = localtime(&rawTime)) == NULL || + (strftime(buffer, 80, "%c", timeInfo)) == 0) { + currentTime = "Error obtaining current time"; + } else { + currentTime = buffer; + } // Compute time difference in seconds. int diff = 0; @@ -191,10 +215,10 @@ void showCompilePhase(std::string msg) { firstRawTime = rawTime; hasFirstRawTime = true; } - llvm::outs() << "[" << CURRENT_COMPILE_PHASE++ << "/" << TOTAL_COMPILE_PHASE + llvm::errs() << "[" << CURRENT_COMPILE_PHASE++ << "/" << TOTAL_COMPILE_PHASE << "] " << currentTime << " (" << diff << "s) " << msg << "\n"; // Flush so that if there are errors, we know where it came from. - llvm::outs().flush(); + llvm::errs().flush(); // Reset current phase. if (CURRENT_COMPILE_PHASE > TOTAL_COMPILE_PHASE) { @@ -242,7 +266,7 @@ static void loadMLIR(std::string inputFilename, mlir::MLIRContext &context, if ((numOfFuncOp == 1) && (!shapeInformation.empty())) { ModelInputShaper modelInputShaper_; modelInputShaper_.setShapeInformation(shapeInformation); - auto funcType = dyn_cast(funcOp.getFunctionType()); + auto funcType = mlir::dyn_cast(funcOp.getFunctionType()); ArrayRef argTypes = funcType.getInputs(); SmallVector newArgTypes; for (uint64_t i = 0; i < argTypes.size(); ++i) { @@ -320,13 +344,15 @@ static void tailorLLVMIR(llvm::Module &llvmModule) { llvmModule.getNamedGlobal(StringRef("_entry_point_arrays" + tag))) { if (GV->isConstant() && GV->hasDefinitiveInitializer()) { llvm::Constant *initializer = GV->getInitializer(); - llvm::ArrayType *AT = dyn_cast(initializer->getType()); + llvm::ArrayType *AT = + mlir::dyn_cast(initializer->getType()); for (uint64_t i = 0; i < AT->getNumElements() - 1; ++i) { llvm::GlobalVariable *entryGV = llvmModule.getNamedGlobal( StringRef("_entry_point_" + std::to_string(i) + tag)); if (entryGV->isConstant()) { llvm::ConstantDataSequential *entry = - dyn_cast(entryGV->getInitializer()); + mlir::dyn_cast( + entryGV->getInitializer()); exportedFuncs.emplace_back(entry->getAsCString()); } } @@ -396,6 +422,12 @@ static int genLLVMBitcode(const mlir::OwningOpRef &module, return InvalidTemporaryFileAccess; } + // In the LLVM translation, we get some warnings, so disable in non-verbose + // mode. + if (diagnosticHandlerID && !VerboseOutput) { + module.get().getContext()->getDiagEngine().eraseHandler( + diagnosticHandlerID); + } llvm::LLVMContext llvmContext; mlir::registerBuiltinDialectTranslation(*(module.get().getContext())); mlir::registerLLVMDialectTranslation(*(module.get().getContext())); @@ -432,12 +464,14 @@ static int genLLVMBitcode(const mlir::OwningOpRef &module, std::string optPath = getToolPath("opt"); Command optBitcode(/*exePath=*/optPath); setXoptOption({"--code-model", modelSizeStr[modelSize]}); - int rc = optBitcode.appendStr(getOptimizationLevelOption()) + int rc = optBitcode + .appendStr(getOptimizationLevelUniqueOption( + {getLLVMOptions(), getXoptOption()})) .appendStr(getTargetTripleOption()) - .appendStr(getTargetArchOption()) - .appendStr(getTargetCPUOption()) + .appendStr(getTargetArchOption(/*forLLVM toolchain*/ true)) + .appendStr(getTargetCPUOption(/*forLLVM*/ true)) .appendList(getXoptOption()) - .appendStr(getLLVMOption()) + .appendList(getLLVMOptions()) .appendList({"-o", optimizedBitcodeNameWithExt}) .appendStr(unoptimizedBitcodeNameWithExt) .exec(); @@ -454,12 +488,14 @@ static int genModelObject( std::string llcPath = getToolPath("llc"); Command llvmToObj(/*exePath=*/llcPath); setXllcOption({"--code-model", modelSizeStr[modelSize]}); - int rc = llvmToObj.appendStr(getOptimizationLevelOption()) + int rc = llvmToObj + .appendStr(getOptimizationLevelUniqueOption( + {getLLVMOptions(), getXllcOption()})) .appendStr(getTargetTripleOption()) - .appendStr(getTargetArchOption()) - .appendStr(getTargetCPUOption()) + .appendStr(getTargetArchOption(/*LLVM toolchain*/ true)) + .appendStr(getTargetCPUOption(/*LLVM*/ true)) .appendList(getXllcOption()) - .appendStr(getLLVMOption()) + .appendList(getLLVMOptions()) .appendStr("-filetype=obj") .appendStr("-relocation-model=pic") .appendList({"-o", modelObjNameWithExt}) @@ -553,6 +589,9 @@ static int genJniJar(const mlir::OwningOpRef &module, // Copy javaruntime.jar to model jar. llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath); + if (VerboseOutput) + llvm::outs() << "cp " << javaRuntimeJarPath << " " << modelJniJarPath + << "\n"; // Add shared library to model jar. Command jar(getToolPath("jar", true)); @@ -686,7 +725,8 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context, options.shapeInformation = shapeInformation; options.dimParams = dimParams; options.allowSorting = allowSorting; - options.externalDataDir = dirName(inputFilename); + options.externalDataDir = + externalDataDir.empty() ? dirName(inputFilename) : externalDataDir; options.functionsToDecompose.insert(options.functionsToDecompose.end(), functionsToDecompose.begin(), functionsToDecompose.end()); return ImportFrontendModelFile( @@ -769,8 +809,8 @@ static int emitOutputFiles(std::string outputNameNoExt, return rc; } if (VerboseOutput) - printf( - "Object file '%s' has been compiled.\n", modelObjNameWithExt.c_str()); + llvm::outs() << "Object file '" << modelObjNameWithExt + << "' has been compiled.\n"; } break; case EmitLib: { std::string sharedLibNameWithExt; @@ -784,8 +824,8 @@ static int emitOutputFiles(std::string outputNameNoExt, return rc; } if (VerboseOutput) - printf("Shared library '%s' has been compiled.\n", - sharedLibNameWithExt.c_str()); + llvm::outs() << "Shared library '" << sharedLibNameWithExt + << "' has been compiled.\n"; } break; case EmitJNI: { int rc = compileModuleToJniJar(module, outputNameNoExt); @@ -797,29 +837,32 @@ static int emitOutputFiles(std::string outputNameNoExt, return rc; } if (VerboseOutput) - printf( - "JNI archive '%s.jar' has been compiled.\n", outputNameNoExt.c_str()); + llvm::outs() << "JNI archive '" << outputNameNoExt + << ".jar' has been compiled.\n"; } break; default: { // Emit the version with all constants included. - std::string ouputNameWithExt = + std::string outputNameWithExt = getTargetFilename(outputNameNoExt, emissionTarget); - int rc = outputCode(module, ouputNameWithExt); - if (VerboseOutput) - printf("Full MLIR code written to: \n\t%s\n\n", ouputNameWithExt.c_str()); - if (rc != CompilerSuccess) - return rc; - + if (!doNotEmitFullMLIRCode) { + int rc = outputCode(module, outputNameWithExt); + if (VerboseOutput) + llvm::outs() << "Full MLIR code written to:\n" + << "\t" << outputNameWithExt << "\n\n"; + if (rc != CompilerSuccess) + return rc; + } // Elide element attributes if larger than 100. if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR || emissionTarget == EmitMLIR) { std::string tempNameWithExt = outputNameNoExt + ".tmp"; int rc = outputCode(module, tempNameWithExt, /*largeElementLimit=*/100); if (VerboseOutput) { - printf("Constant-free MLIR Code written to: \n\t%s\n\n", - tempNameWithExt.c_str()); - printf("Use:\n\t%s\nto continue lowering the code to other dialects.\n", - ouputNameWithExt.c_str()); + llvm::outs() << "Constant-free MLIR Code written to:\n" + << "\t" << tempNameWithExt << "\n\n"; + llvm::outs() << "Use:\n" + << "\t" << outputNameWithExt + << "\nto continue lowering the code to other dialects.\n"; } if (rc != CompilerSuccess) return rc; @@ -845,14 +888,16 @@ static const llvm::Target *getLLVMTarget( return LLVMTarget; } -/// Return the module datalayout string. The datalayout string is determined +/// Return the module data layout string. The data layout string is determined /// by creating a target machine using the target triple and target cpu. static std::string getDataLayout(const Location &loc) { const llvm::Target &LLVMTarget = *getLLVMTarget(mtriple, loc); llvm::TargetOptions ops; + std::string mcpuForLLVMToolchain = getTargetCPUOption( + /*forLLVM*/ true, /*cpu only*/ true); auto targetMachine = std::unique_ptr{LLVMTarget.createTargetMachine( - mtriple, mcpu, "" /*features*/, ops, std::nullopt)}; + mtriple, mcpuForLLVMToolchain, "" /*features*/, ops, std::nullopt)}; if (!targetMachine) { emitError(loc, "failed to create target machine"); return nullptr; @@ -860,7 +905,7 @@ static std::string getDataLayout(const Location &loc) { const llvm::DataLayout &dl = targetMachine->createDataLayout(); std::string dataLayoutString = dl.getStringRepresentation(); - assert(dataLayoutString != "" && "Expecting a valid target datalayout"); + assert(dataLayoutString != "" && "Expecting a valid target data layout"); return dataLayoutString; } @@ -936,6 +981,12 @@ static int emitOutput(mlir::OwningOpRef &module, outputModule(module, llvm::outs()); return CompilerSuccess; } + if (printBytecode) { + if (failed(mlir::writeBytecodeToFile(*module, llvm::outs()))) { + return CompilerFailure; + } + return CompilerSuccess; + } return emitOutputFiles(outputNameNoExt, emissionTarget, context, module); } @@ -943,11 +994,15 @@ static int emitOutput(mlir::OwningOpRef &module, int compileModule(mlir::OwningOpRef &module, mlir::MLIRContext &context, std::string outputNameNoExt, EmissionTargetType emissionTarget) { + // When a C++ program calls this function directly without using onnx-mlir + // driver, there is no importing phase (e.g. the model is .mlir, not .onnx). + // Thus, decrease the total number of phases. + if (CURRENT_COMPILE_PHASE == 1) { + SET_TOTAL_COMPILE_PHASE(emissionTarget); + TOTAL_COMPILE_PHASE--; + } + std::string msg = "Compiling and Optimizing MLIR Module"; - // There is no importing phase (e.g. the model is .mlir, not .onnx), adjust to - // correctly reflect the current phase. - if (CURRENT_COMPILE_PHASE == 1) - CURRENT_COMPILE_PHASE++; showCompilePhase(msg); auto compileModuleTiming = rootTimingScope.nest("[onnx-mlir] " + msg); @@ -957,6 +1012,17 @@ int compileModule(mlir::OwningOpRef &module, configurePasses(); + // Enable printing for error handler on llvm error stream. Save ID if we want + // to disable it later. We currently disable for the llvm lowering, as + // otherwise we currently get an unrecognized warning for the "onnx.name" + // attribute in function operations. In Verbose mode, we keep the error + // handling all the way to the end. + diagnosticHandlerID = + context.getDiagEngine().registerHandler([](Diagnostic &diag) { + llvm::errs() << diag << "\n"; + return mlir::LogicalResult::success(); + }); + mlir::PassManager pm( module.get()->getName(), mlir::OpPassManager::Nesting::Implicit); // TODO(tung): Revise adding passes. The current mechanism does not work if @@ -975,7 +1041,7 @@ int compileModule(mlir::OwningOpRef &module, pm.addInstrumentation(std::make_unique( heapLogFileame, reportHeapBefore, reportHeapAfter)); } - (void)mlir::applyPassManagerCLOptions(pm); + static_cast(mlir::applyPassManagerCLOptions(pm)); if (enableTiming) { pm.enableTiming(compileModuleTiming); diff --git a/src/Compiler/CompilerUtils.hpp b/src/Compiler/CompilerUtils.hpp index e3ecc1bd72..713e2fb8e3 100644 --- a/src/Compiler/CompilerUtils.hpp +++ b/src/Compiler/CompilerUtils.hpp @@ -33,10 +33,22 @@ extern mlir::TimingScope rootTimingScope; namespace onnx_mlir { // Values to report the current phase of compilation. -// Increase TOTAL_COMPILE_PHASE when having more phases. extern uint64_t CURRENT_COMPILE_PHASE; extern uint64_t TOTAL_COMPILE_PHASE; +// When having more phases, let increase TOTAL_COMPILE_PHASE. +#define SET_TOTAL_COMPILE_PHASE(emissionTarget) \ + { \ + if (emissionTarget == EmitObj) \ + TOTAL_COMPILE_PHASE = 5; \ + else if (emissionTarget == EmitLib) \ + TOTAL_COMPILE_PHASE = 6; \ + else if (emissionTarget == EmitJNI) \ + TOTAL_COMPILE_PHASE = 8; \ + else \ + TOTAL_COMPILE_PHASE = 3; \ + } + struct Command { std::string _path; diff --git a/src/Compiler/DisposableGarbageCollector.cpp b/src/Compiler/DisposableGarbageCollector.cpp index 0eabc179df..4233652ff3 100644 --- a/src/Compiler/DisposableGarbageCollector.cpp +++ b/src/Compiler/DisposableGarbageCollector.cpp @@ -28,7 +28,7 @@ DisposableGarbageCollector::~DisposableGarbageCollector() {} void DisposableGarbageCollector::runAfterPass(Pass *pass, Operation *op) { if (!disposablePool.isActive()) return; - ModuleOp moduleOp = dyn_cast(op); + ModuleOp moduleOp = mlir::dyn_cast(op); if (!moduleOp) return; disposablePool.garbageCollectUnreachable( diff --git a/src/Conversion/KrnlSeqToMemref/CMakeLists.txt b/src/Conversion/KrnlSeqToMemref/CMakeLists.txt index 73e6a07d0e..3757e5a0ad 100644 --- a/src/Conversion/KrnlSeqToMemref/CMakeLists.txt +++ b/src/Conversion/KrnlSeqToMemref/CMakeLists.txt @@ -12,5 +12,6 @@ add_onnx_mlir_library(OMSeqToMemref OMSupport MLIRTransforms MLIRAffineUtils + MLIRMathTransforms OMONNXToKrnl ) diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp index 10f982864d..f175a5ba28 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqAlloc.cpp @@ -41,7 +41,7 @@ class KrnlSeqAllocOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { KrnlSeqAllocOpAdaptor operandAdaptor(operands); - KrnlSeqAllocOp thisOp = dyn_cast(op); + KrnlSeqAllocOp thisOp = mlir::dyn_cast(op); Location loc = op->getLoc(); MultiDialectBuilder create(rewriter, loc); diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp index d09f020c54..c5ff228d6b 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqExtract.cpp @@ -41,7 +41,7 @@ class KrnlSeqExtractOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { KrnlSeqExtractOpAdaptor operandAdaptor(operands); - KrnlSeqExtractOp thisOp = dyn_cast(op); + KrnlSeqExtractOp thisOp = mlir::dyn_cast(op); Location loc = op->getLoc(); MultiDialectBuilder create(rewriter, loc); @@ -62,7 +62,7 @@ class KrnlSeqExtractOpLowering : public ConversionPattern { llvm_unreachable( "Not implemented: type of onnx seq element is not tensor"); auto outputType = mlir::cast(output.getType()); - SmallVector allocParams; + SmallVector allocParams; for (size_t i = 0; i < outputType.getShape().size(); i++) { if (outputType.isDynamicDim(i)) { allocParams.emplace_back(create.mem.dim(output, i)); diff --git a/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp b/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp index cb58f47a1e..b8025e6ad8 100644 --- a/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp +++ b/src/Conversion/KrnlSeqToMemref/KrnlSeqStore.cpp @@ -46,7 +46,7 @@ class KrnlSeqStoreOpLowering : public ConversionPattern { // Allocate a new tensor and copy input tensor into it auto inputType = mlir::cast(operandAdaptor.getInput().getType()); - SmallVector allocParams; + SmallVector allocParams; for (size_t i = 0; i < inputType.getShape().size(); i++) { if (inputType.isDynamicDim(i)) { allocParams.emplace_back(create.mem.dim(operandAdaptor.getInput(), i)); diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index abcd008004..73609c2f14 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -154,7 +154,7 @@ class LoopBodyMover { */ struct Movable { std::optional movableOp; - std::optional> loopsToSkip; + std::optional> loopsToSkip; // Movable that stores a KrnlMovableOp. explicit Movable(KrnlMovableOp op) : movableOp(op) {} @@ -171,7 +171,7 @@ class LoopBodyMover { // Only skip non-unroll loops. Loops that are unrolled are by // definitions a loop whose loopRef is used by a KrnlUnrollOp. if (llvm::all_of(val.getUsers(), [&](Operation *user) { - return dyn_cast_or_null(user); + return mlir::dyn_cast_or_null(user); })) values.emplace_back(val); } @@ -265,10 +265,10 @@ class LoopBodyMover { // Move iterator to point to the next AffineFor Op. while (insertPt != loopBody.end() && - (!dyn_cast_or_null(&*insertPt) || - !dyn_cast_or_null(&*insertPt)) && + (!mlir::dyn_cast_or_null(&*insertPt) || + !mlir::dyn_cast_or_null(&*insertPt)) && loopToSkip) { - assert(dyn_cast_or_null(&*insertPt) && + assert(mlir::dyn_cast_or_null(&*insertPt) && "Expecting a KrnlMovableOp"); insertPt++; } @@ -289,7 +289,7 @@ class LoopBodyMover { } private: - llvm::DenseMap> movingPlan; + llvm::DenseMap> movingPlan; }; /*! @@ -338,7 +338,7 @@ static void markLoopBodyAsMovable( movableRegion, builder, delimeterOp->getLoc()); mover.toMoveUnder(LoopBodyMover::Movable(movableOp), root); - if (auto iterateOp = dyn_cast_or_null(delimeterOp)) + if (auto iterateOp = mlir::dyn_cast_or_null(delimeterOp)) mover.toMoveUnder(LoopBodyMover::Movable(iterateOp), root); movableBeginOp = delimeterOp->getNextNode(); @@ -354,11 +354,12 @@ static void lowerGetInductionVariableValueOp( for (const auto &operandAndResult : zippedOperandsResults) { auto operand = std::get<0>(operandAndResult); auto result = std::get<1>(operandAndResult); - if (auto forOp = dyn_cast_or_null(loopRefToOp[operand])) { + if (auto forOp = + mlir::dyn_cast_or_null(loopRefToOp[operand])) { result.replaceAllUsesWith(forOp.getInductionVar()); } else { auto parallelOp = - dyn_cast_or_null(loopRefToOp[operand]); + mlir::dyn_cast_or_null(loopRefToOp[operand]); assert(parallelOp && "expected affine.parallelOp only"); result.replaceAllUsesWith(parallelOp.getIVs()[0]); } @@ -423,7 +424,8 @@ static void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &builder, // For last optimized loop. // yield the iterateOp yield value. builder.setInsertionPointToEnd(forOp.getBody()); - auto Yield = cast(iterateOp.getBody()->getTerminator()); + auto Yield = + mlir::cast(iterateOp.getBody()->getTerminator()); builder.create(iterateOp.getLoc(), Yield.getOperands()); // replace use of iterateOp iterArgs with forOp iterArgs. @@ -554,7 +556,8 @@ static void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &builder, auto innerForOp = newForOps.back(); auto prevTerm = innerForOp.getBody()->getTerminator(); builder.setInsertionPointToEnd(innerForOp.getBody()); - auto iterTerm = cast(iterateOp.getBody()->getTerminator()); + auto iterTerm = + mlir::cast(iterateOp.getBody()->getTerminator()); builder.create(iterateOp.getLoc(), iterTerm.getOperands()); // Remove the old terminator. prevTerm->erase(); @@ -569,7 +572,8 @@ static void lowerIterateOp(KrnlIterateOp &iterateOp, OpBuilder &builder, // When there's no loop but iterateOp has result. else if (!isLoop && iterateHasResult) { // Replace use of iteratedOp with the yield value. - auto Yield = cast(iterateOp.getBody()->getTerminator()); + auto Yield = + mlir::cast(iterateOp.getBody()->getTerminator()); for (auto [result, yieldValue] : llvm::zip(iterateOp.getResults(), Yield.getOperands())) { result.replaceAllUsesWith(yieldValue); @@ -666,7 +670,7 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, } } - if (auto iterateOp = dyn_cast_or_null(op)) { + if (auto iterateOp = mlir::dyn_cast_or_null(op)) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " interpret iterate op " << iterateOp << "\n"); // If an iterateOp has no unoptimized loop references, then we need to lower @@ -676,7 +680,7 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, opsToErase.insert(iterateOp); } return success(); - } else if (auto blockOp = dyn_cast_or_null(op)) { + } else if (auto blockOp = mlir::dyn_cast_or_null(op)) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " interpret block op " << blockOp << "\n"); SmallVector tiledLoops; @@ -709,7 +713,7 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, opsToErase.insert(op); return success(); - } else if (auto permuteOp = dyn_cast_or_null(op)) { + } else if (auto permuteOp = mlir::dyn_cast_or_null(op)) { LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " interpret permute op " << permuteOp << "\n"); // TODO(tjingrant): call it whenever an operation lowering completes. @@ -731,13 +735,17 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, opsToErase.insert(op); return success(); - } else if (auto parallelOp = dyn_cast_or_null(op)) { + } else if (auto parallelOp = mlir::dyn_cast_or_null(op)) { // Parallelism the given loop by transform the tagged affine.for op to // affine.parallel LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " interpret parallel op " << parallelOp << "\n"); // ToFix handle multiple parallel loop ValueRange loopRefs = parallelOp.getLoops(); + Value numThreads = parallelOp.getNumThreads(); + StringAttr procBind = parallelOp.getProcBindAttr(); + bool needParallelClause = + numThreads || (procBind && procBind.getValue().size() > 0); // Obtain the the reference the loop that needs to be parallelized for (Value loopRef : loopRefs) { @@ -774,6 +782,23 @@ static LogicalResult interpretOperation(Operation *op, OpBuilder &builder, parallelLoop.getRegion().takeBody(loopToParallel.getRegion()); Operation *yieldOp = ¶llelLoop.getBody()->back(); yieldOp->setOperands(reducedValues); + if (needParallelClause) { + // Use clause only for the first one (expected the outermost one). + // Ideally, we would generate here a single, multi-dimensional + // AffineParallelOp, and we would not need to reset the flag. + needParallelClause = false; + // Currently approach: insert after yield and then move before it. + PatternRewriter::InsertionGuard insertGuard(builder); + builder.setInsertionPointAfter(yieldOp); + // Get induction variable. + ValueRange optionalLoopIndices = parallelLoop.getIVs(); + assert(optionalLoopIndices.size() >= 1 && + "expected at least one loop index"); + Value parallelLoopIndex = optionalLoopIndices[0]; + Operation *newOp = opBuilder.create( + loc, parallelLoopIndex, numThreads, procBind); + newOp->moveBefore(yieldOp); + } // Replace the affine.forOp with affine.parallelOp in loopRefToTop loopRefToOp[loopRef] = parallelLoop; loopToParallel.erase(); @@ -789,20 +814,18 @@ AffineTypeConverter::AffineTypeConverter() { addConversion([](Type type) { return type; }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) - return std::nullopt; + return Value(); return builder.create(loc, resultType, inputs) .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) - return std::nullopt; + return Value(); return builder.create(loc, resultType, inputs) .getResult(0); @@ -846,7 +869,8 @@ void ConvertKrnlToAffinePass::runOnOperation() { // Move invariant instructions outside of the loops as many as possible. This // helps make loops perfectly nested, which facilitates transformations. funcOp.walk([&](KrnlIterateOp loopOp) { - moveLoopInvariantCode(cast(loopOp.getOperation())); + moveLoopInvariantCode( + mlir::cast(loopOp.getOperation())); }); // We use the end of the function body as a staging area for movable ops. @@ -879,10 +903,10 @@ void ConvertKrnlToAffinePass::runOnOperation() { std::vector iterateOps; for (auto result : defineOp.getResults()) for (auto *user : result.getUsers()) - if (auto iterateOp = dyn_cast_or_null(user)) + if (auto iterateOp = mlir::dyn_cast_or_null(user)) if (std::find(iterateOps.begin(), iterateOps.end(), iterateOp) == iterateOps.end()) - iterateOps.push_back(dyn_cast(user)); + iterateOps.push_back(mlir::dyn_cast(user)); // Lower iterate operations and record the mapping between loop references // and affine for loop operations in loopRefToOp map. @@ -924,8 +948,9 @@ void ConvertKrnlToAffinePass::runOnOperation() { auto &blockOps = block.getOperations(); for (auto itr = blockOps.begin(); itr != blockOps.end(); ++itr) { Operation *genericOp = &(*itr); - if (auto getIVOp = dyn_cast_or_null( - genericOp)) { + if (auto getIVOp = + mlir::dyn_cast_or_null( + genericOp)) { lowerGetInductionVariableValueOp(getIVOp, loopRefToOp); opsToErase.insert(genericOp); } @@ -944,13 +969,14 @@ void ConvertKrnlToAffinePass::runOnOperation() { funcOp->walk([&](Operation *op) { if (SpecializedKernelOpInterface kernelOp = - dyn_cast(op)) { + mlir::dyn_cast(op)) { OperandRange loopRefs = kernelOp.getLoopRefs(); for (auto loopRef : loopRefs) opsToErase.insert(loopRefToOp[loopRef]); kernelOp.getLoopRefs().clear(); } - if (auto getIVOp = dyn_cast_or_null(op)) { + if (auto getIVOp = + mlir::dyn_cast_or_null(op)) { lowerGetInductionVariableValueOp(getIVOp, loopRefToOp); opsToErase.insert(op); } @@ -968,6 +994,7 @@ void ConvertKrnlToAffinePass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); diff --git a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp index 352bc9d6be..5b5f761a35 100644 --- a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp +++ b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp @@ -39,7 +39,8 @@ class KrnlCopyFromBufferLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - KrnlCopyFromBufferOp copyFromBufferOp = cast(op); + KrnlCopyFromBufferOp copyFromBufferOp = + mlir::cast(op); Location loc = copyFromBufferOp.getLoc(); MultiDialectBuilder create( rewriter, loc); @@ -89,7 +90,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern { return success(); } - void genCopyLoops(AffineBuilderKrnlMem &createAffine, + void genCopyLoops(const AffineBuilderKrnlMem &createAffine, IndexExprScope *enclosingScope, Value buffMemref, Value destMemref, IndexExpr zeroIE, SmallVectorImpl &starts, SmallVectorImpl &writeUBs, SmallVectorImpl &loopIndices, @@ -123,9 +124,9 @@ class KrnlCopyFromBufferLowering : public ConversionPattern { // Nothing to write. } else { // Loop to copy the data. - createAffine.forIE(zeroIE, writeUBs[i], 1, - [&](AffineBuilderKrnlMem &createAffine, Value index) { - loopIndices.emplace_back(index); + createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { + loopIndices.emplace_back(loopInd[0]); genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref, zeroIE, starts, writeUBs, loopIndices, i + 1, buffRank); loopIndices.pop_back_n(1); diff --git a/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp b/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp index a08766d63a..7e6bbf20d2 100644 --- a/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp +++ b/src/Conversion/KrnlToAffine/KrnlCopyToBuffer.cpp @@ -39,7 +39,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Get info from operands. - KrnlCopyToBufferOp copyToBufferOp = cast(op); + KrnlCopyToBufferOp copyToBufferOp = mlir::cast(op); Location loc = copyToBufferOp.getLoc(); MultiDialectBuilder create( rewriter, loc); @@ -129,7 +129,7 @@ class KrnlCopyToBufferLowering : public ConversionPattern { return success(); } - void genCopyLoops(AffineBuilderKrnlMem &createAffine, + void genCopyLoops(const AffineBuilderKrnlMem &createAffine, IndexExprScope *enclosingScope, Value buffMemref, Value sourceMemref, SmallVectorImpl &srcLoopMap, Value padVal, IndexExpr zeroIE, SmallVectorImpl &starts, SmallVectorImpl &readUBs, @@ -168,9 +168,9 @@ class KrnlCopyToBufferLowering : public ConversionPattern { if (readUBs[i].isLiteralAndIdenticalTo(0)) { // Nothing to read, skip. } else { - createAffine.forIE(zeroIE, readUBs[i], 1, - [&](AffineBuilderKrnlMem &createAffine, Value index) { - loopIndices.emplace_back(index); + createAffine.forLoopIE(zeroIE, readUBs[i], 1, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { + loopIndices.emplace_back(loopInd[0]); genCopyLoops(createAffine, enclosingScope, buffMemref, sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs, padUBs, loopIndices, i + 1, buffRank, @@ -181,9 +181,9 @@ class KrnlCopyToBufferLowering : public ConversionPattern { if (padUBs[i].isLiteralAndIdenticalTo(0)) { // No padding needed. } else { - createAffine.forIE(readUBs[i], padUBs[i], 1, - [&](AffineBuilderKrnlMem &createAffine, Value index) { - loopIndices.emplace_back(index); + createAffine.forLoopIE(readUBs[i], padUBs[i], 1, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { + loopIndices.emplace_back(loopInd[0]); genCopyLoops(createAffine, enclosingScope, buffMemref, sourceMemref, srcLoopMap, padVal, zeroIE, starts, readUBs, padUBs, loopIndices, i + 1, buffRank, diff --git a/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp b/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp index 99dc1037f2..8ca09343f2 100644 --- a/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp +++ b/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp @@ -53,7 +53,8 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern { auto memrefTy = llvm::dyn_cast(memref.getType()); int64_t rank = memrefTy.getRank(); - assert((int64_t)mapResults.value().size() == rank && "Invalid indices"); + assert(static_cast(mapResults.value().size()) == rank && + "Invalid indices"); // Only lower this op after the memref is normalized. if (!memrefTy.getLayout().isIdentity()) @@ -63,10 +64,10 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern { SmallVector dims; create.krnlIE.getShapeAsDims(memref, dims); // Compute the linear offset using strides. - IndexExpr offsetIE = LiteralIndexExpr(0); - IndexExpr strideIE = LiteralIndexExpr(1); + IndexExpr offsetIE = LitIE(0); + IndexExpr strideIE = LitIE(1); for (int64_t i = rank - 1; i >= 0; --i) { - IndexExpr strideOffset = strideIE * DimIndexExpr(indices[i]); + IndexExpr strideOffset = strideIE * DimIE(indices[i]); offsetIE = offsetIE + strideOffset; if (i > 0) strideIE = strideIE * dims[i]; diff --git a/src/Conversion/KrnlToAffine/KrnlLoad.cpp b/src/Conversion/KrnlToAffine/KrnlLoad.cpp index 0e44cbdff8..b4ef3ce131 100644 --- a/src/Conversion/KrnlToAffine/KrnlLoad.cpp +++ b/src/Conversion/KrnlToAffine/KrnlLoad.cpp @@ -38,7 +38,7 @@ class KrnlLoadLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto loadOp = cast(op); + auto loadOp = mlir::cast(op); KrnlLoadOpAdaptor operandAdaptor(loadOp); // Prepare inputs. diff --git a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp index 289d394a10..8ab9ef7b1a 100644 --- a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp @@ -48,7 +48,7 @@ class KrnlMatmulLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto matmulOp = cast(op); + auto matmulOp = mlir::cast(op); KrnlMatMulOpAdaptor operandAdaptor(matmulOp); // Option. bool fullUnrollAndJam = matmulOp.getUnroll(); @@ -122,7 +122,7 @@ class KrnlMatmulLowering : public ConversionPattern { jGlobalUB.getLiteral() == 1; // Investigate SIMD - IndexExpr vectorLen = LiteralIndexExpr(1); // Assume no simd. + IndexExpr vectorLen = LitIE(1); // Assume no simd. if (simdize) { if (matVectorProduct) { // Matrix (I x K) times vector (K x 1). We currently vectorize along the @@ -134,7 +134,7 @@ class KrnlMatmulLowering : public ConversionPattern { uint64_t archVL = create.vec.getArchVectorLength(elementType); if (i % archVL == 0 && k % archVL == 0) { // Right now, vector length must be archVL. - vectorLen = LiteralIndexExpr(archVL); + vectorLen = LitIE(archVL); } else { simdize = false; LLVM_DEBUG(llvm::dbgs() << "Matmul: mat*vec with bad sizes: i " << i @@ -167,34 +167,31 @@ class KrnlMatmulLowering : public ConversionPattern { // A[i, k]; SmallVector aStart, bStart, cStart; for (int t = 0; t < aRank - 2; t++) - aStart.emplace_back( - SymbolIndexExpr(operandAdaptor.getAGlobalIndexMemStart()[t])); + aStart.emplace_back(SymIE(operandAdaptor.getAGlobalIndexMemStart()[t])); aStart.emplace_back( iGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getAGlobalIndexMemStart()[aRank - 2])); + DimIE(operandAdaptor.getAGlobalIndexMemStart()[aRank - 2])); aStart.emplace_back( kGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getAGlobalIndexMemStart()[aRank - 1])); + DimIE(operandAdaptor.getAGlobalIndexMemStart()[aRank - 1])); // B[k, j]; for (int t = 0; t < bRank - 2; t++) - bStart.emplace_back( - SymbolIndexExpr(operandAdaptor.getBGlobalIndexMemStart()[t])); + bStart.emplace_back(SymIE(operandAdaptor.getBGlobalIndexMemStart()[t])); bStart.emplace_back( kGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getBGlobalIndexMemStart()[bRank - 2])); + DimIE(operandAdaptor.getBGlobalIndexMemStart()[bRank - 2])); bStart.emplace_back( jGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getBGlobalIndexMemStart()[bRank - 1])); + DimIE(operandAdaptor.getBGlobalIndexMemStart()[bRank - 1])); // C[i, j] for (int t = 0; t < cRank - 2; t++) - cStart.emplace_back( - SymbolIndexExpr(operandAdaptor.getCGlobalIndexMemStart()[t])); + cStart.emplace_back(SymIE(operandAdaptor.getCGlobalIndexMemStart()[t])); cStart.emplace_back( iGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getCGlobalIndexMemStart()[cRank - 2])); + DimIE(operandAdaptor.getCGlobalIndexMemStart()[cRank - 2])); cStart.emplace_back( jGlobalIndexComputeStart - - DimIndexExpr(operandAdaptor.getCGlobalIndexMemStart()[cRank - 1])); + DimIE(operandAdaptor.getCGlobalIndexMemStart()[cRank - 1])); // Now determine if we have full/partial tiles. This is determined by the // outer dimensions of the original computations, as by definition tiling @@ -225,31 +222,33 @@ class KrnlMatmulLowering : public ConversionPattern { // SIMD code generator. if (matVectorProduct) { // clang-format off - create.affineKMem.ifThenElse(indexScope, allFullTiles, - /* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) { + create.affineKMem.ifThenElseIE(indexScope, allFullTiles, + /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); - }, /* else has partial tiles */ [&](AffineBuilderKrnlMem &createAffine) { + }, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) { genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, /*unroll*/ false); }); // clang-format on } else { // clang-format off - create.affineKMem.ifThenElse(indexScope, allFullTiles, - /* then full tiles */ [&](AffineBuilderKrnlMem &createAffine) { + create.affineKMem.ifThenElseIE(indexScope, allFullTiles, + /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); - }, /* has some partial tiles */ [&](AffineBuilderKrnlMem &createAffine) { + }, + /* Else has some partial tiles */ + [&](const AffineBuilderKrnlMem &createAffine) { // Trip regardless of full/partial for N & K // Test if SIMD dim (M) is full. - createAffine.ifThenElse(indexScope, jFullTiles, - /* full SIMD */ [&](AffineBuilderKrnlMem &createAffine) { + createAffine.ifThenElseIE(indexScope, jFullTiles, + /* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false); - }, /* else partial SIMD */ [&](AffineBuilderKrnlMem &createAffine) { + }, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { // TODO: evaluate if get performance from partial SIMD if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) { // has a known trip count along the simd dimension of at least 2 @@ -267,12 +266,12 @@ class KrnlMatmulLowering : public ConversionPattern { } else { // Scalar code generator. // clang-format off - create.affineKMem.ifThenElse(indexScope, allFullTiles, - /* then full */ [&](AffineBuilderKrnlMem &createAffine) { + create.affineKMem.ifThenElseIE(indexScope, allFullTiles, + /* then full */ [&](const AffineBuilderKrnlMem &createAffine) { genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, fullUnrollAndJam); - }, /* else partial */ [&](AffineBuilderKrnlMem &createAffine) { + }, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) { genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, false); }); @@ -283,7 +282,7 @@ class KrnlMatmulLowering : public ConversionPattern { } private: - void genScalar(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, + void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, Type elementType, ArrayRef aStart, ArrayRef bStart, ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, bool unrollJam) const { @@ -302,11 +301,14 @@ class KrnlMatmulLowering : public ConversionPattern { // For i, j loops. LiteralIndexExpr zeroIE(0); Value jSaved; - createAffine.forIE( - zeroIE, I, 1, [&](AffineBuilderKrnlMem &createAffine, Value i) { - createAffine.forIE( - zeroIE, J, 1, [&](AffineBuilderKrnlMem &createAffine, Value j) { + createAffine.forLoopIE(zeroIE, I, 1, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { + Value i = loopInd[0]; + createAffine.forLoopIE(zeroIE, J, 1, + [&](const AffineBuilderKrnlMem &createAffine, + ValueRange loopInd) { MathBuilder createMath(createAffine); + Value j = loopInd[0]; // Defines induction variables, and possibly initialize C. jSaved = j; // Alloc and init temp c storage. @@ -315,9 +317,11 @@ class KrnlMatmulLowering : public ConversionPattern { // TTmpC() = affine_load(C, cAccess); createAffine.store(initVal, TmpC, tmpCAccess); // Sum over k. - createAffine.forIE(zeroIE, K, 1, - [&](AffineBuilderKrnlMem &createAffine, Value k) { + createAffine.forLoopIE(zeroIE, K, 1, + [&](const AffineBuilderKrnlMem &createAffine, + ValueRange loopInd) { MathBuilder createMath(createAffine); + Value k = loopInd[0]; Value a = createAffine.loadIE(A, aStart, {i, k}); Value b = createAffine.loadIE(B, bStart, {k, j}); Value res = createMath.mul(a, b); @@ -339,7 +343,7 @@ class KrnlMatmulLowering : public ConversionPattern { } // Initially, simdize with full K vector length. - void genSimdMatVect(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, + void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, Type elementType, ArrayRef aStart, ArrayRef bStart, ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { @@ -368,13 +372,7 @@ class KrnlMatmulLowering : public ConversionPattern { assert(BUFFER_ALIGN >= gDefaultAllocAlign && "alignment of buffers cannot be smaller than the default alignment " "(which is set for SIMD correctness"); - // TODO: alloca is good as it help simplify away this data structures (as it - // is only used as local temp, basically extensions of registers). However, - // there might be issues with non-removed alloca when they are not in the - // innermost loop. Still think its worth it having alloca as we want - // eventually all the refs to alloca to be register/spill access, not memory - // load/stores. - Value TmpProd = create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + Value TmpProd = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Init with zero. Value fZero = create.math.constant(elementType, 0); Value vFZero = create.vec.broadcast(vecType, fZero); @@ -382,9 +380,10 @@ class KrnlMatmulLowering : public ConversionPattern { LiteralIndexExpr zeroIE(0); Value iZero = create.math.constantIndex(0); - create.affineKMem.forIE( - zeroIE, K, VL, [&](AffineBuilderKrnlMem &createAffine, Value k) { + create.affineKMem.forLoopIE(zeroIE, K, VL, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { MultiDialectBuilder create(createAffine); + Value k = loopInd[0]; // Iterates over the I indices (K is SIMD dim). // First compute A[i,k]*B[k, 1] for i=0..iUnrollFactor explicitly. // We reuse B[k][0] vector for each iteration of i. @@ -429,7 +428,7 @@ class KrnlMatmulLowering : public ConversionPattern { } // Simdize along J / memory rows in B and C. - void genSimdMatMat(AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, + void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, Type elementType, ArrayRef aStart, ArrayRef bStart, ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { @@ -450,32 +449,29 @@ class KrnlMatmulLowering : public ConversionPattern { // Have to privatize CTmpType by unroll factor (1 if none). MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType); assert(BUFFER_ALIGN >= gDefaultAllocAlign); - // TODO: alloca is good as it help simplify away this data structures (as it - // is only used as local temp, basically extensions of registers). However, - // there might be issues with non-removed alloca when they are not in the - // innermost loop. Still think its worth it having alloca as we want - // eventually all the refs to alloca to be register/spill access, not memory - // load/stores. - Value TmpC = create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + Value TmpC = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Iterates over the I indices (j are simd dim). Value iSaved, kSaved; LiteralIndexExpr zeroIE(0); Value iZero = create.math.constantIndex(0); - createAffine.forIE( - zeroIE, I, 1, [&](AffineBuilderKrnlMem &createAffine, Value i) { + createAffine.forLoopIE(zeroIE, I, 1, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { MultiDialectBuilder create(createAffine); + Value i = loopInd[0]; iSaved = i; // Saved for unroll and jam. - // Alloca temp vector TmpC and save C(i)/0.0 into it. + // Alloc temp vector TmpC and save C(i)/0.0 into it. Value initVal = create.vec.loadIE(vecType, C, cStart, {i, iZero}); Value tmpCAccess = (unrollFactor > 1) ? i : zeroIE.getValue(); createAffine.store(initVal, TmpC, tmpCAccess); // Sum over k. - createAffine.forIE( - zeroIE, K, 1, [&](AffineBuilderKrnlMem &createAffine, Value k) { + createAffine.forLoopIE(zeroIE, K, 1, + [&](const AffineBuilderKrnlMem &createAffine, + ValueRange loopInd) { MultiDialectBuilder create( createAffine); + Value k = loopInd[0]; kSaved = k; Value a = createAffine.loadIE(A, aStart, {i, k}); Value va = create.vec.broadcast(vecType, a); diff --git a/src/Conversion/KrnlToAffine/KrnlMemset.cpp b/src/Conversion/KrnlToAffine/KrnlMemset.cpp index 67cbe0d3c3..0cda7ad596 100644 --- a/src/Conversion/KrnlToAffine/KrnlMemset.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMemset.cpp @@ -35,7 +35,7 @@ class KrnlMemsetLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Get info from operands. - auto memsetOp = cast(op); + auto memsetOp = mlir::cast(op); bool delayed = memsetOp.getDelayed(); KrnlMemsetOpAdaptor operandAdaptor(memsetOp); Value destMemRef(operandAdaptor.getDest()); @@ -55,11 +55,12 @@ class KrnlMemsetLowering : public ConversionPattern { SmallVector ubs; create.krnlIE.getShapeAsDims(destMemRef, ubs); int rank = ubs.size(); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector steps(rank, 1); + SmallVector useParallel(rank, false); // Copy data, - create.affineKMem.forIE(lbs, ubs, steps, - [&](AffineBuilderKrnlMem &createAffine, ValueRange indices) { + create.affineKMem.forLoopsIE(lbs, ubs, steps, useParallel, + [&](const AffineBuilderKrnlMem &createAffine, ValueRange indices) { createAffine.store(destVal, destMemRef, indices); }); rewriter.eraseOp(op); diff --git a/src/Conversion/KrnlToAffine/KrnlStore.cpp b/src/Conversion/KrnlToAffine/KrnlStore.cpp index 682a0c2e1f..989a108147 100644 --- a/src/Conversion/KrnlToAffine/KrnlStore.cpp +++ b/src/Conversion/KrnlToAffine/KrnlStore.cpp @@ -38,7 +38,7 @@ class KrnlStoreLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto storeOp = cast(op); + auto storeOp = mlir::cast(op); KrnlStoreOpAdaptor operandAdaptor(storeOp); // Prepare inputs. diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index 52a583552f..92948137be 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMKrnlToLLVM KrnlPrintTensor.cpp KrnlPrint.cpp KrnlRandomNormal.cpp + KrnlRoundEven.cpp KrnlStrlen.cpp KrnlStrncmp.cpp KrnlToLLVMHelper.cpp diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index 04cc4206c5..d33abe5918 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -198,6 +198,7 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, patterns, vector::VectorTransformsOptions()); vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransformsOptions()); + vector::populateVectorShapeCastLoweringPatterns(patterns); populateAffineToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns); @@ -246,7 +247,7 @@ void PostfixEntrypointNames(ModuleOp &module) { .getValue() .str(); func::FuncOp entryPointFunc = - dyn_cast(module.lookupSymbol(entryPointFuncName)); + mlir::dyn_cast(module.lookupSymbol(entryPointFuncName)); assert(entryPointFunc && "entry point func must exist"); // Update the function name. entryPointFunc.setSymName( @@ -278,12 +279,12 @@ void recordInputOutputMemRefTypes(ModuleOp &module, assert(entryPointFunc && isa(entryPointFunc) && "entry point func must exist and be an llvm func op"); auto entryPointTy = mlir::dyn_cast( - dyn_cast(entryPointFunc).getFunctionType()); + mlir::dyn_cast(entryPointFunc).getFunctionType()); SmallVector inputTypes, outputTypes; for (Type ty : entryPointTy.getInputs()) - inputTypes.emplace_back(dyn_cast(ty)); + inputTypes.emplace_back(mlir::dyn_cast(ty)); for (Type ty : entryPointTy.getResults()) - outputTypes.emplace_back(dyn_cast(ty)); + outputTypes.emplace_back(mlir::dyn_cast(ty)); inputMemRefTypes.emplace( std::make_pair(entryPointFuncName.str(), inputTypes)); outputMemRefTypes.emplace( @@ -367,12 +368,12 @@ void genSignatureFunction(ModuleOp &module, // If the argument is not NULL, update its value to return the number of // entry points. create.llvm.ifThenElse(/*cond=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { Value nullPtr = createLLVM.null(i64PtrTy); return createLLVM.icmp( LLVM::ICmpPredicate::ne, numOfEntryPoints, nullPtr); }, /*then=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { Value numOfEntryPointsPtr = createLLVM.getElemPtr( i64PtrTy, i64Type, numOfEntryPoints, ArrayRef{0}); Value noep = @@ -420,7 +421,7 @@ void genSignatureFunction(ModuleOp &module, // Return the signature if found. create.llvm.ifThenElse(/*cond=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { // Read an entry point name. Value entryI8Ptr = krnl::getPtrToGlobalString(globalEntryPoint, loc, b); @@ -434,7 +435,7 @@ void genSignatureFunction(ModuleOp &module, return createLLVM.icmp( LLVM::ICmpPredicate::eq, strncmpResult, zeroI32); }, /*then=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { Value sigAddr = createLLVM.addressOf(globalSignature); Value sigI8Ptr = createLLVM.bitcast(i8PtrTy, sigAddr); createLLVM._return(sigI8Ptr); @@ -556,6 +557,7 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(module.getBody()); std::string fname = llvm::sys::path::filename(filepath).str() + '\0'; + fname = (isZOS(module)) ? krnl::a2e_s(fname) : fname; mlir::StringAttr valueAttr = mlir::StringAttr::get(context, fname); create.llvm.globalOp(LLVM::LLVMArrayType::get(llvmI8Ty, fname.size()), /*isConstant=*/true, LLVM::Linkage::Internal, @@ -599,15 +601,15 @@ void loadConstantsFromFile(ModuleOp &module, OpBuilder b(ctx); MultiDialectBuilder create(b, loc); + Type llvmI1Ty = IntegerType::get(ctx, 1); Type llvmI8Ty = IntegerType::get(ctx, 8); Type llvmI64Ty = IntegerType::get(ctx, 64); Type llvmI8PtrTy = getPointerType(ctx, llvmI8Ty); - Type llvmVoidTy = LLVM::LLVMVoidType::get(ctx); // The following function will be emitted inside the IR to load constants from // file. std::string loadAllConstantsFuncName = "omLoadConstantsFromFile"; - Type llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, {}, false); + Type llvmFnType = LLVM::LLVMFunctionType::get(llvmI1Ty, {}, false); // If calledByEntryPoint, this function will be called by entry points. // Otherwise, user program (C/C++/Java/Python) would call this function. @@ -616,6 +618,7 @@ void loadConstantsFromFile(ModuleOp &module, Operation *firstEntryPointOp = getFirstEntryOpInBlock(module, entryGlobalOps); assert(firstEntryPointOp && "No entry function exists"); + OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(firstEntryPointOp); funcOp = create.llvm.func( loadAllConstantsFuncName, llvmFnType, /*createUniqueFunc=*/true); @@ -633,13 +636,16 @@ void loadConstantsFromFile(ModuleOp &module, std::find(entryName.begin(), entryName.end(), '\0'), entryName.end()); auto entryFunc = module.lookupSymbol(entryName); assert(entryFunc && "Entry function not found"); + OpBuilder::InsertionGuard guard(b); b.setInsertionPoint( &entryFunc.getBody().front(), entryFunc.getBody().front().begin()); FlatSymbolRefAttr loadAllConstantsRef = create.llvm.getOrInsertSymbolRef( module, LLVMBuilder::SymbolPostfix(module, loadAllConstantsFuncName), - llvmVoidTy, {}, + llvmI1Ty, {}, /*isVarArg=*/false); - create.llvm.call({}, loadAllConstantsRef, {}); + Value retVal = create.llvm.call({llvmI1Ty}, loadAllConstantsRef, {}); + equalOrFailed(module, b, loc, + create.llvm.constant(llvmI1Ty, static_cast(1)), retVal); } } else { OpBuilder::InsertionGuard guard(b); @@ -684,8 +690,11 @@ void loadConstantsFromFile(ModuleOp &module, // Call a function to mmap the binary file to memory. Value isleVal = create.llvm.constant(llvmI64Ty, isle); Value sizeVal = create.llvm.constant(llvmI64Ty, dataSize); - RuntimeAPI::callApi(b, loc, apiRegistry, RuntimeAPI::API::MMAP_BINARY_FILE, + Value retVal = RuntimeAPI::callApi(b, loc, apiRegistry, + RuntimeAPI::API::MMAP_BINARY_FILE, {packedGlobalPtr, fnameI8Ptr, sizeVal, isleVal}); + equalOrReturn(module, b, loc, + create.llvm.constant(llvmI1Ty, static_cast(1)), retVal, retVal); // Now set pointers for constants in the IR module->walk([&](LLVM::GlobalOp dataGlobalOp) -> WalkResult { @@ -712,11 +721,10 @@ void loadConstantsFromFile(ModuleOp &module, RuntimeAPI::callApi(b, loc, apiRegistry, RuntimeAPI::API::GET_EXTERNAL_CONSTANT_ADDR, {dataPtr, packedGlobalPtr, offsetVal}); - return WalkResult::advance(); }); - create.llvm._return(); + create.llvm._return(create.llvm.constant(llvmI1Ty, static_cast(1))); } //===----------------------------------------------------------------------===// @@ -971,6 +979,7 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, krnl::populateLoweringKrnlUnaryMathOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlStrncmpOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlNoneOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlRoundEvenOpPattern(typeConverter, patterns, ctx); } } // namespace krnl diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index c222913dfe..2309871db4 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -107,6 +107,10 @@ void populateLoweringKrnlVectorTypeCastOpPattern( void populateLoweringKrnlNoneOpPattern(mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateLoweringKrnlRoundEvenOpPattern( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::MLIRContext *ctx); + void determineOwnershipForOutputOMTensors(mlir::ModuleOp &module, llvm::SmallVectorImpl &outputOMTensorOwnerships); diff --git a/src/Conversion/KrnlToLLVM/KrnlCall.cpp b/src/Conversion/KrnlToLLVM/KrnlCall.cpp index 6d570ab3d4..3251fd1697 100644 --- a/src/Conversion/KrnlToLLVM/KrnlCall.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlCall.cpp @@ -4,7 +4,7 @@ //===-------------- KrnlCall.cpp - Lower KrnlCallOp -----------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -68,10 +68,27 @@ class KrnlCallOpLowering : public ConversionPattern { rewriter, op, namedAttr.getValue(), parameterTypeList, parameterList); } - FlatSymbolRefAttr callRef = - create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(), - LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList); - create.llvm.call({}, callRef, parameterList); + ValueRange returns = op->getResults(); + if (returns.size() == 0) { + // There is no return + FlatSymbolRefAttr callRef = + create.llvm.getOrInsertSymbolRef(module, krnlCallOp.getFuncName(), + LLVM::LLVMVoidType::get(module.getContext()), parameterTypeList); + create.llvm.call({}, callRef, parameterList); + + rewriter.eraseOp(op); + } else { + assert(returns.size() == 1 && + "Only one return value is allowed for krnl.call now"); + Type llvmReturnType = + llvmTypeConverter->convertType(returns[0].getType()); + + FlatSymbolRefAttr callRef = create.llvm.getOrInsertSymbolRef( + module, krnlCallOp.getFuncName(), llvmReturnType, parameterTypeList); + auto llvmCall = + create.llvm.call({llvmReturnType}, callRef, parameterList); + rewriter.replaceOp(op, llvmCall.getDefiningOp()->getResults()[0]); + } // Destroy OMTensor wrappers of parameters. const auto &apiRegistry = @@ -81,7 +98,6 @@ class KrnlCallOpLowering : public ConversionPattern { rewriter, loc, apiRegistry, RuntimeAPI::API::DESTROY_OMTENSOR, {omt}); } - rewriter.eraseOp(op); return success(); } @@ -102,11 +118,12 @@ class KrnlCallOpLowering : public ConversionPattern { // Check the original type, not after type conversion Type ty = original.getType(); - if (auto originalMemRef = dyn_cast(ty)) { + if (auto originalMemRef = mlir::dyn_cast(ty)) { auto int64Ty = IntegerType::get(context, 64); auto memRefTy = mlir::dyn_cast(parameter.getType()); auto memRefRank = krnl::getRankFromMemRefType(memRefTy); - auto memRefRankVal = create.llvm.constant(int64Ty, (int64_t)memRefRank); + auto memRefRankVal = + create.llvm.constant(int64Ty, static_cast(memRefRank)); Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); @@ -190,7 +207,7 @@ class KrnlCallOpLowering : public ConversionPattern { auto int64Ty = IntegerType::get(context, 64); auto memRefRank = memRefTy.getRank(); auto memRefRankVal = - create.llvm.constant(int64Ty, (int64_t)memRefRank); + create.llvm.constant(int64Ty, static_cast(memRefRank)); Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); diff --git a/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp index 47eaa8bb53..3eec020df7 100644 --- a/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp @@ -4,7 +4,7 @@ //===------ KrnlEntryPoint.cpp - Lower KrnlEntryPointOp -------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -151,7 +151,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { assert(mlir::isa(maccelAttr) && "onnx-mlir.accels must be ArrayAttr"); ArrayAttr accels = mlir::cast(maccelAttr); - Value zeroI64 = create.llvm.constant(int64Ty, (int64_t)0); + Value zeroI64 = create.llvm.constant(int64Ty, static_cast(0)); for (uint64_t i = 0; i < accels.size(); ++i) { assert( @@ -166,17 +166,17 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Emit code for `if (OMInitCompatibleAccelX() == 0) then return NULL`. create.llvm.ifThenElse(/*cond=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { // Call OMInitCompatibleAccelX. - Value versionNumberVal = - createLLVM.constant(int64Ty, (int64_t)versionNumberInHex); + Value versionNumberVal = createLLVM.constant( + int64Ty, static_cast(versionNumberInHex)); Value isCompatible = createLLVM.call( int64Ty, funcRef, ArrayRef({versionNumberVal})); // Condition: if (OMInitCompatibleAccelX() == 0) return createLLVM.icmp( LLVM::ICmpPredicate::eq, isCompatible, zeroI64); }, /*then=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { // return NULL. createLLVM._return(createLLVM.null(getI8PointerType(context))); }); @@ -202,7 +202,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName.lower()); auto staticEntryPointFuncTy = mlir::cast( - cast(staticEntryPointFunc).getFunctionType()); + mlir::cast(staticEntryPointFunc).getFunctionType()); LLVM_DEBUG(llvm::dbgs() << "Static entry point function type: " << staticEntryPointFuncTy << "\n"); // Static entry point is wrapped with prefix `_mlir_ciface` automatically by @@ -216,13 +216,13 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { isa(wrappedStaticEntryPointFunc) && "entry point func must exist and be an llvm func op"); auto wrappedStaticEntryPointOp = - cast(wrappedStaticEntryPointFunc); + mlir::cast(wrappedStaticEntryPointFunc); auto wrappedStaticEntryPointTy = mlir::cast( wrappedStaticEntryPointOp.getFunctionType()); Value omTensorPtrArr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_OMT_ARRAY, {omTensorInputs}); - Value one = create.llvm.constant(int64Ty, (int64_t)1); + Value one = create.llvm.constant(int64Ty, static_cast(1)); // Prepare MemRefs as inputs for the wrapped static entry point function. // MemRefs are filled with information from user' OMTensor inputs. @@ -233,15 +233,16 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // entry point instead of the wrapped static entry point. Type memRefOutTy = staticEntryPointFuncTy.getReturnTypes()[0]; Type memRefOutPtrTy = getPointerType(context, memRefOutTy); - Value ptrToOutMemRef = + Value ptrToOutMemRef = // alloca ok as there is only one entry point. create.llvm._alloca(memRefOutPtrTy, memRefOutTy, one, /*alignment=*/0); staticInputs.emplace_back(ptrToOutMemRef); // Start with param 1 because 0 is the return value. for (size_t i = 1; i < wrappedStaticEntryPointTy.getNumParams(); i++) { // Call API function to retrieve the i-th dynamic memref. - Value omTensorPtrAddr = create.llvm.getElemPtr(omTensorPtrAddrTy, - opaquePtrTy, omTensorPtrArr, ArrayRef{(int32_t)i - 1}); + Value omTensorPtrAddr = + create.llvm.getElemPtr(omTensorPtrAddrTy, opaquePtrTy, omTensorPtrArr, + ArrayRef{static_cast(i) - 1}); Value omTensorPtr = create.llvm.load(opaquePtrTy, omTensorPtrAddr); // Create a (static) memref type corresponding to the i-th memref input to @@ -249,7 +250,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Original input is shifted by 1 in the iface func. Type memRefInTy = typeConverter.convertType(origInputMemRefTypes[i - 1]); Type memRefInPtrTy = getPointerType(context, memRefInTy); - Value ptrToMemRef = + Value ptrToMemRef = // alloca ok as there is only one entry point. create.llvm._alloca(memRefInPtrTy, memRefInTy, one, /*alignment=*/0); // Fill in the memref underlying ptrToMemRef with information extracted @@ -268,7 +269,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { auto outMemRefsType = mlir::dyn_cast(outMemRefs.getType()); - std::vector outMemRefList; + std::vector outMemRefList; if (numOutputs == 1) { // If only one output tensor exists, the tensor's corresponding memref // descriptor will be returned as is. @@ -284,9 +285,10 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { } } - Value numOutput = - create.llvm.constant(int64Ty, (int64_t)outMemRefList.size()); - // Assume that OMTensor pointer size is 8 + Value numOutput = create.llvm.constant( + int64Ty, static_cast(outMemRefList.size())); + // Assume that OMTensor pointer size is 8. + // Alloca ok as its only for 1 small data structure per parameters. Value outOmtPtrsArr = create.llvm._alloca( omTensorPtrAddrTy, opaquePtrTy, numOutput, /*alignment=*/0); @@ -297,7 +299,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { auto outMemRefTy = mlir::dyn_cast(memRef.getType()); int64_t outMemRefRank = krnl::getRankFromMemRefType(outMemRefTy); Value outMemRefRankVal = - create.llvm.constant(int64Ty, (int64_t)outMemRefRank); + create.llvm.constant(int64Ty, static_cast(outMemRefRank)); Value outOMTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::CREATE_OMTENSOR, {outMemRefRankVal}); // If output is a constant tensor or a block argument, OMTensor does not @@ -309,8 +311,9 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { origOutputMemRefTypes[i].getElementType(), outOMTensor, outOwning, rewriter, loc, apiRegistry, module); - Value omTensorPtrAddr = create.llvm.getElemPtr(omTensorPtrAddrTy, - opaquePtrTy, outOmtPtrsArr, ArrayRef{(int32_t)i}); + Value omTensorPtrAddr = + create.llvm.getElemPtr(omTensorPtrAddrTy, opaquePtrTy, outOmtPtrsArr, + ArrayRef{static_cast(i)}); create.llvm.store(outOMTensor, omTensorPtrAddr); } @@ -360,7 +363,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { memRef = create.llvm.insertValue(memRefTy, memRef, dataPtr, {1}); // Use zero offset now. - Value zero = create.llvm.constant(int64Ty, (int64_t)0); + Value zero = create.llvm.constant(int64Ty, static_cast(0)); memRef = create.llvm.insertValue(memRefTy, memRef, zero, {2}); // Get rank, sizes array ptr and strides array ptr. @@ -375,14 +378,14 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Insert size of the dimension. Value dimSizePtr = create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty, - sizesArrayPtr, ArrayRef{(int32_t)i}); + sizesArrayPtr, ArrayRef{static_cast(i)}); Value dimSize = create.llvm.load(int64Ty, dimSizePtr); memRef = create.llvm.insertValue(memRefTy, memRef, dimSize, {3, i}); // Insert stride of the dimension. auto dimStridePtr = create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty, - stridesArrayPtr, ArrayRef{(int32_t)i}); + stridesArrayPtr, ArrayRef{static_cast(i)}); auto dimStride = create.llvm.load(int64Ty, dimStridePtr); memRef = create.llvm.insertValue(memRefTy, memRef, dimStride, {4, i}); } @@ -409,31 +412,6 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { rewriter.getI64Type(), {rewriter.getI64Type()}); } - // Emit code for `IF lhs != rhs THEN return null ELSE do nothing` - void equalOrFailed(ModuleOp &module, PatternRewriter &rewriter, Location loc, - Value lhs, Value rhs, std::string errorMsg = "", - bool appendRHS = true) const { - MLIRContext *context = rewriter.getContext(); - MultiDialectBuilder create(rewriter, loc); - create.llvm.ifThenElse(/*cond=*/ - [&](LLVMBuilder &createLLVM) { - return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs); - }, /*then=*/ - [&](LLVMBuilder &createLLVM) { - MultiDialectBuilder create(createLLVM); - // Print an error message. - if (appendRHS) - create.krnl.printf( - StringRef(errorMsg), rhs, rewriter.getI64Type(), true); - else - create.krnl.printf(StringRef(errorMsg + "\n")); - // Set errno. - krnl::emitErrNo(module, rewriter, loc, EINVAL); - // Return NULL. - create.llvm._return(create.llvm.null(getI8PointerType(context))); - }); - } - void emitVerificationCodeForInputTensors(ModuleOp &module, PatternRewriter &rewriter, Location loc, const RuntimeAPIRegistry &apiRegistry, Value omTensorInputs, @@ -451,7 +429,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Verify the number of inputs. equalOrFailed(module, rewriter, loc, - create.llvm.constant(int64Ty, (int64_t)inputNum), + create.llvm.constant(int64Ty, static_cast(inputNum)), RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_OMTENSOR_LIST_SIZE, {omTensorInputs}), "Wrong number of input tensors: expect " + std::to_string(inputNum) + @@ -462,9 +440,9 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { RuntimeAPI::API::GET_OMT_ARRAY, {omTensorInputs}); for (int64_t i = 0; i < inputNum; ++i) { // Call API function to retrieve the i-th omTensor. - Value omTensorPtrAddr = - create.llvm.getElemPtr(getPointerType(context, opaquePtrTy), - opaquePtrTy, omTensorPtrArr, ArrayRef{(int32_t)i}); + Value omTensorPtrAddr = create.llvm.getElemPtr( + getPointerType(context, opaquePtrTy), opaquePtrTy, omTensorPtrArr, + ArrayRef{static_cast(i)}); Value omTensorPtr = create.llvm.load(opaquePtrTy, omTensorPtrAddr); // Verify data type. @@ -488,7 +466,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { auto JSONDimArray = JSONItem->getArray("dims"); int64_t rank = JSONDimArray->size(); equalOrFailed(module, rewriter, loc, - create.llvm.constant(int64Ty, (int64_t)rank), + create.llvm.constant(int64Ty, static_cast(rank)), RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_DATA_RANK, {omTensorPtr}), "Wrong rank for the input " + std::to_string(i) + ": expect " + @@ -499,9 +477,10 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { RuntimeAPI::API::GET_DATA_SHAPE, {omTensorPtr}); for (int d = 0; d < rank; ++d) { // Get actual dimension size. - Value actualDim = create.llvm.load(int64Ty, - create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty, - sizesArrayPtr, ArrayRef{(int32_t)d})); + Value actualDim = create.llvm.load( + int64Ty, create.llvm.getElemPtr(getPointerType(context, int64Ty), + int64Ty, sizesArrayPtr, + ArrayRef{static_cast(d)})); // Get reference dimension size. auto JSONDimValue = (*JSONDimArray)[d].getAsInteger(); assert(JSONDimValue && "failed to get value"); @@ -511,12 +490,13 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // In case that the reference dimension size is unknown, verify that // the actual dimension size is a non-negative value. create.llvm.ifThenElse(/*cond=*/ - [&](LLVMBuilder &createLLVM) { - Value zero = createLLVM.constant(int64Ty, (int64_t)d); + [&](const LLVMBuilder &createLLVM) { + Value zero = + createLLVM.constant(int64Ty, static_cast(d)); return createLLVM.icmp( LLVM::ICmpPredicate::slt, actualDim, zero); }, /*then=*/ - [&](LLVMBuilder &createLLVM) { + [&](const LLVMBuilder &createLLVM) { MultiDialectBuilder create( createLLVM); // Print an error message. @@ -533,7 +513,8 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { create.llvm.null(getI8PointerType(context))); }); } else { - Value referenceDim = create.llvm.constant(int64Ty, (int64_t)dim); + Value referenceDim = + create.llvm.constant(int64Ty, static_cast(dim)); equalOrFailed(module, rewriter, loc, referenceDim, actualDim, "Wrong size for the dimension " + std::to_string(d) + " of the input " + std::to_string(i) + ": expect " + diff --git a/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp index 667683d826..8f9fcf4b8a 100644 --- a/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp @@ -32,7 +32,7 @@ class KrnlFindIndexOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto findIndexOp = cast(op); + auto findIndexOp = mlir::cast(op); MLIRContext *ctx = findIndexOp.getContext(); Location loc = findIndexOp.getLoc(); KrnlFindIndexOpAdaptor operandAdaptor(operands); diff --git a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp index 7312be6e61..d4e0bfe861 100644 --- a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp @@ -5,7 +5,7 @@ //===------ KrnlInstrument.cpp - Lower KrnlInstrumentOp -------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -82,14 +82,14 @@ class KrnlInstrumentOpLowering : public ConversionPattern { else name.pop_back(); // remove last "-" Location newLoc = NameLoc::get(rewriter.getStringAttr(name)); - nodeName = cast(newLoc).getName(); + nodeName = mlir::cast(newLoc).getName(); } else if (auto fileLineColLoc = mlir::dyn_cast(loc)) { std::string filename = llvm::sys::path::filename(fileLineColLoc.getFilename().str()).str(); std::string name = filename + ":" + std::to_string(fileLineColLoc.getLine()); Location newLoc = NameLoc::get(rewriter.getStringAttr(name)); - nodeName = cast(newLoc).getName(); + nodeName = mlir::cast(newLoc).getName(); } else nodeName = StringRef("NOTSET"); LLVM_DEBUG( @@ -114,7 +114,7 @@ class KrnlInstrumentOpLowering : public ConversionPattern { SET_INSTRUMENT_OP_NAME_LEN(tagWithLen, opNameLen); SET_INSTRUMENT_NODE_NAME_LEN(tagWithLen, nodeNameLen); Value tag = create.llvm.constant( - IntegerType::get(context, 64), (int64_t)tagWithLen); + IntegerType::get(context, 64), static_cast(tagWithLen)); LLVM::GlobalOp globalStr = krnl::getOrCreateGlobalString( nodeName, loc, rewriter, parentModule, typeConverter); Value nodeNamePtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter); diff --git a/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp index d2aa7c1f35..b6e85ddf93 100644 --- a/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp @@ -4,7 +4,7 @@ //===------ KrnlMemcpy.cpp - Lower KrnlMemcpyOp ---------------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -91,7 +91,7 @@ class KrnlMemcpyOpLowering : public ConversionPattern { Value sizeInBytes = create.llvm.mul(elemsToCopy, eltSizeInBytes); // Is volatile (set to false). - Value isVolatile = create.llvm.constant(i1Ty, (int64_t)0); + Value isVolatile = create.llvm.constant(i1Ty, static_cast(0)); // Memcpy call create.llvm.call( diff --git a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp index bfccfda8be..4a005452f5 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp @@ -34,7 +34,7 @@ class KrnlPrintOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto printOp = cast(op); + auto printOp = mlir::cast(op); Location loc = printOp.getLoc(); KrnlPrintOpAdaptor operandAdaptor(operands); MultiDialectBuilder create(rewriter, loc); diff --git a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp index f254dc1074..776e97b246 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp @@ -4,7 +4,7 @@ //===------ KrnlPrintTensor.cpp - Lower KrnlPrintTensorOp ----------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -34,7 +34,7 @@ class KrnlPrintTensorOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto printTensorOp = cast(op); + auto printTensorOp = mlir::cast(op); MLIRContext *context = printTensorOp.getContext(); Location loc = printTensorOp.getLoc(); KrnlPrintTensorOpAdaptor operandAdaptor(operands); @@ -57,7 +57,8 @@ class KrnlPrintTensorOpLowering : public ConversionPattern { auto int64Ty = IntegerType::get(context, 64); auto memRefTy = mlir::dyn_cast(input.getType()); auto memRefRank = krnl::getRankFromMemRefType(memRefTy); - Value memRefRankVal = create.llvm.constant(int64Ty, (int64_t)memRefRank); + Value memRefRankVal = + create.llvm.constant(int64Ty, static_cast(memRefRank)); Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp index 0e2ece621c..e976b42b7f 100644 --- a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp @@ -41,7 +41,7 @@ class KrnlRandomNormalOpLowering : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { KrnlRandomNormalOpAdaptor operandAdaptor(operands); Location loc = op->getLoc(); - mlir::Type inType = op->getOperand(2).getType(); + Type inType = op->getOperand(2).getType(); MultiDialectBuilder create(rewriter, loc); // Get a symbol reference to the memcpy function, inserting it if necessary. diff --git a/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp b/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp new file mode 100644 index 0000000000..cd3738256a --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp @@ -0,0 +1,117 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlRoundEven.cpp - Lower KrnlRoundEvenOp ---------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlRoundEvenOp operator. +// +// Currently limited to fp32 integers, instructions supports other data types. +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlRoundEvenOpLowering : public ConversionPattern { +public: + explicit KrnlRoundEvenOpLowering( + LLVMTypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlRoundEvenOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + KrnlRoundEvenOp::Adaptor operandAdaptor(operands); + Value input = operandAdaptor.getIn(); + + // Scalar or Vector? + Type inputType = input.getType(); + Type inputElemType = getElementTypeOrSelf(inputType); + assert(mlir::isa(inputElemType) && "expected float"); + int64_t inputBitWidth = inputElemType.getIntOrFloatBitWidth(); + assert(inputBitWidth == 32 && "expected 32bit float"); + VectorType inputVecType = mlir::dyn_cast(inputType); + assert(VectorMachineSupport::requireCustomASM( + GenericOps::roundEvenGop, inputElemType) && + "expected custom requirement"); + // Common between scalar and vector + MultiDialectBuilder create(rewriter, loc); + Type i32Ty = rewriter.getI32Type(); + Type f32Ty = rewriter.getF32Type(); + + if (inputVecType) { + // Vector of 4 elements. + Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4); + Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4); + // Use integer as container for inputs. + Value inputVecI32 = create.llvm.bitcast(vecTypeI32, input); + SmallVector asmVals{inputVecI32}; + // SIMD ASM round to nearest even (M5=4) op + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " VFISB $0,$1,0,4 \n\t"; + const char *asmConstraints = "=v,v"; + Value outVecI32 = + rewriter + .create(loc, vecTypeI32, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + // Cast output back to float. + Value outVecF32 = create.llvm.bitcast(vecTypeF32, outVecI32); + rewriter.replaceOp(op, {outVecF32}); + return success(); + } else { + // Scalar types. + Type typeF32 = rewriter.getF32Type(); + SmallVector asmVals{input}; + // Scalar ASM round to the nearest even (M3=4) op. + // Note the spaces are required by the z/OS assembler. + const char *asmStr = " FIEBR $0,4,$1 \n\t"; + const char *asmConstraints = "=f,f"; + Value outF32 = + rewriter + .create(loc, typeF32, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + rewriter.replaceOp(op, {outF32}); + return success(); + } + llvm_unreachable("not supported"); + } +}; + +void populateLoweringKrnlRoundEvenOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp index 979eb3db50..a3f12a59c5 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp @@ -4,7 +4,7 @@ //===------ KrnlToLLVMHelper.cpp ------------------------------------------===// // -// Copyright 2022 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -24,6 +24,7 @@ #include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" @@ -142,7 +143,7 @@ void fillOMTensorWithMemRef(Value &outMemRef, Type elemTy, Value &outOMTensor, MultiDialectBuilder create(rewriter, loc); // Set ownership, i.e., free after OMTensor is destroyed. - Value owning = create.llvm.constant(int64Ty, (int64_t)outOwning); + Value owning = create.llvm.constant(int64Ty, static_cast(outOwning)); // Extract the allocated pointer. Value outMemRefAllocatedPtr = @@ -174,15 +175,16 @@ void fillOMTensorWithMemRef(Value &outMemRef, Type elemTy, Value &outOMTensor, for (decltype(rank) i = 0; i < rank; i++) { // Transfer size of dimension from memref to dynamic memref. Value dimSize = create.llvm.extractValue(int64Ty, outMemRef, {3, i}); - Value dimSizePtr = create.llvm.getElemPtr(getPointerType(context, int64Ty), - int64Ty, sizesArrayPtr, ArrayRef{(int32_t)i}); + Value dimSizePtr = + create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty, + sizesArrayPtr, ArrayRef{static_cast(i)}); create.llvm.store(dimSize, dimSizePtr); // Transfer stride of dimension from memref to dynamic memref. Value dimStride = create.llvm.extractValue(int64Ty, outMemRef, {4, i}); Value dimStridePtr = create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty, - stridesArrayPtr, ArrayRef{(int32_t)i}); + stridesArrayPtr, ArrayRef{static_cast(i)}); create.llvm.store(dimStride, dimStridePtr); } } @@ -253,14 +255,14 @@ FlatSymbolRefAttr getOrInsertStrncmp(OpBuilder &builder, ModuleOp module) { std::string a2e_s(std::string a_s) { std::string r(a_s); for (unsigned int i = 0; i < r.size(); i++) - r[i] = a2e[(int)r[i]]; + r[i] = a2e[static_cast(r[i])]; return r; } std::string e2a_s(std::string e_s) { std::string r(e_s); for (unsigned int i = 0; i < r.size(); i++) - r[i] = e2a[(int)r[i]]; + r[i] = e2a[static_cast(r[i])]; return r; } @@ -274,7 +276,7 @@ void emitErrNo(ModuleOp module, OpBuilder &builder, Location loc, int errCode) { module, StringRef("__errno_location"), int32PtrTy, {}); Value errNoPos = createLLVM.call(int32PtrTy, errnoSymbolRef, ArrayRef({})); - Value errNoVal = createLLVM.constant(int32Ty, (int64_t)errCode); + Value errNoVal = createLLVM.constant(int32Ty, static_cast(errCode)); createLLVM.store(errNoVal, errNoPos); } @@ -341,5 +343,47 @@ bool isZOS(ModuleOp module) { return zOS; } +void equalOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc, + Value lhs, Value rhs, std::string errorMsg, bool appendRHS) { + MLIRContext *context = rewriter.getContext(); + MultiDialectBuilder create(rewriter, loc); + create.llvm.ifThenElse(/*cond=*/ + [&](const LLVMBuilder &createLLVM) { + return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs); + }, /*then=*/ + [&](const LLVMBuilder &createLLVM) { + MultiDialectBuilder create(createLLVM); + // Print an error message. + if (!errorMsg.empty()) { + if (appendRHS) + create.krnl.printf( + StringRef(errorMsg), rhs, rewriter.getI64Type(), true); + else + create.krnl.printf(StringRef(errorMsg + "\n")); + } + // Set errno. + emitErrNo(module, rewriter, loc, EINVAL); + // Return NULL. + create.llvm._return(create.llvm.null(getI8PointerType(context))); + }); +} + +void equalOrReturn(ModuleOp &module, OpBuilder &rewriter, Location loc, + Value lhs, Value rhs, Value retVal, std::string errorMsg) { + MultiDialectBuilder create(rewriter, loc); + create.llvm.ifThenElse(/*cond=*/ + [&](const LLVMBuilder &createLLVM) { + return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs); + }, /*then=*/ + [&](const LLVMBuilder &createLLVM) { + MultiDialectBuilder create(createLLVM); + // Print an error message. + if (!errorMsg.empty()) + create.krnl.printf(StringRef(errorMsg + "\n")); + // Return retVal. + create.llvm._return(retVal); + }); +} + } // namespace krnl } // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp index ef616d12f3..470209144f 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp @@ -69,6 +69,16 @@ std::string e2a_s(std::string e_s); void emitErrNo(mlir::ModuleOp module, mlir::OpBuilder &builder, mlir::Location loc, int err); +/// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`. +void equalOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter, + mlir::Location loc, mlir::Value lhs, mlir::Value rhs, + std::string errorMsg = "", bool appendRHS = true); + +/// Emit code for `IF lhs != rhs THEN return retVal ELSE do nothing`. +void equalOrReturn(mlir::ModuleOp &module, mlir::OpBuilder &rewriter, + mlir::Location loc, mlir::Value lhs, mlir::Value rhs, mlir::Value retVal, + std::string errorMsg = ""); + /// Creates an LLVM pointer type with the given element type and address space. /// This function is meant to be used in code supporting both typed and opaque /// pointers, as it will create an opaque pointer with the given address space diff --git a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp index abb5120592..2a0ee747c7 100644 --- a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp @@ -37,7 +37,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "erff"; if (type.isF64()) @@ -48,7 +48,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "acosf"; if (type.isF64()) @@ -59,7 +59,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "acoshf"; if (type.isF64()) @@ -70,7 +70,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "asinf"; if (type.isF64()) @@ -81,7 +81,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "asinhf"; if (type.isF64()) @@ -92,7 +92,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "atanf"; if (type.isF64()) @@ -103,7 +103,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "tanf"; if (type.isF64()) @@ -114,7 +114,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) return "atanhf"; if (type.isF64()) @@ -125,7 +125,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) #if (__APPLE__) return "__isinff"; @@ -140,7 +140,7 @@ struct MathFunctionName { template <> struct MathFunctionName { - static std::string functionName(mlir::Type type) { + static std::string functionName(Type type) { if (type.isF32()) #if (__APPLE__) @@ -168,9 +168,9 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { Location loc = op->getLoc(); // get the LLVM type for the function args and result - mlir::Type inType = op->getOperand(0).getType(); - mlir::Type outType = op->getResultTypes().front(); - mlir::Type llvmInType, llvmOutType; + Type inType = op->getOperand(0).getType(); + Type outType = op->getResultTypes().front(); + Type llvmInType, llvmOutType; if (inType.isF16()) llvmInType = FloatType::getF16(context); else if (inType.isF32()) @@ -207,16 +207,16 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { // declare float (float) // FlatSymbolRefAttr getOrInsertUnaryMathFunction(PatternRewriter &rewriter, - ModuleOp module, std::string mathFuncName, mlir::Type llvmInType, - mlir::Type llvmOutType) const { + ModuleOp module, std::string mathFuncName, Type llvmInType, + Type llvmOutType) const { auto *context = module.getContext(); if (module.lookupSymbol(mathFuncName)) return SymbolRefAttr::get(context, mathFuncName); // Create function declaration. // auto llvmF32Ty = FloatType::get(context); - auto llvmFnType = LLVM::LLVMFunctionType::get( - llvmOutType, ArrayRef({llvmInType})); + auto llvmFnType = + LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef({llvmInType})); // Insert the unary math function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp index 256d00572a..62d7c25de3 100644 --- a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp @@ -41,7 +41,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto krnlVectorTypeCastOp = cast(op); + auto krnlVectorTypeCastOp = mlir::cast(op); MemRefType sourceType = mlir::cast(krnlVectorTypeCastOp.getOperand().getType()); MemRefType targetType = krnlVectorTypeCastOp.getType(); diff --git a/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp b/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp index 5d810aa6eb..63146dbbba 100644 --- a/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp +++ b/src/Conversion/KrnlToLLVM/RuntimeAPI.cpp @@ -64,6 +64,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry( : registry() { MLIRContext *context = module.getContext(); auto voidTy = LLVM::LLVMVoidType::get(context); + Type int1Ty = IntegerType::get(context, 1); auto int8Ty = IntegerType::get(context, 8); auto opaquePtrTy = onnx_mlir::krnl::getPointerType(context, int8Ty); auto opaquePtrPtrTy = onnx_mlir::krnl::getPointerType(context, opaquePtrTy); @@ -88,7 +89,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry( RuntimeAPI(API::GET_OMT_ARRAY, "omTensorListGetOmtArray", opaquePtrPtrTy, {opaquePtrTy}), RuntimeAPI(API::PRINT_OMTENSOR, "omTensorPrint", voidTy, {opaquePtrTy, opaquePtrTy}), RuntimeAPI(API::GET_OMTENSOR_LIST_SIZE, "omTensorListGetSize", int64Ty, {opaquePtrTy}), - RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", voidTy, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}), + RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", int1Ty, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}), RuntimeAPI(API::GET_EXTERNAL_CONSTANT_ADDR, "omGetExternalConstantAddr", voidTy, {opaquePtrPtrTy, opaquePtrPtrTy, int64Ty}), }; // clang-format on diff --git a/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp b/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp index 66e90ca329..2657589e97 100644 --- a/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp +++ b/src/Conversion/ONNXToKrnl/Additional/LayoutTransform.cpp @@ -126,8 +126,8 @@ struct ONNXLayoutTransformOpLowering } // Outer loop (E1 iterates over tiles of 64 elements). - create.krnl.iterateIE( - loopDefs, loopDefs, lbs, ubs, [&](KrnlBuilder &b, ValueRange loopInd) { + create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs, + [&](const KrnlBuilder &b, ValueRange loopInd) { MDBuilder create(b); IndexExprScope outerScope(create.krnl); DimsExpr outerIndices; @@ -154,7 +154,7 @@ struct ONNXLayoutTransformOpLowering create.krnl.memcpy(alloc, input, len, allocOffset, inputOffset); } else { // Compute if we have a last tile. - IndexExpr modLit = LiteralIndexExpr(modVal); + IndexExpr modLit = LitIE(modVal); IndexExpr isFull = create.krnlIE.isTileFull(memAF[E1], modLit, SymIE(ub1)); IndexExpr isFullLogical = isFull >= 0; @@ -162,13 +162,13 @@ struct ONNXLayoutTransformOpLowering // Condition isFullLogical.getValue(), // Then (is full). - [&](SCFBuilder b) { + [&](const SCFBuilder b) { MDBuilder create(b); create.krnl.memcpy( alloc, input, len, allocOffset, inputOffset); }, // Else, we don't have a full tile. - [&](SCFBuilder b) { + [&](const SCFBuilder b) { MDBuilder create(b); IndexExprScope middleScope(b, &outerScope); IndexExpr tripCount = SymIE(ub1) - SymIE(memAF[E1]); @@ -218,7 +218,7 @@ struct ONNXLayoutTransformOpLowering IndexExprScope outerScope(create.krnl); SmallVector ubs; create.krnlIE.getShapeAsDims(data, ubs); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); // Insert an allocation and deallocation for the result of this // operation. @@ -263,7 +263,7 @@ struct ONNXLayoutTransformOpLowering } } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { // Simply copy the input into the output. Value val = createKrnl.load(data, indices); createKrnl.store(val, alloc, indices); diff --git a/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp b/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp index 1f13fd0656..82f5212596 100644 --- a/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp +++ b/src/Conversion/ONNXToKrnl/Additional/ShapeTransform.cpp @@ -55,12 +55,12 @@ struct ONNXShapeTransformOpLowering : public ConversionPattern { // Element-wise moving of data. ValueRange loopDef = create.krnl.defineLoops(inputRank); - SmallVector lbs(inputRank, LiteralIndexExpr(0)); + SmallVector lbs(inputRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(input, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange inputIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange inputIndices) { Value loadedVal = createKrnl.load(input, inputIndices); // Compute output indices by using affine map. SmallVector outputIndices; diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index f5faedf2a5..6a68f3cf2a 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -11,6 +11,7 @@ add_onnx_mlir_library(OMONNXToKrnl ControlFlow/If.cpp ControlFlow/Loop.cpp ControlFlow/Scan.cpp + ControlFlow/Yield.cpp ConvertONNXToKrnl.cpp ML/CategoryMapper.cpp Math/CumSum.cpp diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/If.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/If.cpp index a3b2677152..0edbb4bd62 100644 --- a/src/Conversion/ONNXToKrnl/ControlFlow/If.cpp +++ b/src/Conversion/ONNXToKrnl/ControlFlow/If.cpp @@ -54,14 +54,6 @@ struct ONNXIfOpLowering : public OpConversionPattern { rewriter.eraseBlock(&scfBranch.back()); scfBranch.takeBody(graph); - rewriter.setInsertionPointToEnd(&scfBranch.back()); - - Operation *yieldOp = scfBranch.back().getTerminator(); - llvm::SmallVector outputs; - if (failed(rewriter.getRemappedValues(yieldOp->getOperands(), outputs))) { - llvm_unreachable("failed to convert branch return values"); - } - rewriter.replaceOpWithNewOp(yieldOp, outputs); } }; diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp index 01249466ff..5eddef0d66 100644 --- a/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp +++ b/src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp @@ -100,7 +100,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { ValueRange loopDef = createKrnl.defineLoops(1); Value zero = create.math.constantIndex(0); createKrnl.iterate(loopDef, loopDef, {zero}, {maxTripCount}, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { OpBuilder::InsertionGuard insertGuard(rewriter); Value condReg = createKrnl.load(cond); @@ -281,7 +281,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // Here loop is assumed to be executed at least once. Value firstElement = create.krnl.load(output, create.math.constantIndex(0)); - SmallVector allocParams; + SmallVector allocParams; SmallVector dims; dims.emplace_back( mlir::cast(output.getType()).getShape()[0]); @@ -303,7 +303,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { KrnlBuilder createKrnl(rewriter, loc); ValueRange loopDef = createKrnl.defineLoops(1); createKrnl.iterate(loopDef, loopDef, {zero}, {maxTripCount}, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Wrap with KrnlRegionOp because emitCopy uses the result of // SeqExtract for loop bound. KrnlRegionOp regionOp = rewriter.create(loc); @@ -328,10 +328,10 @@ struct ONNXLoopOpLowering : public OpConversionPattern { return success(); } - void allocateMemoryForVFinal(mlir::Location loc, + void allocateMemoryForVFinal(Location loc, ConversionPatternRewriter &rewriter, Operation *op, - ONNXLoopOpAdaptor adaptor, SmallVectorImpl &outputs) const { - auto loopOp = dyn_cast(op); + ONNXLoopOpAdaptor adaptor, SmallVectorImpl &outputs) const { + auto loopOp = mlir::dyn_cast(op); for (const auto &ioPair : llvm::zip(adaptor.getVInitial(), loopOp.v_final())) { auto vInit = std::get<0>(ioPair); @@ -356,11 +356,11 @@ struct ONNXLoopOpLowering : public OpConversionPattern { } } - void allocateMemoryForScanOutput(mlir::Location loc, + void allocateMemoryForScanOutput(Location loc, ConversionPatternRewriter &rewriter, Operation *op, - ONNXLoopOpAdaptor adaptor, SmallVectorImpl &outputs, + ONNXLoopOpAdaptor adaptor, SmallVectorImpl &outputs, bool isWhile = false) const { - auto loopOp = dyn_cast(op); + auto loopOp = mlir::dyn_cast(op); for (const auto &opScanOutput : loopOp.scan_outputs()) { // Convert opScanOutput's type to MemRefType. Type convertedType = typeConverter->convertType(opScanOutput.getType()); @@ -380,7 +380,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { alloc = create.mem.alignedAlloc(memRefType); else { auto rankedScanOutTy = memRefType; - SmallVector allocParams; + SmallVector allocParams; // Check the loop accumulation dimension if (rankedScanOutTy.isDynamicDim(0)) { @@ -452,11 +452,11 @@ struct ONNXLoopOpLowering : public OpConversionPattern { if (srcTy.getRank() > 0) { IndexExprScope childScope(create.krnl); ValueRange loopDef = create.krnl.defineLoops(srcTy.getRank()); - SmallVector lbs(srcTy.getRank(), LiteralIndexExpr(0)); + SmallVector lbs(srcTy.getRank(), LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(src, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { SmallVector writeIV( writePrefix.begin(), writePrefix.end()); writeIV.insert(writeIV.end(), loopInd.begin(), loopInd.end()); @@ -481,7 +481,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { // iteration variable bool isWhileLoop(Operation *op) const { - auto onnxLoopOp = dyn_cast(op); + auto onnxLoopOp = mlir::dyn_cast(op); // Check whether continue condition is modified or not // Code copied from src/Dialect/ONNX/Rewrite.cpp @@ -517,7 +517,7 @@ struct ONNXLoopOpLowering : public OpConversionPattern { LogicalResult rewriteWithSCFWhile(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) const { Location loc = ONNXLoc(op); - auto loopOp = dyn_cast(op); + auto loopOp = mlir::dyn_cast(op); MultiDialectBuilder create( rewriter, loc); diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp index 54b992c69c..62ed56c63d 100644 --- a/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp +++ b/src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp @@ -4,7 +4,7 @@ //===-------------------- Scan.cpp - Lowering Scan Op ---------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -198,11 +198,11 @@ struct ONNXScanOpLowering : public OpConversionPattern { return success(); } - static void allocateMemoryForVFinal(mlir::Location loc, + static void allocateMemoryForVFinal(Location loc, ConversionPatternRewriter &rewriter, const TypeConverter *typeConverter, Operation *op, ONNXScanOpAdaptor adaptor, - SmallVectorImpl &outputs) { - auto scanOp = dyn_cast(op); + SmallVectorImpl &outputs) { + auto scanOp = mlir::dyn_cast(op); for (const auto &ioPair : llvm::zip(scanOp.getVInitial(), scanOp.v_final())) { auto vInit = std::get<0>(ioPair); @@ -223,11 +223,11 @@ struct ONNXScanOpLowering : public OpConversionPattern { } } - static void allocateMemoryForScanOutput(mlir::Location loc, + static void allocateMemoryForScanOutput(Location loc, ConversionPatternRewriter &rewriter, const TypeConverter *typeConverter, Operation *op, ONNXScanOpAdaptor adaptor, - SmallVectorImpl &outputs) { - auto scanOp = dyn_cast(op); + SmallVectorImpl &outputs) { + auto scanOp = mlir::dyn_cast(op); for (const auto &opScanOutput : scanOp.scan_outputs()) { // Convert opScanOutput's type to MemRefType. Type convertedType = typeConverter->convertType(opScanOutput.getType()); @@ -248,7 +248,7 @@ struct ONNXScanOpLowering : public OpConversionPattern { MemRefBuilder createMemRef(rewriter, loc); OnnxBuilder onnxBuilder(rewriter, loc); auto rankedScanOutTy = memRefType; - SmallVector allocParams; + SmallVector allocParams; for (int i = 0; i < rankedScanOutTy.getRank(); i++) { if (rankedScanOutTy.isDynamicDim(i)) { if (i == 0) { @@ -274,9 +274,9 @@ struct ONNXScanOpLowering : public OpConversionPattern { } } - static mlir::Value allocateMemoryForBodyScanInput(mlir::Location loc, + static Value allocateMemoryForBodyScanInput(Location loc, ConversionPatternRewriter &rewriter, const TypeConverter *typeConverter, - mlir::Type bodyScanInputTy) { + Type bodyScanInputTy) { // Convert type to MemRefType. Type convertedType = typeConverter->convertType(bodyScanInputTy); assert(convertedType && mlir::isa(convertedType) && @@ -317,11 +317,11 @@ struct ONNXScanOpLowering : public OpConversionPattern { if (srcTy.getRank() > 0) { IndexExprScope childScope(create.krnl); ValueRange loopDef = create.krnl.defineLoops(srcTy.getRank()); - SmallVector lbs(srcTy.getRank(), LiteralIndexExpr(0)); + SmallVector lbs(srcTy.getRank(), LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(src, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { SmallVector writeIV( writePrefix.begin(), writePrefix.end()); writeIV.insert(writeIV.end(), loopInd.begin(), loopInd.end()); @@ -343,17 +343,17 @@ struct ONNXScanOpLowering : public OpConversionPattern { SmallVector readIV(readPrefix.begin(), readPrefix.end()); MultiDialectBuilder create( builder, loc); - if ((size_t)srcTy.getRank() > readIV.size()) { + if (static_cast(srcTy.getRank()) > readIV.size()) { IndexExprScope childScope(create.krnl); ValueRange loopDef = create.krnl.defineLoops(srcTy.getRank() - readPrefix.size()); SmallVector lbs( - srcTy.getRank() - readPrefix.size(), LiteralIndexExpr(0)); + srcTy.getRank() - readPrefix.size(), LitIE(0)); SmallVector ubs; for (int i = readIV.size(); i < srcTy.getRank(); i++) ubs.emplace_back(create.krnlIE.getShapeAsDim(src, i)); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { readIV.insert(readIV.end(), loopInd.begin(), loopInd.end()); Value val = createKrnl.load(src, readIV); createKrnl.store(val, dest, loopInd); diff --git a/src/Conversion/ONNXToKrnl/ControlFlow/Yield.cpp b/src/Conversion/ONNXToKrnl/ControlFlow/Yield.cpp new file mode 100644 index 0000000000..df6b03b5ce --- /dev/null +++ b/src/Conversion/ONNXToKrnl/ControlFlow/Yield.cpp @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===--------------------- Yield.cpp - Lowering Yield Op ------------------===// +// +// Copyright 2019-2023 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Yield Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +struct ONNXYieldOpLowering : public OpConversionPattern { + ONNXYieldOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(ONNXYieldOp yieldOp, ONNXYieldOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // Gather info. + Operation *op = yieldOp.getOperation(); + Location loc = ONNXLoc(op); + + MultiDialectBuilder create( + rewriter, loc); + + ValueRange inputs = yieldOp.getOperands(); + llvm::SmallVector outputs; + for (Value input : inputs) { + Type inputType = input.getType(); + Type outputType = typeConverter->convertType(inputType); + outputs.emplace_back(typeConverter->materializeTargetConversion( + rewriter, loc, outputType, input)); + } + + rewriter.replaceOpWithNewOp(yieldOp, outputs); + + onnxToKrnlSimdReport(op); + return success(); + } +}; + +void populateLoweringONNXYieldOpPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 16bab29f5c..ffca3130bc 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -47,7 +47,7 @@ class ONNXEntryPointLowering : public OpRewritePattern { StringRef entryPointName = funcRefAttr.getLeafReference().getValue(); Operation *entryPointOp = module.lookupSymbol(entryPointName); assert(entryPointOp && "entry point name not found!"); - func::FuncOp entryPointFunc = cast(entryPointOp); + func::FuncOp entryPointFunc = mlir::cast(entryPointOp); IntegerAttr numInputsAttr = rewriter.getI32IntegerAttr(entryPointFunc.getArgumentTypes().size()); @@ -190,7 +190,7 @@ std::map ONNXEntryPointLowering::typeMap = { void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, DimAnalysis *dimAnalysis, bool enableTiling, bool enableSIMD, bool enableParallel, - std::string opsForCall) { + bool enableFastMath, std::string opsForCall) { // clang-format off // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is a ranked tensor. @@ -203,6 +203,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, populateLoweringONNXIfOpPattern(patterns, typeConverter, ctx); populateLoweringONNXLoopOpPattern(patterns, typeConverter, ctx); populateLoweringONNXScanOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXYieldOpPattern(patterns, typeConverter, ctx); // Math populateLoweringONNXCumSumOpPattern(patterns, typeConverter, ctx); populateLoweringONNXDFTOpPattern(patterns, typeConverter, ctx); @@ -223,8 +224,8 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, // ObjectDetection populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx); // Quantization - populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel); - populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel); + populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel, enableFastMath); + populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel, enableFastMath); // Tensor populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx); populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx); @@ -308,12 +309,13 @@ struct FrontendToKrnlLoweringPass FrontendToKrnlLoweringPass(const FrontendToKrnlLoweringPass &pass) : PassWrapper>() {} FrontendToKrnlLoweringPass(bool enableTiling, bool enableSIMD, - bool enableParallel, std::string opsForCall) { + bool enableParallel, bool enableFastMath, std::string opsForCall) { // Below, need explicit assignment to enable implicit conversion of bool to // Option. this->enableTiling = enableTiling; this->enableSIMD = enableSIMD; this->enableParallel = enableParallel; + this->enableFastMath = enableFastMath; this->opsForCall = opsForCall; } @@ -342,6 +344,8 @@ struct FrontendToKrnlLoweringPass llvm::cl::desc("Enable SIMD code gen"), llvm::cl::init(false)}; Option enableParallel{*this, "enable-parallel", llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)}; + Option enableFastMath{*this, "enable-fast-math", + llvm::cl::desc("Enable fast math optimizations"), llvm::cl::init(false)}; Option opsForCall{*this, "ops-for-call", llvm::cl::desc("Specify ops to be lowered to krnl.call"), llvm::cl::init("")}; @@ -372,20 +376,6 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // canonicalization after the lowering. target.addLegalOp<::mlir::ONNXNoneOp>(); - // Use krnl.load/store instead of std.load/store and affine.load/store. - // krnl.load/store will be lowered to std.load/store and affine.load/store - // by `convert-krnl-to-affine` pass. - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - // Memref builder can use affine stores, it would be awkward for it to - // generate Krnl stores as mem builder is part of MLIR. Thus the affine - // stores should not be illegal here. Since affine loads are still illegal, - // the regular krnl lowering will most likely trigger errors if non krnl mem - // ops where generally used. - // - // target.addIllegalOp(); - // Option`emitDealloc` is deprecated and turned off, make sure we don't have // buffer deallocation at this level. Will use MLIR buffer-deallocation for // this purpose instead. However, since the SequenceErase needs to emit @@ -443,7 +433,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { // Define patterns. populateONNXToKrnlConversionPattern(patterns, krnlTypeConverter, &getContext(), dimAnalysis, enableTiling, enableSIMD, enableParallel, - opsForCall); + enableFastMath, opsForCall); // Rewrite patterns for accelerators. for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) @@ -463,9 +453,9 @@ std::unique_ptr createLowerToKrnlPass() { } std::unique_ptr createLowerToKrnlPass(bool enableTiling, bool enableSIMD, - bool enableParallel, std::string opsForCall) { + bool enableParallel, bool enableFastMath, std::string opsForCall) { return std::make_unique( - enableTiling, enableSIMD, enableParallel, opsForCall); + enableTiling, enableSIMD, enableParallel, enableFastMath, opsForCall); } //===----------------------------------------------------------------------===// diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index 5b6e561a78..565e63a7d7 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -140,11 +140,11 @@ struct ONNXCategoryMapperOpLowering create.krnlIE.getShapeAsDims(X, ubs); if (emitPrintStmts) - create.krnl.printTensor("Input tensor:\n", X); + create.krnl.printTensor("Input tensor:%s%d%e", X); ValueRange loopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Determine the index of 'inputElem' in the perfect hash table // 'pHash'. Note: the index might not be valid (this happens // when the 'inputElem' is not present in the perfect hash @@ -253,7 +253,7 @@ struct ONNXCategoryMapperOpLowering } Value loadElement(Value memref, ValueRange loopInd, Type elementType, - int64_t rank, KrnlBuilder &createKrnl) const { + int64_t rank, const KrnlBuilder &createKrnl) const { Value inputElem; TypeSwitch(elementType) .Case( diff --git a/src/Conversion/ONNXToKrnl/Math/CumSum.cpp b/src/Conversion/ONNXToKrnl/Math/CumSum.cpp index 195bdf0285..017d4f2b9f 100644 --- a/src/Conversion/ONNXToKrnl/Math/CumSum.cpp +++ b/src/Conversion/ONNXToKrnl/Math/CumSum.cpp @@ -4,7 +4,7 @@ //===-------------- CumSum.cpp - Lowering CumSum Ops ----------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -123,14 +123,14 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { IndexExpr axisIE = create.krnlIE.getIntFromArrayAsSymbol(axis, 0); if (axisIE.isUndefined()) return op->emitError("axis parameter could not be processed"); - axisIE = axisIE.selectOrSelf(axisIE < 0, axisIE + LiteralIndexExpr(rank)); + axisIE = axisIE.selectOrSelf(axisIE < 0, axisIE + LitIE(rank)); // Insert an allocation and deallocation for the result of this operation. Value resMemRef = create.mem.alignedAlloc(X, memRefType); Value bufMemRef = create.mem.alignedAlloc(X, memRefType); // Get the size of dimension 'axis'. - IndexExpr axisSize = LiteralIndexExpr(-1); + IndexExpr axisSize = LitIE(-1); for (uint64_t i = 0; i < rank; ++i) axisSize = IndexExpr::select(axisIE == i, xDims[i], axisSize); @@ -138,8 +138,8 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { IndexExpr numberOfStep; if (axisSize.isLiteral()) { int64_t n = axisSize.getLiteral(); - int64_t logN = (int64_t)std::ceil(std::log2(n)); - numberOfStep = LiteralIndexExpr(logN); + int64_t logN = static_cast(std::ceil(std::log2(n))); + numberOfStep = LitIE(logN); } else { Value nos = create.math.cast(f32Ty, axisSize.getValue()); // Use this when math::CeilOp is available in MLIR. @@ -147,8 +147,8 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { nos = create.math.log2(nos); nos = create.math.cast(i64Ty, nos); // Use this when math::CeilOp is available in MLIR. - // numberOfStep = SymbolIndexExpr(nos); - numberOfStep = SymbolIndexExpr(nos) + LiteralIndexExpr(1); + // numberOfStep = SymIE(nos); + numberOfStep = SymIE(nos) + LitIE(1); } // Input and output have the same shape, so they share the bounds. @@ -159,7 +159,7 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { // Initialize the temporary buffer: copy values from the input. ValueRange initLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(initLoopDef, initLoopDef, lbs, ubs, - [&](KrnlBuilder &ck, ValueRange initLoopInd) { + [&](const KrnlBuilder &ck, ValueRange initLoopInd) { MultiDialectBuilder create(ck); if (!exclusive) { Value x = create.krnl.load(X, initLoopInd); @@ -190,9 +190,8 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { }); // Outer loop iterates over the number of steps. - ValueRange stepLoopDef = create.krnl.defineLoops(1); - create.krnl.iterateIE(stepLoopDef, stepLoopDef, {zeroIE}, {numberOfStep}, - [&](KrnlBuilder &ck, ValueRange stepLoopInd) { + create.krnl.forLoopIE(zeroIE, numberOfStep, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder &ck, ValueRange stepLoopInd) { MultiDialectBuilder create(ck); // Compute index offset: offset = 2^step. @@ -210,7 +209,7 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { // y[i,k] = buf[i,k] ValueRange sumLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(sumLoopDef, sumLoopDef, lbs, ubs, - [&](KrnlBuilder &ck, ValueRange sumLoopInd) { + [&](const KrnlBuilder &ck, ValueRange sumLoopInd) { IndexExprScope ieScope(ck); MultiDialectBuilder create(ck); Value axis = axisIE.getValue(); @@ -231,7 +230,7 @@ struct ONNXCumSumOpLowering : public OpConversionPattern { // buf = y ValueRange bufLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(bufLoopDef, bufLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange bufLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange bufLoopInd) { Value x = createKrnl.load(resMemRef, bufLoopInd); createKrnl.store(x, bufMemRef, bufLoopInd); }); diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index a1c4aaa35e..1a11e22e50 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -28,6 +28,82 @@ using namespace mlir; namespace onnx_mlir { +// Check the input, x, can be reused as the output buffer +bool isBufferReusable(Value x, MemRefType outputType) { + if (!x.hasOneUse()) + return false; + + Type xType = x.getType(); + auto inputType = dyn_cast(xType); + if (!inputType) + return false; + // Currently, only static shape could be reused. + // ToFix: use DimAnalysis to handle dynamic shape. + if (!hasStaticShape(inputType)) + return false; + if (!hasStaticShape(outputType)) + return false; + + // Currently reuse requires that the shape has to be the same. + // ToFix: If the shape is not the same, memref.cast can be used. + if (getRank(inputType) != getRank(outputType)) + return false; + for (int64_t i = 0; i < getRank(inputType); i++) { + if (inputType.getShape()[i] != outputType.getShape()[i]) + return false; + } + + // ToFix: The simd padding is not checked + // We did not record whether the memref is padded or not. + // The padding added to the memref the as an attribute, or not needed. + return true; +} + +// Traverse the operands to find the candidate for buffer reuse. +// Return -1, if no candidate is found. +int whichBufferToReuse(ValueRange values, MemRefType outputType) { + for (size_t i = 0; i < values.size(); i++) { + if (isBufferReusable(values[i], outputType)) + return i; + } + return -1; +} + +// Allocate memref (as before) if no input buffer can be reused. +// Default VL=0 is used for non SIMD allocation +Value allocOrReuse(MemRefBuilder &create, Operation *op, + ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, + int64_t alignment, int64_t VL = 0); + +Value allocOrReuse(MemRefBuilder &create, Operation *op, + ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims, + int64_t alignment, int64_t VL) { + + int indexToReuse = -1; + // By default, enableKrnlBufferReuse is false. Simply allocate a memref. + if (enableKrnlBufferReuse) { + // Be aware to use the op->getOperands() to check the number of uses. + // After buffer reuse, the number of uses of the transformed Value, + // generatedOperands, will increase. + indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType); + } + + if (indexToReuse != -1) { + int size = getSizeInBytes(outputMemRefType); + LLVM_DEBUG({ + llvm::dbgs() << " malloc_size " << size << "\n"; + op->dump(); + }); + return generatedOperands[indexToReuse]; + } else { + if (VL == 0) + return create.alignedAlloc(outputMemRefType, dims, alignment); + else + return create.alignedAllocWithSimdPadding( + outputMemRefType, dims, VL, alignment); + } +} + // ============================================================================= /// Emit post-processing for variadic element-wise ops. @@ -304,7 +380,7 @@ struct ScalarOp { template <> GenOpMix getGenOpMix(Type t, Operation *op) { - StringRef approximate = dyn_cast(op).getApproximate(); + StringRef approximate = mlir::dyn_cast(op).getApproximate(); if (approximate.equals_insensitive("none")) return {{GenericOps::ArithmeticGop, 1}, {GenericOps::ErfGop, 1}, {GenericOps::MulGop, 3}}; @@ -327,7 +403,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // "none". "approximate = none" simply implies no approximation will take // place. However, "approximate" can also have a string value of "tanh" which // indicates the use of tanh approximation. - StringRef approximate = dyn_cast(op).getApproximate(); + StringRef approximate = mlir::dyn_cast(op).getApproximate(); // Local constants Value half = create.math.constant(elementType, 0.5); @@ -388,8 +464,10 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Value negInf = create.math.negativeInf(inputType); Value posInf = create.math.positiveInf(inputType); - double detectNegAttribute = dyn_cast(op).getDetectNegative(); - double detectPosAttribute = dyn_cast(op).getDetectPositive(); + double detectNegAttribute = + mlir::dyn_cast(op).getDetectNegative(); + double detectPosAttribute = + mlir::dyn_cast(op).getDetectPositive(); // Three different cases: Infinity, Negative Infinity and Positive Infinity bool detectInf = detectPosAttribute == 1 && detectNegAttribute == 1; @@ -551,8 +629,10 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // Constant 1) CheckIfCustomScalarOpIsSupported(elementType); Value operand = scalarOperands[0]; - double alphaLit = dyn_cast(op).getAlpha().convertToFloat(); - double betaLit = dyn_cast(op).getBeta().convertToFloat(); + double alphaLit = + mlir::dyn_cast(op).getAlpha().convertToFloat(); + double betaLit = + mlir::dyn_cast(op).getBeta().convertToFloat(); // Create constants. MultiDialectBuilder create(rewriter, loc); Value zero = create.math.constant(elementType, 0); @@ -591,7 +671,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // %X) CheckIfCustomScalarOpIsSupported(elementType); Value operand = scalarOperands[0]; - double alphaLit = dyn_cast(op).getAlpha().convertToFloat(); + double alphaLit = mlir::dyn_cast(op).getAlpha().convertToFloat(); MultiDialectBuilder create(rewriter, loc); Value zero = create.math.constant(elementType, 0); Value one = create.math.constant(elementType, 1); @@ -651,7 +731,8 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // %X) CheckIfCustomScalarOpIsSupported(elementType); Value operand = scalarOperands[0]; - double alphaLit = dyn_cast(op).getAlpha().convertToFloat(); + double alphaLit = + mlir::dyn_cast(op).getAlpha().convertToFloat(); MultiDialectBuilder create(rewriter, loc); Value zero = create.math.constant(elementType, 0); auto alpha = create.math.constant(elementType, alphaLit); @@ -717,8 +798,8 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, // alpha))) CheckIfCustomScalarOpIsSupported(elementType); Value operand = scalarOperands[0]; - double alphaLit = dyn_cast(op).getAlpha().convertToFloat(); - double gammaLit = dyn_cast(op).getGamma().convertToFloat(); + double alphaLit = mlir::dyn_cast(op).getAlpha().convertToFloat(); + double gammaLit = mlir::dyn_cast(op).getGamma().convertToFloat(); MultiDialectBuilder create(rewriter, loc); Value zero = create.math.constant(elementType, 0); Value alpha = create.math.constant(elementType, alphaLit); @@ -1204,11 +1285,27 @@ struct ScalarOp { using IOp = NotSuportedScalarOp; }; +// Keep in sync with with KrnlBuilder::roundEven algorithm. template <> GenOpMix getGenOpMix(Type t, Operation *op) { - return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2}, - {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + // Custom? + Type inputType = op->getOperand(0).getType(); + if (VectorMachineSupport::requireCustomASM( + GenericOps::roundEvenGop, getElementTypeOrSelf(inputType))) + return {{GenericOps::ArithmeticGop, 1}}; + + // Change depending on whether KrnlBuilder use roundEven or + // RoundEvenEmulation. + bool useEmulation = true; + if (useEmulation) + return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 2}, + {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, + {GenericOps::FloorGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, + 8 /* Little parallelism in code. */}}; + + // Assume here that there is a hw op to handle this. + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -1216,9 +1313,9 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { Value x = scalarOperands[0]; - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create(rewriter, loc); CheckIfCustomScalarOpIsSupported(elementType); - return create.math.round(x); + return create.krnl.roundEven(x); } //===----------------------------------------------------------------------===// @@ -1277,9 +1374,15 @@ Value emitScalarOpFor( Value scaleFloat = scalarOperands[1]; Value zeroPointInt = scalarOperands[2]; - Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); Value xFloat = create.math.cast(elementType, XInt); - Value sub = create.math.sub(xFloat, zeroPointFloat); + + Value sub; + if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) { + Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); + sub = create.math.sub(xFloat, zeroPointFloat); + } else { + sub = xFloat; + } Value res = create.math.mul(sub, scaleFloat); return res; } @@ -1318,14 +1421,14 @@ static LogicalResult getPartiallyFlattenedSimdCode( IndexExprScope allocScope(create.vec, shapeHelper->getScope()); DimsExpr outputDims; getIndexExprList(shapeHelper->getOutputDims(), outputDims); - // Alloc memory with padding for SIMD. + // Reuse the buffer from the input, or Alloc memory with padding for SIMD. // For the moment, its ok to go here; if we truly have partial flattening of // the simd code, then we only do it with static memref size that are // multiples of VL * unrollVL, so there should be no padding anyway. This // will change if we do partial flattening with non-multiple of VL * // unrollVL. - Value alloc = create.mem.alignedAllocWithSimdPadding( - outputMemRefType, outputDims, VL, alignment); + Value alloc = allocOrReuse( + create.mem, op, operands, outputMemRefType, outputDims, alignment, VL); // Create flat inputs in the last innerDinNum dims. llvm::SmallVector flatOperands; for (Value oper : operands) { @@ -1337,8 +1440,8 @@ static LogicalResult getPartiallyFlattenedSimdCode( DimsExpr operDims, flattenOperDims; create.krnlIE.getShapeAsSymbols(oper, operDims); // Because we fully fuse 1x1x128xf32 and 128xf32, the - // collapsedInnermostLoops may be higher than the rank of this input. Adjust - // collapsedInnermostLoops accordingly for the flatten below. + // collapsedInnermostLoops may be higher than the rank of this input. + // Adjust collapsedInnermostLoops accordingly for the flatten below. int64_t currRank = operDims.size(); int64_t currCollapsedNum = std::min(collapsedInnermostLoops, currRank); Value flatOper = create.mem.reshapeToFlatInnermost( @@ -1354,7 +1457,8 @@ static LogicalResult getPartiallyFlattenedSimdCode( Value flatAlloc = create.mem.reshapeToFlatInnermost( alloc, outputDims, flattenedOutputDims, collapsedInnermostLoops); - // Create loop iteration, rank-1, all but the flattened innermost [simd] loop. + // Create loop iteration, rank-1, all but the flattened innermost [simd] + // loop. int64_t outerLoopRank = rank - 1; ValueRange loopDef = create.krnl.defineLoops(outerLoopRank); // Iterate only over the blocks. @@ -1374,7 +1478,8 @@ static LogicalResult getPartiallyFlattenedSimdCode( "outer-loop of elementwise simd partially flattened"); } else { onnxToKrnlParallelReport(op, false, -1, -1, - "not enough work in outermost-loops of elementwise simd partially " + "not enough work in outermost-loops of elementwise simd " + "partially " "flattened"); } } else { @@ -1393,16 +1498,18 @@ static LogicalResult getPartiallyFlattenedSimdCode( } } } - create.krnl.iterateIE( - loopDef, loopDef, lbs, ubs, [&](KrnlBuilder &ck, ValueRange loopInd) { + create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, + [&](const KrnlBuilder &ck, ValueRange loopInd) { MultiDialectBuilder create(ck); - // LoopInd has the current indices for all but the innermost dim. Since - // we expect here the entire innermost loop iteration in one go, the - // innermost loop starts at zero. Add here to the list of Dim symbols. + // LoopInd has the current indices for all but the innermost dim. + // Since we expect here the entire innermost loop iteration in one go, + // the innermost loop starts at zero. Add here to the list of Dim + // symbols. SmallVector outputAccessExprs = DimListIE(loopInd); outputAccessExprs.emplace_back(zero); - // Have to produce the list of input values and their access functions. + // Have to produce the list of input values and their access + // functions. llvm::SmallVector inputs = flatOperands; llvm::SmallVector inputAFs; for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) { @@ -1440,8 +1547,7 @@ static LogicalResult getPartiallyFlattenedSimdCode( create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly, useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF}, - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { MultiDialectBuilder create(kb); Type currElementType = outputElementType; if (VL > 1) @@ -1453,8 +1559,8 @@ static LogicalResult getPartiallyFlattenedSimdCode( res = emitScalarOpFor( rewriter, create.getLoc(), op, currElementType, inputVals); } else { - // For non-unary ops, each op is a flattened array that need to - // be processed; process the two first ones, and then + // For non-unary ops, each op is a flattened array that need + // to be processed; process the two first ones, and then // "accumulate" one value at a time. Use the first operand as // temporary result. Value accumulated = inputVals[0]; @@ -1470,9 +1576,9 @@ static LogicalResult getPartiallyFlattenedSimdCode( res = emitPostProcessingFor(rewriter, create.getLoc(), op, currElementType, accumulated); } - resVals.emplace_back(res); - }); // SIMD kernel. - }); // Outer loops. + return res; + }}); // SIMD kernel. + }); // Outer loops. rewriter.replaceOp(op, alloc); return success(); @@ -1483,9 +1589,9 @@ static LogicalResult getPartiallyFlattenedSimdCode( //===----------------------------------------------------------------------===// // Function pointer type for the emitScalarOpFor of elementwise Ops. -typedef mlir::Value (*EmitScalarFunc)(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::Operation *op, mlir::Type elementType, - mlir::ArrayRef scalarOperands); +typedef Value (*EmitScalarFunc)(ConversionPatternRewriter &rewriter, + Location loc, mlir::Operation *op, Type elementType, + mlir::ArrayRef scalarOperands); // Utility class for Op fusion. // Start from the root op, which is being lowered as an Elementwise Op. @@ -1496,7 +1602,8 @@ typedef mlir::Value (*EmitScalarFunc)(mlir::ConversionPatternRewriter &rewriter, // loop nest generated for the root Op. // Finally the last op is replaced with the allocated memref and the other ops // are deleted. -// ToFix: fusion for a graph structure, not just line, could be added in future. +// ToFix: fusion for a graph structure, not just line, could be added in +// future. class OpFusionHelper { public: // Constructor @@ -1662,10 +1769,10 @@ bool OpFusionHelper::isControlFlowValidForFusion( // %2 = "onnx.Add"(%1, %3) : (tensor<16x24xf32>, tensor<24xf32>) -> tensor<16x24xf32> // clang-format on // In this implementation, no data dependence analysis and -// code motion for fusion is implemented yet. The only other inputs allowed are -// block argument and constant to guarantee they are before the root op. It is -// assumed the canonicalization has hoisted all constant to the beginning of the -// function by fold function. +// code motion for fusion is implemented yet. The only other inputs allowed +// are block argument and constant to guarantee they are before the root op. +// It is assumed the canonicalization has hoisted all constant to the +// beginning of the function by fold function. bool OpFusionHelper::areInputsValidForFusion( Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis) { // Do not fuse ops with scalar tensors. @@ -1803,11 +1910,11 @@ Value OpFusionHelper::emitFuseOps( // In a previous implementation, the original output of defOp is used // with 'alloc = defOp->getResult(0)' at the end of the loop. // But ONNXBroadcastOpShapeHelper.computeShape() unexpectedly used - // this parameter to generate some code (memref.dim) that is not really - // needed. Due to this live user, the original op can not be erased. - // This error occurred when there were more than one op with dynamic dim - // to be fused in the previous implementation. - // Therefore, alloc is used for all the fused op. + // this parameter to generate some code (memref.dim) that is not + // really needed. Due to this live user, the original op can not be + // erased. This error occurred when there were more than one op with + // dynamic dim to be fused in the previous implementation. Therefore, + // alloc is used for all the fused op. useOperands.emplace_back(alloc); } // Use shape helper to generate load index @@ -1970,13 +2077,14 @@ struct ONNXElementwiseUnaryOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(X, ubs); if (enableParallel) { @@ -1992,7 +2100,7 @@ struct ONNXElementwiseUnaryOpLowering } } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { SmallVector args; Value loadedVal = createKrnl.load(X, loopInd); args.emplace_back(loadedVal); @@ -2151,13 +2259,14 @@ struct ONNXElementwiseBinaryOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); // TODO adjust in the future @@ -2174,7 +2283,7 @@ struct ONNXElementwiseBinaryOpLowering } } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { IndexExprScope innerScope(createKrnl, shapeHelper.getScope()); SmallVector outputAccessExprs; getIndexExprList(loopInd, outputAccessExprs); @@ -2326,13 +2435,14 @@ struct ONNXElementwiseVariadicOpLowering outputMemRefType = opFusionHelper.getOutputType(outputMemRefType); // Insert an allocation and deallocation for the result of this operation. - Value alloc = create.mem.alignedAlloc( - outputMemRefType, shapeHelper.getOutputDims(), alignment); + Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType, + shapeHelper.getOutputDims(), alignment); + ; // Only create krnl.iterate if one of the operands is not scalar tensor. if (!isScalar) { ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); @@ -2349,7 +2459,7 @@ struct ONNXElementwiseVariadicOpLowering } } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { IndexExprScope innerScope(createKrnl, shapeHelper.getScope()); SmallVector outputAccessExprs; getIndexExprList(loopInd, outputAccessExprs); @@ -2456,7 +2566,7 @@ struct ONNXWhereOpLowering : public ConversionPattern { // Only create krnl.iterate if one of the operands is not scalar tensor. if (!hasAllScalarValues(operands)) { ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); if (enableParallel) { @@ -2472,7 +2582,7 @@ struct ONNXWhereOpLowering : public ConversionPattern { } } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { IndexExprScope innerScope(&rewriter, shapeHelper.getScope()); SmallVector outputAccessExprs; getIndexExprList(loopInd, outputAccessExprs); diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index e46dd41541..af0724c446 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -64,7 +64,7 @@ struct ONNXGemmOpLowering : public OpConversionPattern { ValueRange loopDef = create.krnl.defineLoops(3); SmallVector outerLoopDef{loopDef[0], loopDef[1]}; SmallVector innerLoopDef{loopDef[2]}; - SmallVector loopLbs(3, LiteralIndexExpr(0)); + SmallVector loopLbs(3, LitIE(0)); IndexExpr outerUb0 = shapeHelper.getOutputDims()[0]; IndexExpr outerUb1 = shapeHelper.getOutputDims()[1]; IndexExpr innerUb = shapeHelper.aDims[1]; @@ -83,16 +83,18 @@ struct ONNXGemmOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(loopDef, outerLoopDef, loopLbs, loopUbs, - [&](KrnlBuilder &createKrnl, ValueRange outerIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange outerIndices) { MultiDialectBuilder create( createKrnl); // Create temp, single scalar, no need for default alignment. + // Alloca is ok here as its for a scalar, and in the generic version + // of GEMM. Value red = create.mem.alloca(MemRefType::get({}, elementType)); // Set to zero. create.krnl.store(zeroVal, red); // Inner loop. create.krnl.iterate({}, innerLoopDef, {}, {}, - [&](KrnlBuilder &createKrnl, ValueRange innerIndex) { + [&](const KrnlBuilder &createKrnl, ValueRange innerIndex) { Value i(outerIndices[0]), j(outerIndices[1]), k(innerIndex[0]); MultiDialectBuilder create( createKrnl); @@ -122,7 +124,7 @@ struct ONNXGemmOpLowering : public OpConversionPattern { // If dim > 1, use loop index, otherwise broadcast on 0's element. DimIndexExpr dim(shapeHelper.cDims[x]); cAccess.emplace_back( - IndexExpr::select(dim > 1, DimIndexExpr(outerIndices[x]), 0) + IndexExpr::select(dim > 1, DimIE(outerIndices[x]), 0) .getValue()); } Value c = create.krnl.load(adaptor.getC(), cAccess); @@ -203,14 +205,6 @@ struct ONNXGemmOpLowering : public OpConversionPattern { MemRefType bTileType = MemRefType::get({kCacheTile, jCacheTile}, elementType); SmallVector empty; - // Allocate here on heap, only when no parallelism. - Value aBuff, bBuff, rBuff; - if (!enableParallel) { - aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); - bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); - if (mustTileR) - rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); - } // 3) introduce the loops and permute them // I, J, K loop. @@ -250,18 +244,16 @@ struct ONNXGemmOpLowering : public OpConversionPattern { } // Compute: A[i, k] * b[k, j] -> R[i, j]) create.krnl.iterateIE({ii, jj, kk}, {ii1, jj1}, {zeroIE, zeroIE, zeroIE}, - {I, J, K}, [&](KrnlBuilder &createKrnl, ValueRange i1_j1_indices) { + {I, J, K}, + [&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) { Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]); - // If parallel, allocate on stack inside the parallel region. - if (enableParallel) { - aBuff = create.mem.alignedAlloca(aTileType, BUFFER_ALIGN); - bBuff = create.mem.alignedAlloca(bTileType, BUFFER_ALIGN); - if (mustTileR) - rBuff = create.mem.alignedAlloca(aTileType, BUFFER_ALIGN); - } + // If parallel, will stay inside, otherwise will migrate out. + Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); + Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); + Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); createKrnl.copyToBuffer(rBuff, R, {i1, j1}, zeroVal, false); createKrnl.iterateIE({}, {kk1}, {}, {}, - [&](KrnlBuilder &createKrnl, ValueRange k1_index) { + [&](const KrnlBuilder &createKrnl, ValueRange k1_index) { Value k1(k1_index[0]); if (aTrans) createKrnl.copyToBuffer(aBuff, A, {k1, i1}, zeroVal, true); @@ -272,7 +264,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern { else createKrnl.copyToBuffer(bBuff, B, {k1, j1}, zeroVal, false); createKrnl.iterate({}, {jj2, ii2}, {}, {}, - [&](KrnlBuilder &createKrnl, ValueRange j2_i2_indices) { + [&](const KrnlBuilder &createKrnl, + ValueRange j2_i2_indices) { Value j2(j2_i2_indices[0]), i2(j2_i2_indices[1]); ArrayRef empty; createKrnl.matmul(aBuff, {i1, k1}, bBuff, {k1, j1}, @@ -316,28 +309,26 @@ struct ONNXGemmOpLowering : public OpConversionPattern { // "not currently used ones" like ii here last. Gave an error when ii was // listed first. create.krnl.iterateIE({jj, kk, ii}, {jj1, kk1}, {zeroIE, zeroIE, zeroIE}, - {J, K, I}, [&](KrnlBuilder &createKrnl, ValueRange j1_k1_indices) { + {J, K, I}, + [&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) { Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]); - // If parallel, allocate on stack inside the parallel region. - if (enableParallel) { - aBuff = create.mem.alignedAlloca(aTileType, BUFFER_ALIGN); - bBuff = create.mem.alignedAlloca(bTileType, BUFFER_ALIGN); - if (mustTileR) - rBuff = create.mem.alignedAlloca(aTileType, BUFFER_ALIGN); - } + // If parallel, it will stay inside, otherwise it will migrate out. + Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); + Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); if (bTrans) createKrnl.copyToBuffer(bBuff, B, {j1, k1}, zeroVal, true); else createKrnl.copyToBuffer(bBuff, B, {k1, j1}, zeroVal, false); createKrnl.iterateIE({}, {ii1}, {}, {}, - [&](KrnlBuilder &createKrnl, ValueRange i1_index) { + [&](const KrnlBuilder &createKrnl, ValueRange i1_index) { Value i1(i1_index[0]); if (aTrans) createKrnl.copyToBuffer(aBuff, A, {k1, i1}, zeroVal, true); else createKrnl.copyToBuffer(aBuff, A, {i1, k1}, zeroVal, false); createKrnl.iterate({}, {jj2, ii2}, {}, {}, - [&](KrnlBuilder &createKrnl, ValueRange j2_i2_indices) { + [&](const KrnlBuilder &createKrnl, + ValueRange j2_i2_indices) { Value j2(j2_i2_indices[0]), i2(j2_i2_indices[1]); createKrnl.matmul(aBuff, {i1, k1}, bBuff, {k1, j1}, R, {z, z}, @@ -374,7 +365,7 @@ struct ONNXGemmOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(outerLoops, outerLoops, {zeroIE, zeroIE}, {I, J}, - [&](KrnlBuilder &createKrnl, ValueRange outerIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange outerIndices) { // Handle alpha/beta coefficients. Value res = createKrnl.load(R, outerIndices); MathBuilder createMath(createKrnl); @@ -387,7 +378,7 @@ struct ONNXGemmOpLowering : public OpConversionPattern { // If dim > 1, use loop index, otherwise broadcast on 0's element. DimIndexExpr dim(shapeHelper.cDims[x]); cAccess.emplace_back( - IndexExpr::select(dim > 1, DimIndexExpr(outerIndices[x]), 0) + IndexExpr::select(dim > 1, DimIE(outerIndices[x]), 0) .getValue()); } Value c = createKrnl.load(adaptor.getC(), cAccess); diff --git a/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp b/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp index cb1db12d17..69ae1153cd 100644 --- a/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Hardmax.cpp @@ -39,7 +39,7 @@ static Value emitArgmax(ConversionPatternRewriter &rewriter, Location loc, // Allocate and initialize the result. // Th result has the same shape as the input except the axis dimension is 1. SmallVector outputUBS(inputUBS); - outputUBS[axis] = LiteralIndexExpr(1); + outputUBS[axis] = LitIE(1); SmallVector outputShape; for (const IndexExpr &dim : outputUBS) outputShape.push_back( @@ -49,9 +49,9 @@ static Value emitArgmax(ConversionPatternRewriter &rewriter, Location loc, create.krnl.memset(resMemRef, zero); ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, inputUBS, - [&](KrnlBuilder &createKrnl, ValueRange inputLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange inputLoopInd) { MultiDialectBuilder create( createKrnl); // Load the index of the current max value. @@ -68,7 +68,7 @@ static Value emitArgmax(ConversionPatternRewriter &rewriter, Location loc, // Compare and update the index for the maximum value. Value gt = create.math.sgt(next, maxValue); - create.scf.ifThenElse(gt, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(gt, [&](const SCFBuilder &createSCF) { KrnlBuilder createKrnl(createSCF); createKrnl.store(inputLoopInd[axis], resMemRef, resLoopInd); }); @@ -118,9 +118,9 @@ struct ONNXHardmaxOpLowering : public OpConversionPattern { // Produce the final result. // Set value to 1 if index is argmax. Otherwise, 0. ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { MultiDialectBuilder create( createKrnl); // Load the index of the current max value. @@ -132,13 +132,13 @@ struct ONNXHardmaxOpLowering : public OpConversionPattern { Value eq = create.math.eq(maxInd, loopInd[axis]); create.scf.ifThenElse( eq, /*then*/ - [&](SCFBuilder &createSCF) { + [&](const SCFBuilder &createSCF) { MultiDialectBuilder create(createSCF); Value one = create.math.constant(elementType, 1); create.krnl.store(one, resMemRef, loopInd); }, /*else*/ - [&](SCFBuilder &createSCF) { + [&](const SCFBuilder &createSCF) { MultiDialectBuilder create(createSCF); Value zero = create.math.constant(elementType, 0); create.krnl.store(zero, resMemRef, loopInd); diff --git a/src/Conversion/ONNXToKrnl/Math/LRN.cpp b/src/Conversion/ONNXToKrnl/Math/LRN.cpp index 8095e294e3..1b08661a2d 100644 --- a/src/Conversion/ONNXToKrnl/Math/LRN.cpp +++ b/src/Conversion/ONNXToKrnl/Math/LRN.cpp @@ -4,7 +4,7 @@ //===-------------------- LRN.cpp - Lowering LRN Op -----------------------===// // -// Copyright 2020-2023 The IBM Research Authors. +// Copyright 2020-2024 The IBM Research Authors. // // ============================================================================= // @@ -55,17 +55,17 @@ struct ONNXLRNOpLowering : public OpConversionPattern { auto f32Type = FloatType::getF32(rewriter.getContext()); Value biasValue = create.math.constant(f32Type, biasLit); Value alphaDivSizeValue = - create.math.constant(f32Type, alphaLit / (float)sizeLit); + create.math.constant(f32Type, alphaLit / static_cast(sizeLit)); Value betaValue = create.math.constant(f32Type, betaLit); Value alloc = create.mem.alignedAlloc(outputMemRefType, shapeHelper.getOutputDims()); ValueRange outputLoopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); create.krnl.iterateIE(outputLoopDef, outputLoopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange outputLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange outputLoopInd) { // Insert computation of square_sum. // square_sum[n, c, d1, ..., dk] = sum(X[n, i, d1, ..., dk] ^ 2), // where max(0, c - floor((size - 1) / 2)) <= i @@ -79,17 +79,15 @@ struct ONNXLRNOpLowering : public OpConversionPattern { constexpr int loopIndexForC = 1; DimIndexExpr cIE(outputLoopInd[loopIndexForC]); DimIndexExpr CIE = create.krnlIE.getShapeAsDim(input, loopIndexForC); - SymbolIndexExpr sizeIE = LiteralIndexExpr(sizeLit); + SymbolIndexExpr sizeIE = LitIE(sizeLit); SmallVector lbMaxList; - lbMaxList.emplace_back(LiteralIndexExpr(0)); - lbMaxList.emplace_back( - cIE - (sizeIE - 1).floorDiv(LiteralIndexExpr(2))); + lbMaxList.emplace_back(LitIE(0)); + lbMaxList.emplace_back(cIE - (sizeIE - 1).floorDiv(LitIE(2))); SmallVector ubMinList; ubMinList.emplace_back(CIE); - ubMinList.emplace_back( - cIE + 1 + (sizeIE - 1).ceilDiv(LiteralIndexExpr(2))); + ubMinList.emplace_back(cIE + 1 + (sizeIE - 1).ceilDiv(LitIE(2))); // Initialize sum, single scalar, no need for default alignment. MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); @@ -121,9 +119,9 @@ struct ONNXLRNOpLowering : public OpConversionPattern { Value loadVal = create.krnl.load(input, loadIndices); Value squareVal = create.math.mul(loadVal, loadVal); - Value sumValue = create.krnl.load(sumAlloc, ArrayRef{}); + Value sumValue = create.krnl.load(sumAlloc); sumValue = create.math.add(sumValue, squareVal); - create.krnl.store(sumValue, sumAlloc, ArrayRef{}); + create.krnl.store(sumValue, sumAlloc); // Compute and store the output // y = x / ((bias + (alpha / nsize) * square_sum) ** beta) diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index 574638d510..ba16633e3c 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -59,7 +59,7 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { int outerLoopNum = shapeHelper.getOutputDims().size(); int totLoopNum = outerLoopNum + 1; // Add reduction inner loop. ValueRange loopDef = create.krnl.defineLoops(totLoopNum); - SmallVector loopLbs(totLoopNum, LiteralIndexExpr(0)); + SmallVector loopLbs(totLoopNum, LitIE(0)); SmallVector loopUbs; // All getOutputDims, plus reduction. SmallVector outerLoops; // All but the last loop def. for (int i = 0; i < outerLoopNum; ++i) { @@ -86,14 +86,14 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { // Non-reduction loop iterations: output-rank. create.krnl.iterateIE(loopDef, outerLoops, loopLbs, loopUbs, - [&](KrnlBuilder &createKrnl, ValueRange outerIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange outerIndices) { MultiDialectBuilder create( createKrnl); ValueRange inits = ValueRange(fZero); // Inner loop for reduction. auto innerIterate = create.krnl.iterate({}, innerLoop, {}, {}, inits, - [&](KrnlBuilder &createKrnl, ValueRange innerIndex, + [&](const KrnlBuilder &createKrnl, ValueRange innerIndex, ValueRange iterArgs) { // Get last argument for the iterate body. Value iterArg = iterArgs.back(); @@ -340,7 +340,7 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { } } create.krnl.iterate({ii, jj, kk}, {ii1, jj1, kk1}, {zero, zero, zero}, - {I, J, K}, [&](KrnlBuilder &createKrnl, ValueRange indices) { + {I, J, K}, [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value i1(indices[0]), j1(indices[1]), k1(indices[2]); createKrnl.matmul(A, {zero, zero}, B, {zero, zero}, C, {zero, zero}, {ii2, jj2, kk2}, {i1, j1, k1}, {I, J, K}, @@ -408,7 +408,7 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { if (enableParallel) { int64_t parId; // Could check out more than the outer dim of the broadcasts... - SmallVector lb(1, LiteralIndexExpr(0)), + SmallVector lb(1, LitIE(0)), ub(1, shapeHelper.getOutputDims()[0]); if (findSuitableParallelDimension(lb, ub, 0, 1, parId, /*min iter for going parallel*/ 4)) { @@ -420,7 +420,7 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { } } create.krnl.iterate(broadcastLoop, broadcastLoop, broadcastLB, broadcastUB, - [&](KrnlBuilder &createKrnl, ValueRange broadcastIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange broadcastIndices) { MultiDialectBuilder create(createKrnl); // I, J, K loop. ValueRange origLoop = create.krnl.defineLoops(3); @@ -436,7 +436,8 @@ struct ONNXMatMulOpLowering : public OpConversionPattern { create.krnl.permute( {ii1, ii2, jj1, jj2, kk1, kk2}, {0, 3, 1, 4, 2, 5}); create.krnl.iterate({ii, jj, kk}, {ii1, jj1, kk1}, {zero, zero, zero}, - {I, J, K}, [&](KrnlBuilder &createKrnl, ValueRange indices) { + {I, J, K}, + [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value i1(indices[0]), j1(indices[1]), k1(indices[2]); // Compute global start for B/C: {broadcastIndices, 0, 0} SmallVector broadcastGlobalStart; diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index 8702377291..d406bcc571 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -12,6 +12,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir/IR/BuiltinTypeInterfaces.h" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" @@ -19,7 +21,8 @@ #include "src/Support/SmallVectorHelper.hpp" #define DEBUG_TYPE "lowering-to-krnl" -#define DEBUG_FORCE_SHUFFLE_REDUCTION 0 +#define DEBUG_FORCE_SHUFFLE_REDUCTION 0 /* should be 0 in repo */ +#define REDUCTION_MULTIPLE_OF_VL_ONLY 0 /* 0: improved;1: old, for debug */ using namespace mlir; @@ -279,21 +282,119 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, //===----------------------------------------------------------------------===// -using MDBuilder = - MultiDialectBuilder; +using MDBuilder = MultiDialectBuilder; +using PreProcessFn = std::function; //===----------------------------------------------------------------------===// // Helper function to perform reduction when an entire tensor is reduced to a // single value. Support the reduction for up to 2 operations at once. If only // one is needed, then pass ONNXNoneOp in the second slot. +// If PreProcessFn is given, each input value is preprocessed by the function +// before doing reduction. // Return true if we can optimize the reduction, false otherwise. -// TODO: alexe add support for parallel -// TODO: alexe see if the new simd infrastructure can be used. +template +void emitOneStepOfFullSIMDReduction(ConversionPatternRewriter &rewriter, + Operation *op, MDBuilder &create, Type elementType, IndexExpr lb, + IndexExpr ub, int64_t VL, bool simdOnly, IndexExpr t, int64_t tNum, + bool hasTwoRed, Value input1, Value input2, Value tmp1, Value tmp2, + PreProcessFn preProc1, PreProcessFn preProc2, Value output1, Value output2, + Value divisorForMean) { + + VectorType vecType = VectorType::get({VL}, elementType); + + SmallVector inputs, tmps, outputs, initVals; + SmallVector inputAFs, tmpAFs, outputAFs; + IndexExpr zero = LitIE(0); + + // Init data for 1st reduction + inputs.emplace_back(input1); + DimsExpr inputAF(1, zero); // Inputs starts at 0 + inputAFs.emplace_back(inputAF); + tmps.emplace_back(tmp1); + DimsExpr tmpAF(1, t * VL); // Each thread t starts a new section of VL values. + tmpAFs.emplace_back(tmpAF); + outputs.emplace_back(output1); + DimsExpr outputAF; // By default, no value (scalar). + if (tNum > 1) { + outputAF.emplace_back(t); // If parallel, indexed by t. + } + outputAFs.emplace_back(outputAF); + initVals.emplace_back(getIdentityValue( + rewriter, create.getLoc(), elementType)); + // Init data for 2nd reduction. + if (hasTwoRed) { + inputs.emplace_back(input2); + inputAFs.emplace_back(inputAF); + tmps.emplace_back(tmp2); + tmpAFs.emplace_back(tmpAF); + outputs.emplace_back(output2); + outputAFs.emplace_back(outputAF); + initVals.emplace_back(getIdentityValue( + rewriter, create.getLoc(), elementType)); + } + + // Create the reduction functions. + llvm::SmallVector, 2> redBodyFnList; + llvm::SmallVector, 2> + postRedBodyFnList; + // Push functions for the first reduction. + redBodyFnList.emplace_back( + [&](const BUILDER &b, Value inputVal, Value tmpVal, int64_t VL) { + Type currType = (VL > 1) ? vecType : elementType; + // Perform reduction of tmp and input. + if (preProc1) + inputVal = preProc1(inputVal); + return emitScalarOpFor( + rewriter, create.getLoc(), op, currType, {tmpVal, inputVal}); + }); + postRedBodyFnList.emplace_back( + [&](const BUILDER &b, Value tmpVal, int64_t VL) { + // Perform horizontal reductions. + Value scalarVal = + create.vec.reduction(getCombiningKind(), tmpVal); + if (tNum == 1) { /* parallel: do it for the final iteration only */ + if (divideByMean()) + scalarVal = create.math.div(scalarVal, divisorForMean); + } + return scalarVal; + }); + if (hasTwoRed) { + // Push functions for the second reduction. + redBodyFnList.emplace_back( + [&](const BUILDER &b, Value inputVal, Value tmpVal, int64_t VL) { + Type currType = (VL > 1) ? vecType : elementType; + // Perform reduction of tmp and input. + if (preProc2) + inputVal = preProc2(inputVal); + return emitScalarOpFor( + rewriter, create.getLoc(), op, currType, {tmpVal, inputVal}); + }); + postRedBodyFnList.emplace_back( + [&](const BUILDER &b, Value tmpVal, int64_t VL) { + // Perform horizontal reductions. + Value scalarVal = create.vec.reduction( + getCombiningKind(), tmpVal); + if (tNum == 1) { /* parallel: do it for the final iteration only */ + if (divideByMean()) + scalarVal = create.math.div(scalarVal, divisorForMean); + } + return scalarVal; + }); + } + // Call simd reduce. + BUILDER builder(create.vec); + builder.simdReduceIE(lb, ub, VL, simdOnly, inputs, inputAFs, tmps, tmpAFs, + outputs, outputAFs, initVals, redBodyFnList, postRedBodyFnList); +} + template bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc, - Operation *op, Value input, Value &alloc1, Value &alloc2) { + Operation *op, Value input, PreProcessFn preProc1, PreProcessFn preProc2, + Value &alloc1, Value &alloc2, bool enableParallel) { // Create scope. IndexExprScope scope(&rewriter, loc); MDBuilder create(rewriter, loc); @@ -306,6 +407,13 @@ bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc, // Flatten entirely the input memref. Value flatInput = create.mem.reshapeToFlatInnermost( input, inputDims, flatInputDims, inputRank); + IndexExpr zero = LitIE(0); + IndexExpr lb = zero; + IndexExpr ub = flatInputDims[0]; + // Compute the divisor that is the number of elements participated in + // reduction, i.e., 'divisor = size of input / size of output, where + // output size == 1'. + Value divisorForMean = create.math.cast(elementType, ub.getValue()); // Has one or 2 reductions? bool hasTwoRed = true; @@ -326,92 +434,92 @@ bool emitFullSIMDReductionFor(ConversionPatternRewriter &rewriter, Location loc, int64_t totVL = computeSuitableUnrollFactor(inputType, collapsedInnermostLoops, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); - // Current simdized loop only support SIMD only scheme. - if (!simdOnly) { - totVL = capVLForSimdOnly(inputType, totVL, simdLoopStaticTripCount); - } - if (totVL <= 1) - return false; // TODO alexe: consider staying here with VL=1 - IndexExpr VLIndexExpr = LitIE(totVL); - - // Compute type of small temporary reduction vector. - MemRefType outputType = MemRefType::get({}, elementType); - MemRefType redType = MemRefType::get({totVL}, elementType); - VectorType vecType = VectorType::get({totVL}, elementType); - - // Initialize first reduction. - Value zero = create.math.constantIndex(0); - /*output*/ alloc1 = create.mem.alloc(outputType); - Value redAlloc1 = create.mem.alignedAlloc(redType); - Value identity1 = getIdentityValue( - rewriter, create.getLoc(), elementType); - Value initVec1 = create.vec.splat(vecType, identity1); - create.vec.store(initVec1, redAlloc1, {zero}); - // Init second reduction. - alloc2 = nullptr; - Value redAlloc2 = nullptr; - if (hasTwoRed) { - /*output*/ alloc2 = create.mem.alloc(outputType); - redAlloc2 = create.mem.alignedAlloc(redType); - Value identity2 = getIdentityValue( - rewriter, create.getLoc(), elementType); - Value initVec2 = create.vec.splat(vecType, identity2); - create.vec.store(initVec2, redAlloc2, {zero}); - } - - // Loop over SIMD values. - ValueRange loopDef = create.krnl.defineLoops(1); - ValueRange blockedLoopDef = create.krnl.block(loopDef[0], totVL); - create.krnl.iterate(loopDef, {blockedLoopDef[0]}, {zero}, - {flatInputDims[0].getValue()}, [&](KrnlBuilder &ck, ValueRange loopInd) { - MDBuilder create(ck); - // Input values, loaded as a vector. - SmallVector inAccessVals; - inAccessVals.emplace_back(loopInd[0]); - Value inputVec = create.vec.load(vecType, flatInput, inAccessVals); - // Process first reduction. - Value redVec1 = create.vec.load(vecType, redAlloc1, {zero}); - Value accumulatedVec1 = emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {redVec1, inputVec}); - create.vec.store(accumulatedVec1, redAlloc1, {zero}); - // Process second reduction. - if (hasTwoRed) { - Value redVec2 = create.vec.load(vecType, redAlloc2, {zero}); - Value accumulatedVec2 = emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {redVec2, inputVec}); - create.vec.store(accumulatedVec2, redAlloc2, {zero}); - } - }); - - // First reduction horizontal sum. - Value reductionVec1 = create.vec.load(vecType, redAlloc1, {zero}); - Value res1 = - create.vec.reduction(getCombiningKind(), reductionVec1); - // Second reduction horizontal sum. - Value res2 = nullptr; - if (hasTwoRed) { - Value reductionVec2 = create.vec.load(vecType, redAlloc2, {zero}); - res2 = create.vec.reduction( - getCombiningKind(), reductionVec2); + // Test if loop trip count is long enough for a parallel execution. + if (enableParallel) { + int64_t parId; + if (findSuitableParallelDimension({lb}, {ub}, 0, 1, parId, 32 * totVL)) { + onnxToKrnlParallelReport( + op, true, parId, lb, ub, "simd reduction to one element"); + } else { + enableParallel = false; + onnxToKrnlParallelReport(op, false, -1, -1, + "not enough work in simd reduction to one element"); + } } + if (!enableParallel) { + // Allocate temp and output memory + Value tmp1, tmp2; + MemRefType redType = MemRefType::get({totVL}, elementType); + MemRefType outputType = MemRefType::get({}, elementType); + tmp1 = create.mem.alignedAlloc(redType); + /*output*/ alloc1 = create.mem.alloc(outputType); + + alloc2 = nullptr; + if (hasTwoRed) { + tmp2 = create.mem.alignedAlloc(redType); + /*output*/ alloc2 = create.mem.alloc(outputType); + } + int64_t tNum = 1; // No parallelism. + IndexExpr t = zero; + // OK to use Krnl builder here as we have a simple loop structure. + emitOneStepOfFullSIMDReduction(rewriter, op, create, elementType, lb, ub, totVL, + simdOnly, t, tNum, hasTwoRed, flatInput, flatInput, tmp1, tmp2, + preProc1, preProc2, alloc1, alloc2, divisorForMean); + } else { + // Performs 2 rounds: first round compute a parallel partial reduction + // where each (possibly virtual) thread is responsible for one chunk. + // Second round computes the final reduction done by one thread. + + // TODO: this should not be hardwired but gotten from an option. + int64_t tNum = 8; + + // Round 1. + MemRefType redType = MemRefType::get({tNum * totVL}, elementType); + MemRefType outputType = MemRefType::get({tNum}, elementType); + Value tmp1, tmp2, output1, output2; + + tmp1 = create.mem.alignedAlloc(redType); + output1 = create.mem.alloc(outputType); + if (hasTwoRed) { + tmp2 = create.mem.alignedAlloc(redType); + output2 = create.mem.alloc(outputType); + } - // Handle mean if any. - Value divisorForMean = nullptr; - if (divideByMean() || divideByMean()) { - // Compute the divisor that is the number of elements participated in - // reduction, i.e., 'divisor = size of input / size of output, where output - // size == 1'. - divisorForMean = create.math.cast(elementType, flatInputDims[0].getValue()); + IndexExpr tNumIE = LitIE(tNum); + bool simdOnly = false; // Refine, but since we are chunking input, safer. + create.krnl.forExplicitParallelLoopIE( + lb, ub, tNumIE, [&](const KrnlBuilder &ck, mlir::ValueRange loopInd) { + IndexExprScope scope(ck); + MDBuilder create(ck); + IndexExpr t = DimIE(loopInd[0]); + IndexExpr currLB = SymIE(loopInd[1]); + IndexExpr currUB = SymIE(loopInd[2]); + // Use SCF builder because the partition of outer loop into block + // makes the formulas non-affine. + emitOneStepOfFullSIMDReduction(rewriter, op, create, elementType, currLB, + currUB, totVL, simdOnly, t, tNum, hasTwoRed, flatInput, flatInput, + tmp1, tmp2, preProc1, preProc2, output1, output2, nullptr); + // Result here, each iteration would have generate 1 value in + // output1 &2, + }); + // Now we need to reduce output's tNum values into one. Reuse tmps. + MemRefType finalOutputType = MemRefType::get({}, elementType); + /*output*/ alloc1 = create.mem.alloc(finalOutputType); + alloc2 = nullptr; + if (hasTwoRed) + /*output*/ alloc2 = create.mem.alloc(finalOutputType); + IndexExpr finalLB = zero; + IndexExpr finalUB = tNumIE; + IndexExpr t = zero; + // Reduction here is straight forward, Krnl builder is fine. + emitOneStepOfFullSIMDReduction(rewriter, op, create, elementType, finalLB, finalUB, + /*VL*/ 1, /*simd only*/ false, t, /*thread num */ 1, hasTwoRed, output1, + output2, tmp1, tmp2, preProc1, preProc2, alloc1, alloc2, + divisorForMean); } - if (divideByMean()) - res1 = create.math.div(res1, divisorForMean); - if (hasTwoRed && divideByMean()) - res2 = create.math.div(res2, divisorForMean); - - // Save result. - create.affineKMem.store(res1, alloc1, {}); - if (hasTwoRed) - create.affineKMem.store(res2, alloc2, {}); if (hasTwoRed) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, @@ -427,9 +535,11 @@ void emitMinMaxReductionToScalar(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value input, Value &minAlloc, Value &maxAlloc, bool enableSIMD, bool enableParallel) { // Try optimized path first. - if (enableSIMD && emitFullSIMDReductionFor( - rewriter, loc, op, input, minAlloc, maxAlloc)) + if (enableSIMD && + emitFullSIMDReductionFor(rewriter, loc, + op, input, nullptr, nullptr, minAlloc, maxAlloc, enableParallel)) { return; + } // Could not optimize the pattern, generate default path. MultiDialectBuilder create(rewriter, loc); Type elementType = mlir::cast(input.getType()).getElementType(); @@ -442,6 +552,37 @@ void emitMinMaxReductionToScalar(ConversionPatternRewriter &rewriter, create.onnx.reduceMax(outputType, input, none, false)); } +void emitSymmetricQuantRecscaleToScalar(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Value input, uint64_t bitWidth, + Value &recscale, bool enableSIMD, bool enableParallel) { + Type elemType = getElementType(input.getType()); + assert(elemType.isF32() && "Only support f32"); + double range = static_cast((1 << (bitWidth - 1)) - 1); + + // Try optimized path first. + Value absmaxMemRef, noused; + MultiDialectBuilder create( + rewriter, loc); + if (enableSIMD && + emitFullSIMDReductionFor( + rewriter, loc, op, input, [&](Value v) { return create.math.abs(v); }, + nullptr, absmaxMemRef, noused, enableParallel)) { + Value cst = create.math.constant(elemType, range); + Value absmax = create.krnl.load(absmaxMemRef); + recscale = create.math.div(cst, absmax); + return; + } + + // Could not optimize the pattern, generate default path. + Value none = create.onnx.none(); + RankedTensorType scalarTy = RankedTensorType::get({}, elemType); + Value cst = create.onnx.constant( + DenseElementsAttr::get(scalarTy, static_cast(range))); + Value recscaleMemRef = create.onnx.toMemref( + create.onnx.div(cst, create.onnx.reduceMax(scalarTy, + create.onnx.abs(input), none, false, false))); + recscale = create.krnl.load(recscaleMemRef); +} //===----------------------------------------------------------------------===// // Generic reduction code (for current and legacy using "if constexpr". // Function use SIMD if all reductions occur consecutively in the innermost @@ -534,10 +675,13 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // Default value of having no axes. hasNoAxes = true; } else { - // Check it has a rank of 1. - assert( - create.krnlIE.getShapedTypeRank(axesVal) == 1 && "expect rank 1"); - axisShape0 = create.krnlIE.getShapeAsDim(axesVal, 0); + // Check it has a rank of 0 or 1. + int64_t axisRank = create.krnlIE.getShapedTypeRank(axesVal); + assert((axisRank == 0 || axisRank == 1) && "expect rank 0 or 1"); + if (axisRank == 0) + axisShape0 = LitIE(1); + else + axisShape0 = create.krnlIE.getShapeAsDim(axesVal, 0); if (!axisShape0.isLiteral()) // Don't even know the shape of the axis... it is dynamic. @@ -558,19 +702,34 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } } + ////////////////////////////////////////////////////////////////////// + // Reduction over all dimensions to a scalar value. + bool fullReduction = + hasNoAxes || (rawAxesIE.size() == static_cast(inRank)); + if (fullReduction && !isKeepdims && enableSIMD) { + Value alloc, none; + if (emitFullSIMDReductionFor(rewriter, loc, + op, input, nullptr, nullptr, alloc, none, enableParallel)) { + rewriter.replaceOp(op, alloc); + return success(); + } + } + ////////////////////////////////////////////////////////////////////// // Characterize literal axes: make unique and within [0, inRank). std::vector uniqueLitAxes; llvm::BitVector litAxes(inRank, false); if (hasNoAxes) { if (isNoop) { - // No axes and is noop, should we not just return the input array? - } else { - // No axes, perform a full reduction. - for (int64_t i = 0; i < inRank; ++i) { - uniqueLitAxes.push_back(i); - litAxes[i] = true; - } + // Axes is none and 'noop_with_empty_axes' is true. This behaves as a + // noop, replace op with its input + rewriter.replaceOp(op, adaptor.getData()); + return success(); + } + // No axes, perform a full reduction. + for (int64_t i = 0; i < inRank; ++i) { + uniqueLitAxes.push_back(i); + litAxes[i] = true; } } else if (!dynamicAxes) { // Check raw axes. @@ -604,6 +763,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { bool parallelSimd = false; int64_t innermostLoopCollapse = 0; int64_t totVL = 1; + bool simdOnly = false; int64_t simdLoopStaticTripCount = 0; // With dynamic axes, use this @@ -654,10 +814,10 @@ struct ONNXReductionOpLowering : public OpConversionPattern { #endif } // Currently only vectorize loops whose SIMD dimension is a multiple - // of the natural SIMD width. Aka, we don't deal with SIMD of partial - // vectors. + // of the natural SIMD width. Aka, we don't deal with SIMD of + // partial vectors. GenOpMix mix = getGenOpMix(elementOutType, op); - bool simdOnly, canOverCompute = false; + bool canOverCompute = false; totVL = computeSuitableUnrollFactor(memRefInType, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); @@ -667,11 +827,15 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // here. Some benchmarks have small trip counts (e.g. GPT2: 8). totVL = capVLForMaxUnroll(memRefInType, totVL, 1); } - // Current code gen scheme only support SIMD only scheme. +#if REDUCTION_MULTIPLE_OF_VL_ONLY + // Currently fails with krnl to affine without this. Should + // consider an affine simd iterate/reduce. onnx-mlir + // -shapeInformation=0:4x8 reducemean2.mlir -O3 --march=arm64 if (!simdOnly) { totVL = capVLForSimdOnly(memRefInType, totVL, simdLoopStaticTripCount); } +#endif LLVM_DEBUG(llvm::dbgs() << " SIMD: " << innermostLoopCollapse << " loops, totVL " << totVL << "\n"); if (totVL <= 1) { @@ -728,9 +892,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (!axisShape0.isLiteral()) { // When axes is dynamic, generate a Krnl loop KrnlBuilder createKrnl(rewriter, loc); - ValueRange loopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(loopDef, loopDef, {LiteralIndexExpr(0)}, - {axisShape0}, [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + createKrnl.forLoopIE(LitIE(0), axisShape0, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { Value axe = createKrnl.load(axesVal, loopInd[0]); Value cond = create.math.slt(axe, zeroValue); Value dim = create.math.select( @@ -791,12 +954,12 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // Compute the divisor that is the number of elements participated in // reduction, i.e., 'divisor = size of input / size of output'. IndexExprScope scope(create.krnl); - IndexExpr inputSizeExpr = LiteralIndexExpr(1); + IndexExpr inputSizeExpr = LitIE(1); for (unsigned i = 0; i < inRank; i++) { IndexExpr dimExpr = create.krnlIE.getShapeAsSymbol(input, i); inputSizeExpr = inputSizeExpr * dimExpr; } - IndexExpr outputSizeExpr = LiteralIndexExpr(1); + IndexExpr outputSizeExpr = LitIE(1); for (unsigned i = 0; i < outRank; i++) { IndexExpr dimExpr = create.krnlIE.getShapeAsSymbol(alloc, i); outputSizeExpr = outputSizeExpr * dimExpr; @@ -808,14 +971,14 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (horizontalSimd) { if (hasHorizontalSimdSupport) { genHorizontalSimdReduction(rewriter, create, op, elementOutType, input, - alloc, inRank, outRank, totVL, innermostLoopCollapse, isKeepdims, - divisorForMean, enableParallel); + alloc, inRank, outRank, totVL, simdOnly, innermostLoopCollapse, + isKeepdims, divisorForMean, enableParallel); onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "horizontal"); } else { genShuffleHorizontalSimdReduction(rewriter, create, op, elementOutType, - input, alloc, inRank, outRank, totVL, innermostLoopCollapse, - isKeepdims, divisorForMean, enableParallel); + input, alloc, inRank, outRank, totVL, simdOnly, + innermostLoopCollapse, isKeepdims, divisorForMean, enableParallel); onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "shuffle-horizontal"); } @@ -857,14 +1020,14 @@ struct ONNXReductionOpLowering : public OpConversionPattern { create.krnl.memset(alloc, identity); ValueRange loop2Def = create.krnl.defineLoops(inRank); - SmallVector lbs2(inRank, LiteralIndexExpr(0)); + SmallVector lbs2(inRank, LitIE(0)); SmallVector ubs2; create.krnlIE.getShapeAsSymbols(input, ubs2); Value trueVal = create.math.constant(rewriter.getIntegerType(1), 1); // TODO Temporary disable the 2nd loop parallelism, since its outermost // loop could be a reduction loop, where parallelism would not be safe. create.krnl.iterateIE(loop2Def, loop2Def, lbs2, ubs2, - [&](KrnlBuilder &kb, ValueRange loopInd) { + [&](const KrnlBuilder &kb, ValueRange loopInd) { MultiDialectBuilder create(kb); Value zeroIndex = create.math.constantIndex(0); // Compute accumulator access function. @@ -894,7 +1057,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (divideByMean()) { // Compute mean ValueRange loop3Def = create.krnl.defineLoops(outRank); - SmallVector lbs3(outRank, LiteralIndexExpr(0)); + SmallVector lbs3(outRank, LitIE(0)); SmallVector ubs3; create.krnlIE.getShapeAsSymbols(alloc, ubs3); if (enableParallel) { @@ -910,7 +1073,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(loop3Def, loop3Def, lbs3, ubs3, - [&](KrnlBuilder &kb, ValueRange loopInd) { + [&](const KrnlBuilder &kb, ValueRange loopInd) { MultiDialectBuilder create(kb); Value loadData = create.krnl.load(alloc, loopInd); Value meanVal = create.math.div(loadData, divisorForMean); @@ -939,8 +1102,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern { } // Generate a single reduction, eventually using a horizontal reduction - // (which, if the hardware supports it, will be one instruction; otherwise it - // will be simulated by several operations). + // (which, if the hardware supports it, will be one instruction; otherwise + // it will be simulated by several operations). // // flatInput has been flattened from [N][M][R1][R2] to [N][M][R1*R2], where // the SIMD reduction is done along the last dim. By definition of what we @@ -955,38 +1118,39 @@ struct ONNXReductionOpLowering : public OpConversionPattern { void genOneHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, VectorType vecType, - Value tmpAlloca, Value flatInput, Value flatAlloc, Value initVec, - Value divisorForMean, ValueRange outLoopInd, Value simdUB, - int64_t VL) const { - // Init temp memory to init values. - Value zero = create.math.constantIndex(0); - create.vec.store(initVec, tmpAlloca, {zero, zero}); - // Iterate over the SIMD blocks. - ValueRange simdLoopDef = create.krnl.defineLoops(1); - ValueRange blockedSimdLoopDef = create.krnl.block(simdLoopDef[0], VL); - create.krnl.iterate(simdLoopDef, {blockedSimdLoopDef[0]}, {zero}, {simdUB}, - [&](KrnlBuilder &ck, ValueRange simdLoopInd) { - MDBuilder create(ck); - // Input values, loaded as a vector. - SmallVector inAccessVals(outLoopInd); - inAccessVals.emplace_back(simdLoopInd[0]); - Value inputVec = create.vec.load(vecType, flatInput, inAccessVals); - Value tmpVec = create.vec.load(vecType, tmpAlloca, {zero, zero}); - // Sum into redVec - Value accumulatedVec = emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {tmpVec, inputVec}); - create.vec.store(accumulatedVec, tmpAlloca, {zero, zero}); - }); - // Horizontal sum. - Value reductionVec = create.vec.load(vecType, tmpAlloca, {zero, zero}); - Value accumulatedVal = - create.vec.reduction(getCombiningKind(), reductionVec); - // other operation... - if (divideByMean()) { - accumulatedVal = create.math.div(accumulatedVal, divisorForMean); - } - // Store tmp into result. - create.krnl.store(accumulatedVal, flatAlloc, outLoopInd); + Value tmpAlloc, Value flatInput, Value flatAlloc, Value initVec, + Value divisorForMean, ValueRange outLoopInd, Value simdUB, int64_t VL, + bool simdOnly) const { + IndexExpr lb = LitIE(0); + IndexExpr ub = SymIE(simdUB); + SmallVector outputAF = SymListIE(outLoopInd); + SmallVector inputAF = outputAF; + inputAF.emplace_back(lb); + SmallVector tmpAF(2, lb); // tmpAlloc is 2D + Value identity = getIdentityValue( + rewriter, create.getLoc(), elementType); + create.krnl.simdReduceIE(lb, ub, VL, simdOnly, + /* inputs*/ {flatInput}, {inputAF}, + /* temp */ {tmpAlloc}, {tmpAF}, + /* output */ {flatAlloc}, {outputAF}, + /* init */ {identity}, + /* reduction simd/scalar */ + {[&](const KrnlBuilder &kb, Value inputVal, Value tmpVal, int64_t VL) { + Type type = VL > 1 ? vecType : elementType; + return emitScalarOpFor( + rewriter, create.getLoc(), op, type, {tmpVal, inputVal}); + }}, + /* post processing */ + {[&](const KrnlBuilder &kb, Value tmpVal, int64_t VL) { + // Horizontal reduction. + Value accumulatedVal = + create.vec.reduction(getCombiningKind(), tmpVal); + // Other post reduction operation... + if (divideByMean()) { + accumulatedVal = create.math.div(accumulatedVal, divisorForMean); + } + return accumulatedVal; + }}); } // We assume here that the hardware has an efficient SIMD horizontal @@ -994,7 +1158,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // reductions that needs to be performed. void genHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, Value input, - Value alloc, int64_t inRank, int64_t outRank, int64_t VL, + Value alloc, int64_t inRank, int64_t outRank, int64_t VL, bool simdOnly, int64_t collapsedInnermostLoops, bool isKeepDims, Value divisorForMean, bool enableParallel) const { LLVM_DEBUG(llvm::dbgs() << "gen horizontal simd reduction\n"); @@ -1015,8 +1179,8 @@ struct ONNXReductionOpLowering : public OpConversionPattern { Value flatAlloc = create.mem.reshapeToFlatInnermost( alloc, outDims, flatOutDims, collapseOutInnermostLoop); int64_t flatOutRank = flatOutDims.size(); - // Flat output should have all but the flattened SIMD loop, so there should - // only be a 1 rank difference between the two. + // Flat output should have all but the flattened SIMD loop, so there + // should only be a 1 rank difference between the two. assert(flatOutRank == flatInRank - 1 && "wrong assumptions about dims"); // Parallelism only if output is not a scalar. @@ -1027,7 +1191,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { MemRefType tmpType = MemRefType::get({1, VL}, elementType); // Define loops for input dimensions, blocking the inner dim by VL ValueRange outLoopDef = create.krnl.defineLoops(flatOutRank); - SmallVector lbs(flatOutRank, LiteralIndexExpr(0)); + SmallVector lbs(flatOutRank, LitIE(0)); if (enableParallel) { int64_t parId; if (findSuitableParallelDimension(lbs, flatOutDims, 0, 1, parId, @@ -1036,125 +1200,110 @@ struct ONNXReductionOpLowering : public OpConversionPattern { onnxToKrnlParallelReport( op, true, 0, lbs[0], flatOutDims[0], "reduction h-simd"); } else { + enableParallel = false; onnxToKrnlParallelReport(op, false, 0, lbs[0], flatOutDims[0], "not enough work for reduction h-simd"); } } create.krnl.iterateIE(outLoopDef, outLoopDef, lbs, flatOutDims, - [&](KrnlBuilder &ck, ValueRange outLoopInd) { + [&](const KrnlBuilder &ck, ValueRange outLoopInd) { MDBuilder create(ck); - // Allocate temp inside loop (because of parallel). - Value tmpAlloca = create.mem.alignedAlloca(tmpType); + // When parallel, will stay inside; otherwise will migrate out. + Value tmpAlloc = create.mem.alignedAlloc(tmpType); Value identity = getIdentityValue( rewriter, create.getLoc(), elementType); Value initVec = create.vec.splat(vecType, identity); genOneHorizontalSimdReduction(rewriter, create, op, elementType, - vecType, tmpAlloca, flatInput, flatAlloc, initVec, divisorForMean, - outLoopInd, simdUB, VL); + vecType, tmpAlloc, flatInput, flatAlloc, initVec, divisorForMean, + outLoopInd, simdUB, VL, simdOnly); }); } // We perform here VL Simd Reductions at once. We are guaranteed that there // are VL reductions to be performed. The algorithm works in 2 steps. // - // In the first step, we perform the SIMD reductions of VL distinct reductions - // using the "emitScalarOp" associated with that operation. At the end of this - // step, we have VL distinct partial reductions, where each of the VL vector - // register have a partial reduction in each of their own VL SIMD slots. + // In the first step, we perform the SIMD reductions of VL distinct + // reductions using the "emitScalarOp" associated with that operation. At + // the end of this step, we have VL distinct partial reductions, where each + // of the VL vector register have a partial reduction in each of their own + // VL SIMD slots. // - // In the second step, we reduce each VL vectors of VL partial values into one - // vector of VL fully-reduced values. We use shuffle patterns to generate - // efficient code where each of the temporary vectors always contain VL - // values. This is implemented by the create.vec.multiReduction operation. + // In the second step, we reduce each VL vectors of VL partial values into + // one vector of VL fully-reduced values. We use shuffle patterns to + // generate efficient code where each of the temporary vectors always + // contain VL values. This is implemented by the create.vec.multiReduction + // operation. // // Finally, the VL full reductions are stored as a vector operation in the // flatAlloc[m][n+0...+VL-1] output. void genVlHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, VectorType vecType, - Value tmpBlockedAlloca, Value flatInput, Value flatAlloc, Value initVec, + Value tmpBlockedAlloc, Value flatInput, Value flatAlloc, Value initVec, Value divisorForMean, ValueRange blockedOutLoopInd, - IndexExpr blockedCurrIndex, Value simdUB, int64_t VL) const { - // Init temp memory to init values. - Value zero = create.math.constantIndex(0); - for (int64_t i = 0; i < VL; ++i) { - create.vec.store( - initVec, tmpBlockedAlloca, {create.math.constantIndex(i), zero}); - } - // First step: blocked simd loop. - ValueRange simdLoopDef = create.krnl.defineLoops(1); - ValueRange blockedSimdLoopDef = create.krnl.block(simdLoopDef[0], VL); - create.krnl.iterate(simdLoopDef, {blockedSimdLoopDef[0]}, {zero}, {simdUB}, - [&](KrnlBuilder &ck, ValueRange simdLoopInd) { - MDBuilder create(ck); - // Loop over blocked output loop, block guaranteed to be full. - for (int64_t i = 0; i < VL; ++i) { - IndexExpr offset = LiteralIndexExpr(i); - IndexExpr blockLocalIndIE = blockedCurrIndex + offset; - Value blockLocalInd = blockLocalIndIE.getValue(); - // All of the non-blocked loop, plus the inter tile index of the - // blocked loop, and the blocked simd loop. - SmallVector inAccessVals = - firstFew(blockedOutLoopInd, -2); - inAccessVals.emplace_back(blockLocalInd); - inAccessVals.emplace_back(simdLoopInd[0]); - Value inputVec = create.vec.load(vecType, flatInput, inAccessVals); - // The tmpInd value is between 0 and VL-1, and is local index - - // blocked index. - Value tmpInd = offset.getValue(); - Value tmpVec = - create.vec.load(vecType, tmpBlockedAlloca, {tmpInd, zero}); - // Sum into redVec - Value accumulatedVec = emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {tmpVec, inputVec}); - create.vec.store(accumulatedVec, tmpBlockedAlloca, {tmpInd, zero}); - } /* intra block output loop */ - }); /* blocked simd loop */ - // Step 2 - // Load all temp vectors. - SmallVector redIn, redOut; - for (int64_t i = 0; i < VL; ++i) { - Value val = create.vec.load( - vecType, tmpBlockedAlloca, {create.math.constantIndex(i), zero}); - redIn.emplace_back(val); - } - // Reduce all of the temp vectors at once. - auto redFct = [&](Value a, Value b) -> Value { - return emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {a, b}); - }; - create.vec.multiReduction(redIn, redFct, redOut); - // The redOut list should have one value with SIMD of VL. - assert(redOut.size() == 1 && "expected only one val"); - Value accumulatedVal = redOut[0]; - // Perform the mean computation if required. - if (divideByMean()) { - Value divisorForMeanVec = create.vec.splat(vecType, divisorForMean); - accumulatedVal = create.math.div(accumulatedVal, divisorForMeanVec); + IndexExpr blockedCurrIndex, Value simdUB, int64_t VL, + bool simdOnly) const { + IndexExpr zero = LitIE(0); + IndexExpr lb = zero; + IndexExpr ub = SymIE(simdUB); + int64_t rank = blockedOutLoopInd.size(); + DimsExpr inputAF = SymListIE(blockedOutLoopInd); + inputAF[rank - 1] = blockedCurrIndex; + inputAF.emplace_back(zero); + DimsExpr tmpAF = {zero, zero}; + DimsExpr outputAF = SymListIE(blockedOutLoopInd); + Value identity = getIdentityValue( + rewriter, create.getLoc(), elementType); + if (simdOnly) { + create.affine.simdReduce2DIE( + lb, ub, VL, simdOnly, flatInput, inputAF, tmpBlockedAlloc, tmpAF, + flatAlloc, outputAF, identity, + [&](const AffineBuilder &b, Value inputVal, Value tmpVal, + int64_t VL) { + Type type = VL > 1 ? vecType : elementType; + return emitScalarOpFor( + rewriter, b.getLoc(), op, type, {tmpVal, inputVal}); + }, + [&](const AffineBuilder &b, Value tmpVal, int VL) { + if (divideByMean()) + return create.math.div(tmpVal, divisorForMean); + return tmpVal; + }); + } else { + create.scf.simdReduce2DIE( // Affine fails with dynamic shapes. + lb, ub, VL, simdOnly, flatInput, inputAF, tmpBlockedAlloc, tmpAF, + flatAlloc, outputAF, identity, + [&](const SCFBuilder &b, Value inputVal, Value tmpVal, int64_t VL) { + Type type = VL > 1 ? vecType : elementType; + return emitScalarOpFor( + rewriter, b.getLoc(), op, type, {tmpVal, inputVal}); + }, + [&](const SCFBuilder &b, Value tmpVal, int VL) { + if (divideByMean()) + return create.math.div(tmpVal, divisorForMean); + return tmpVal; + }); } - // Store final values. - create.vec.store(accumulatedVal, flatAlloc, blockedOutLoopInd); } // Solution when there is no horizontal SIMD op support and that shuffle ops - // are needed. Assuming a (flattened) output reduction tensor of [N][M], this - // algorithm will block the inter dimension of the output tensor by VL. For - // each block of VL values to be reduced, we use the efficient functions that - // computes them using shuffles (genVlHorizontalSimdReduction). For the last - // block (if any) that has fewer than VL remaining reductions to be performed, - // we simply perform r 1 && "expected simd here"); - IndexExpr VLIndexExpr = LiteralIndexExpr(VL); VectorType vecType = VectorType::get({VL}, elementType); // Flatten the input: in[N][M][Red1][Red2] -> in[N][M][Red1*Red2] DimsExpr inDims, flatInDims; @@ -1174,13 +1323,16 @@ struct ONNXReductionOpLowering : public OpConversionPattern { Value flatAlloc = create.mem.reshapeToFlatInnermost( alloc, outDims, flatOutDims, collapseOutInnermostLoop); int64_t flatOutRank = flatOutDims.size(); - // Flat output should have all but the flattened SIMD loop, so there should - // only be a 1 rank difference between the two. + // Flat output should have all but the flattened SIMD loop, so there + // should only be a 1 rank difference between the two. assert(flatOutRank == flatInRank - 1 && "wrong assumptions about dims"); // Parallelism only if output is not a scalar. - if (flatOutRank == 0) + if (flatOutRank == 0 && enableParallel) { enableParallel = false; + onnxToKrnlParallelReport( + op, false, -1, 0, "zero flat out rank for reduction shuffle h-simd"); + } // Compute type of small temp vector. MemRefType tmpBlockedType = MemRefType::get({VL, VL}, elementType); @@ -1193,46 +1345,47 @@ struct ONNXReductionOpLowering : public OpConversionPattern { firstFew(outLoopDef, -2); optimizedOutLoopDef.emplace_back(blockedOutLoopDef[0]); // Iterate only over all but the inner loop of the flattened input. - SmallVector lbs(flatOutRank, LiteralIndexExpr(0)); + SmallVector lbs(flatOutRank, LitIE(0)); if (enableParallel) { int64_t parId; - if (findSuitableParallelDimension(lbs, flatOutDims, 0, 1, parId, - /*min iter for going parallel*/ 64 * VL)) { - create.krnl.parallel(optimizedOutLoopDef[0]); - onnxToKrnlParallelReport( - op, true, 0, lbs[0], flatOutDims[0], "reduction shuffle h-simd"); + if (findSuitableParallelDimension(lbs, flatOutDims, 0, flatOutRank, parId, + /*min iter for going parallel*/ 8 * VL)) { + create.krnl.parallel(optimizedOutLoopDef[parId]); + onnxToKrnlParallelReport(op, true, parId, lbs[parId], + flatOutDims[parId], "reduction shuffle h-simd"); } else { + enableParallel = false; onnxToKrnlParallelReport(op, false, 0, lbs[0], flatOutDims[0], "not enough work for reduction shuffle h-simd"); } } create.krnl.iterateIE(outLoopDef, optimizedOutLoopDef, lbs, flatOutDims, - [&](KrnlBuilder &ck, ValueRange blockedOutLoopInd) { + [&](const KrnlBuilder &ck, ValueRange blockedOutLoopInd) { MDBuilder create(ck); - // Create temp inside loop (because of parallel). - Value tmpBlockedAlloca = create.mem.alignedAlloca(tmpBlockedType); + // When parallel, will stay inside; otherwise will migrate out. + Value tmpBlockedAlloc = create.mem.alignedAlloc(tmpBlockedType); Value identity = getIdentityValue( rewriter, create.getLoc(), elementType); Value initVec = create.vec.splat(vecType, identity); IndexExprScope innerScope(ck); IndexExpr blockedCurrIndex = - DimIndexExpr(blockedOutLoopInd[flatOutRank - 1]); - IndexExpr blockedUB = - SymbolIndexExpr(flatOutDims[flatOutRank - 1].getValue()); - IndexExpr isFull = create.krnlIE.isTileFull( - blockedCurrIndex, LiteralIndexExpr(VL), blockedUB); + DimIE(blockedOutLoopInd[flatOutRank - 1]); + IndexExpr blockedUB = SymIE(flatOutDims[flatOutRank - 1].getValue()); + IndexExpr isFull = + create.krnlIE.isTileFull(blockedCurrIndex, LitIE(VL), blockedUB); Value zero = create.math.constantIndex(0); Value isNotFullVal = create.math.slt(isFull.getValue(), zero); create.scf.ifThenElse( isNotFullVal, - [&](SCFBuilder &scf) { + [&](const SCFBuilder &scf) { MDBuilder create(scf); // create.krnl.printf("partial tile\n"); Value startOfLastBlockVal = blockedCurrIndex.getValue(); Value blockedUBVal = blockedUB.getValue(); create.scf.forLoop(startOfLastBlockVal, blockedUBVal, 1, - [&](SCFBuilder &scf, Value blockLocalInd) { + [&](const SCFBuilder &scf, ValueRange loopInd) { MDBuilder create(scf); + Value blockLocalInd = loopInd[0]; // Output induction variables: same as the outer loop, but // with the blocked index replaced by the inner index. SmallVector outLoopInd = @@ -1240,18 +1393,18 @@ struct ONNXReductionOpLowering : public OpConversionPattern { outLoopInd.emplace_back(blockLocalInd); // Perform reduction for one output value. genOneHorizontalSimdReduction(rewriter, create, op, - elementType, vecType, tmpBlockedAlloca, flatInput, + elementType, vecType, tmpBlockedAlloc, flatInput, flatAlloc, initVec, divisorForMean, outLoopInd, - simdUB, VL); + simdUB, VL, simdOnly); }); /* for inside blocked loop */ }, - [&](SCFBuilder &scf) { + [&](const SCFBuilder &scf) { MDBuilder create(scf); // create.krnl.printf("full tile\n"); genVlHorizontalSimdReduction(rewriter, create, op, elementType, - vecType, tmpBlockedAlloca, flatInput, flatAlloc, initVec, + vecType, tmpBlockedAlloc, flatInput, flatAlloc, initVec, divisorForMean, blockedOutLoopInd, blockedCurrIndex, simdUB, - VL); + VL, simdOnly); }); }); /* blocked out loop */ } diff --git a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp index 39fee40fb3..9e19f80d1d 100644 --- a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp @@ -33,7 +33,8 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops, // Compute the maximum value along axis. ValueRange maxLoops = createKrnl.defineLoops(numberOfLoops); auto maxLoop = createKrnl.iterateIE(maxLoops, maxLoops, Lbs, Ubs, maxInits, - [&](KrnlBuilder &createKrnl, ValueRange maxIndices, ValueRange iterArgs) { + [&](const KrnlBuilder &createKrnl, ValueRange maxIndices, + ValueRange iterArgs) { // Get last argument for the iterate body. Value iterArg = iterArgs.back(); @@ -67,7 +68,8 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops, // Compute the sum of all values along axis. ValueRange sumLoops = createKrnl.defineLoops(numberOfLoops); auto sumLoop = createKrnl.iterateIE(sumLoops, sumLoops, Lbs, Ubs, sumInits, - [&](KrnlBuilder &createKrnl, ValueRange sumIndices, ValueRange iterArgs) { + [&](const KrnlBuilder &createKrnl, ValueRange sumIndices, + ValueRange iterArgs) { // Get last argument for the iterate body. Value iterArg = iterArgs.back(); @@ -106,7 +108,7 @@ static void emitInnerLoops(KrnlBuilder &createKrnl, int64_t numberOfLoops, // Compute the softmax. ValueRange softmaxLoops = createKrnl.defineLoops(numberOfLoops); createKrnl.iterateIE(softmaxLoops, softmaxLoops, Lbs, Ubs, - [&](KrnlBuilder &createKrnl, ValueRange softmaxIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange softmaxIndices) { MultiDialectBuilder create(createKrnl); IndexExprScope ieScope(createKrnl); @@ -188,7 +190,7 @@ void emitInstForSoftmax(ConversionPatternRewriter &rewriter, } } create.krnl.iterateIE(outerLoops, outerLoops, outerLbs, outerUbs, - [&](KrnlBuilder &ck, ValueRange outerIndices) { + [&](const KrnlBuilder &ck, ValueRange outerIndices) { MultiDialectBuilder create(ck); @@ -249,7 +251,7 @@ void emitInstForSoftmax(ConversionPatternRewriter &rewriter, // Emit outer loops. create.krnl.iterateIE(outerLoops, outerLoops, outerLbs, outerUbs, - [&](KrnlBuilder &ck, ValueRange outerIndices) { + [&](const KrnlBuilder &ck, ValueRange outerIndices) { MultiDialectBuilder create(ck); IndexExprScope ieScope(ck); diff --git a/src/Conversion/ONNXToKrnl/Math/TopK.cpp b/src/Conversion/ONNXToKrnl/Math/TopK.cpp index 1b937a9c43..358406637e 100644 --- a/src/Conversion/ONNXToKrnl/Math/TopK.cpp +++ b/src/Conversion/ONNXToKrnl/Math/TopK.cpp @@ -69,10 +69,10 @@ struct ONNXTopKOpLowering : public OpConversionPattern { /*ascending=*/ascendingMode); // Produce the final result. - SmallVector zeroDims(rank, LiteralIndexExpr(0)); + SmallVector zeroDims(rank, LitIE(0)); ValueRange loopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(loopDef, loopDef, zeroDims, resDims, - [&](KrnlBuilder &createKrnl, ValueRange resLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange resLoopInd) { Value resInd = createKrnl.load(argSort, resLoopInd); SmallVector resIndexLoopInd(resLoopInd); resIndexLoopInd[axis] = resInd; diff --git a/src/Conversion/ONNXToKrnl/Math/Trilu.cpp b/src/Conversion/ONNXToKrnl/Math/Trilu.cpp index e60261f517..8a1c8dd062 100644 --- a/src/Conversion/ONNXToKrnl/Math/Trilu.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Trilu.cpp @@ -54,7 +54,7 @@ struct ONNXTriluOpLowering : public OpConversionPattern { if (isNoneValue(triluOp.getK())) k = create.math.constantIndex(0); else - k = create.math.castToIndex(create.krnl.load(adaptor.getK(), {})); + k = create.math.castToIndex(create.krnl.load(adaptor.getK())); // Insert an allocation and deallocation for the result of this operation. SmallVector ubs; @@ -63,9 +63,9 @@ struct ONNXTriluOpLowering : public OpConversionPattern { // Main loop. ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { MultiDialectBuilder create( createKrnl); Value i = create.math.add(k, loopInd[rank - 2]); diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index b84bd979ee..980fe62c89 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -47,7 +47,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { auto biasOperand = operandAdaptor.getB(); bool hasBias = !mlir::isa(biasOperand.getType()); int64_t groupNum = convOp.getGroup(); - IndexExpr G = LiteralIndexExpr(groupNum); + IndexExpr G = LitIE(groupNum); Value fZero = create.math.constant(memRefType.getElementType(), 0); // Bounds for output sizes: [N x CO x HO x WO]: @@ -71,8 +71,8 @@ struct ONNXConvOpLowering : public OpConversionPattern { IndexExpr CIPerGroup = create.krnlIE.getShapeAsSymbol(filterOperand, 1); // Determine the bounds for the loops over batch & channel out. - IndexExpr iZero = LiteralIndexExpr(0); - IndexExpr iOne = LiteralIndexExpr(1); + IndexExpr iZero = LitIE(0); + IndexExpr iOne = LitIE(1); SmallVector lbsStorage, ubsStorage, stepsStorage; SmallVector outerLbs = {iZero, iZero, iZero}; @@ -96,24 +96,23 @@ struct ONNXConvOpLowering : public OpConversionPattern { // Compute the channel out index "co". DimIndexExpr g(outerIndices[1]); DimIndexExpr coPerGroup(outerIndices[2]); - IndexExpr co = g * SymbolIndexExpr(COPerGroup) + coPerGroup; + IndexExpr co = g * SymIE(COPerGroup) + coPerGroup; // Compute g * CIPerGroup for later use. - IndexExpr gTimesCIPerGroup = g * SymbolIndexExpr(CIPerGroup); + IndexExpr gTimesCIPerGroup = g * SymIE(CIPerGroup); // Determine the bounds for the output spacial dimensions. int spacialRank = outputRank - spatialStartIndex; ValueRange outputSpacialLoops = create.krnl.defineLoops(spacialRank); SmallVector outputSpacialLbs, outputSpacialUbs; for (int i = spatialStartIndex; i < outputRank; ++i) { outputSpacialLbs.emplace_back(iZero); - outputSpacialUbs.emplace_back( - SymbolIndexExpr(shapeHelper.getOutputDims()[i])); + outputSpacialUbs.emplace_back(SymIE(shapeHelper.getOutputDims()[i])); } // Spacial loops. // for ho = 0 .. HO: // for wo = 0 .. WO: create.krnl.iterateIE(outputSpacialLoops, outputSpacialLoops, outputSpacialLbs, outputSpacialUbs, - [&](KrnlBuilder &createKrnl, ValueRange outputSpatialIndices) { + [&](const KrnlBuilder &createKrnl, ValueRange outputSpatialIndices) { IndexExprScope outputSpacialScope(createKrnl); MultiDialectBuilder @@ -126,7 +125,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { SmallVector redLbs, redUbs, pMinOS; // First: loop over channel in per group. redLbs.emplace_back(iZero); - redUbs.emplace_back(SymbolIndexExpr(CIPerGroup)); + redUbs.emplace_back(SymIE(CIPerGroup)); // For each spacial dim, do the following. for (int i = 0; i < spacialRank; ++i) { // Get data for dis spacial dimension. @@ -156,7 +155,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { // for kw in lb .. ub: auto innerIterate = create.krnl.iterateIE(redLoops, redLoops, redLbs, redUbs, inits, - [&](KrnlBuilder &createKrnl, ValueRange redIndices, + [&](const KrnlBuilder &createKrnl, ValueRange redIndices, ValueRange iterArgs) { // Get last argument for the iterate body. Value iterArg = iterArgs.back(); @@ -172,7 +171,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { inputAccessFct.emplace_back(n); // ci = g * CIPerG + ciPerG DimIndexExpr ciPerG(redIndices[0]); - IndexExpr ci = SymbolIndexExpr(gTimesCIPerGroup) + ciPerG; + IndexExpr ci = SymIE(gTimesCIPerGroup) + ciPerG; inputAccessFct.emplace_back(ci); for (int i = 0; i < spacialRank; ++i) { // for each spacial dims: access is o * s + k * d - p. @@ -187,8 +186,8 @@ struct ONNXConvOpLowering : public OpConversionPattern { create.krnl.loadIE(inputOperand, inputAccessFct); // Create access fct for filter: [co, ciPerG, kh, kw]. SmallVector filterAccessFct; - filterAccessFct.emplace_back(DimIndexExpr(co)); - filterAccessFct.emplace_back(DimIndexExpr(ciPerG)); + filterAccessFct.emplace_back(DimIE(co)); + filterAccessFct.emplace_back(DimIE(ciPerG)); for (int i = 0; i < spacialRank; ++i) { DimIndexExpr k(redIndices[1 + i]); @@ -210,10 +209,10 @@ struct ONNXConvOpLowering : public OpConversionPattern { result = create.math.add(result, bias); } SmallVector resAccessFunc; - resAccessFunc.emplace_back(SymbolIndexExpr(outerIndices[0])); + resAccessFunc.emplace_back(SymIE(outerIndices[0])); resAccessFunc.emplace_back(coInOutputSpacial); for (Value o : outputSpatialIndices) - resAccessFunc.emplace_back(DimIndexExpr(o)); + resAccessFunc.emplace_back(DimIE(o)); create.krnl.storeIE(result, alloc, resAccessFunc); }); // Output spacial loops. }; @@ -231,7 +230,7 @@ struct ONNXConvOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(outerLoops, outerLoops, outerLbs, outerUbs, - [&](KrnlBuilder &create, ValueRange outerIndices) { + [&](const KrnlBuilder &create, ValueRange outerIndices) { bodyFunction(outerIndices); }); } diff --git a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp index 7ff02d5849..417980aa94 100644 --- a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp @@ -193,6 +193,7 @@ struct ONNXInstanceNormalizationOpLowering create.krnlIE.getShapeAsSymbols(inputMemRef, inputBounds); MemRefType tmpType = MemRefType::get({}, elementType); Value fZero = create.math.constant(elementType, 0); + // Ok to use alloca, just one scalar. Value tmpMemRef = create.mem.alloca(tmpType); // Compute the number of values in a single channel: product of spatial @@ -208,7 +209,7 @@ struct ONNXInstanceNormalizationOpLowering ValueRange n_c_loopDef = create.krnl.defineLoops(2); create.krnl.iterateIE(n_c_loopDef, n_c_loopDef, {iZero, iZero}, {inputBounds[0], inputBounds[1]}, - [&](KrnlBuilder &ck, ValueRange n_c_loopInd) { + [&](const KrnlBuilder &ck, ValueRange n_c_loopInd) { MultiDialectBuilder create( ck); IndexExprScope channelScope(ck); @@ -219,16 +220,16 @@ struct ONNXInstanceNormalizationOpLowering SmallVector lbs(rank - 2, iZero); SmallVector ubs; for (int d = 2; d < rank; ++d) - ubs.emplace_back(SymbolIndexExpr(inputBounds[d])); + ubs.emplace_back(SymIE(inputBounds[d])); // First compute the mean: store zero in reduction value, then sum up // all of the values in the channel, and divide by the number of // values. - create.krnl.store(fZero, tmpMemRef, {}); + create.krnl.store(fZero, tmpMemRef); // Iterate over kernel and add values. ValueRange spatial2_loopDef = create.krnl.defineLoops(rank - 2); create.krnl.iterateIE(spatial2_loopDef, spatial2_loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { MultiDialectBuilder create( createKrnl); SmallVector inputAccessFct = { @@ -236,7 +237,7 @@ struct ONNXInstanceNormalizationOpLowering for (int d = 0; d < rank - 2; ++d) inputAccessFct.emplace_back(spatial_loopInd[d]); // tmp += input[n,c, spatial dims] - Value oldSum = create.krnl.load(tmpMemRef, {}); + Value oldSum = create.krnl.load(tmpMemRef); Value val = create.krnl.load(inputMemRef, inputAccessFct); Value newSum = create.math.add(oldSum, val); create.krnl.store(newSum, tmpMemRef); @@ -244,10 +245,10 @@ struct ONNXInstanceNormalizationOpLowering Value sum = create.krnl.load(tmpMemRef); Value mean = create.math.div(sum, meanDenom); // Second, compute the standard dev: sum of (val - mean)2 / (num-1). - create.krnl.store(fZero, tmpMemRef, {}); + create.krnl.store(fZero, tmpMemRef); // Iterate over kernel and add values. create.krnl.iterateIE(spatial_loopDef, spatial_loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { MultiDialectBuilder create( createKrnl); SmallVector inputAccessFct = { @@ -255,7 +256,7 @@ struct ONNXInstanceNormalizationOpLowering for (int d = 0; d < rank - 2; ++d) inputAccessFct.emplace_back(spatial_loopInd[d]); // tmp += input[n,c, spatial dims] - Value oldSum = create.krnl.load(tmpMemRef, {}); + Value oldSum = create.krnl.load(tmpMemRef); Value val = create.krnl.load(inputMemRef, inputAccessFct); val = create.math.sub(val, mean); val = create.math.mul(val, val); @@ -278,7 +279,7 @@ struct ONNXInstanceNormalizationOpLowering // + term. ValueRange spatial3_loopDef = create.krnl.defineLoops(rank - 2); create.krnl.iterateIE(spatial3_loopDef, spatial3_loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange spatial_loopInd) { MultiDialectBuilder create( createKrnl); SmallVector accessFct = {n.getValue(), c.getValue()}; @@ -465,7 +466,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { IndexExpr &modFactor) const { DimsExpr &operandDims = shapeHelper.inputsDims[operandIndex]; int64_t operandRank = mlir::cast(operand.getType()).getRank(); - modFactor = LiteralIndexExpr(1); + modFactor = LitIE(1); // X: X0 X1 X2 | X3 X4 X5 . // ^ | ^ ^ @@ -773,10 +774,11 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { create.vec.store(initVec, redMemRef2, {o, zero}); }); // Perform reduction of entire vectors. - IndexExpr izero = LiteralIndexExpr(0); - create.affineKMem.forIE(izero, redDim, totVL, - [&](onnx_mlir::AffineBuilderKrnlMem &ck, mlir::Value j) { + IndexExpr izero = LitIE(0); + create.affineKMem.forLoopIE(izero, redDim, totVL, + [&](const onnx_mlir::AffineBuilderKrnlMem &ck, ValueRange loopInd) { MDBuilder create(ck); + Value j = loopInd[0]; // load X, compute X**2, sum into reductions. inlineFor(create, B, [&](int64_t d, Value o) { Value ii = create.math.add(i, o); @@ -828,9 +830,10 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { invStdDev[d] = create.math.div(oneFloat, stdDev); }); // Normalize of entire vectors. - create.affineKMem.forIE(izero, redDim, totVL, - [&](onnx_mlir::AffineBuilderKrnlMem &ck, mlir::Value j) { + create.affineKMem.forLoopIE(izero, redDim, totVL, + [&](const onnx_mlir::AffineBuilderKrnlMem &ck, ValueRange loopInd) { MDBuilder create(ck); + Value j = loopInd[0]; // load X, compute X**2, sum into reductions. inlineFor(create, B, [&](int64_t d, Value o) { Value ii = create.math.add(i, o); @@ -939,17 +942,14 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { invStdDevFlatMemRef); // Alloc mem for reductions (should be private if parallel) MemRefType tmpRedType = MemRefType::get({B, totVL}, elementType); - // Iterate over 1st dim by block - ValueRange loopDefs = create.krnl.defineLoops(1); - IndexExpr zero = LiteralIndexExpr(0); - ValueRange blockedLoopDefs = create.krnl.block(loopDefs[0], B); - Value blockedLoopDef = blockedLoopDefs[0]; + // Iterate over 1st dim by block B. + bool useParallel = false; if (enableParallel) { int64_t parId; - SmallVector lb(1, LiteralIndexExpr(0)), ub(1, XFlatDims[0]); + SmallVector lb(1, LitIE(0)), ub(1, XFlatDims[0]); if (findSuitableParallelDimension(lb, ub, 0, 1, parId, /*min iter for going parallel*/ 4)) { - create.krnl.parallel(blockedLoopDef); + useParallel = true; onnxToKrnlParallelReport(op, true, 0, lb[0], ub[0], "in layer-norm"); } else { onnxToKrnlParallelReport( @@ -958,28 +958,38 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { } else { onnxToKrnlParallelReport(op, false, -1, -1, "no parallel in layer norm"); } - create.krnl.iterateIE({loopDefs[0]}, {blockedLoopDef}, {zero}, - {XFlatDims[0]}, [&](KrnlBuilder &ck, ValueRange blockedLoopIndices) { + Value tmpRedMemRef, tmpRedMemRef2; + if (!useParallel) { + // Sequential, alloc before loop. + tmpRedMemRef = create.mem.alignedAlloc(tmpRedType); + tmpRedMemRef2 = create.mem.alignedAlloc(tmpRedType); + } + create.krnl.forLoopIE(LitIE(0), XFlatDims[0], /*step*/ B, useParallel, + [&](const KrnlBuilder &ck, ValueRange blockedLoopIndices) { MDBuilder create(ck); IndexExprScope innerScope(ck); - Value tmpRedMemRef = create.mem.alignedAlloca(tmpRedType); - Value tmpRedMemRef2 = create.mem.alignedAlloca(tmpRedType); - IndexExpr blockedCurrIndex = DimIndexExpr(blockedLoopIndices[0]); - IndexExpr blockedUB = SymbolIndexExpr(XFlatDims[0]); - IndexExpr isFull = create.krnlIE.isTileFull( - blockedCurrIndex, LiteralIndexExpr(B), blockedUB); + if (useParallel) { + // Parallel, alloc inside parallel loop. + tmpRedMemRef = create.mem.alignedAlloc(tmpRedType); + tmpRedMemRef2 = create.mem.alignedAlloc(tmpRedType); + } + IndexExpr blockedCurrIndex = DimIE(blockedLoopIndices[0]); + IndexExpr blockedUB = SymIE(XFlatDims[0]); + IndexExpr isFull = + create.krnlIE.isTileFull(blockedCurrIndex, LitIE(B), blockedUB); Value zero = create.math.constantIndex(0); Value isNotFullVal = create.math.slt(isFull.getValue(), zero); create.scf.ifThenElse( isNotFullVal, - [&](SCFBuilder &scf) { + [&](const SCFBuilder &scf) { MDBuilder create(scf); // create.krnl.printf("partial tile\n"); Value startOfLastBlockVal = blockedCurrIndex.getValue(); Value blockedUBVal = blockedUB.getValue(); create.scf.forLoop(startOfLastBlockVal, blockedUBVal, 1, - [&](SCFBuilder &scf, Value blockLocalInd) { + [&](const SCFBuilder &scf, ValueRange loopInd) { MDBuilder create(scf); + Value blockLocalInd = loopInd[0]; generateIterWithSIMD(rewriter, create, lnOp, XFlatMemRef, scaleFlatMemRef, biasFlatMemRef, YFlatMemRef, meanFlatMemRef, invStdDevFlatMemRef, tmpRedMemRef, @@ -988,7 +998,7 @@ struct GenericLayerNormaOpLowering : public OpConversionPattern { scaleModFactor, biasModFactor); }); /* for inside blocked loop */ }, - [&](SCFBuilder &scf) { + [&](const SCFBuilder &scf) { MDBuilder create(scf); // create.krnl.printf("full tile\n"); generateIterWithSIMD(rewriter, create, lnOp, XFlatMemRef, diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index 31e259c7f4..18435a5795 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -4,7 +4,7 @@ //===---------------- Pooling.cpp - Lowering Pooling Ops ------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -153,7 +153,7 @@ void postProcessPoolingWindow( Value numerator = create.krnl.load(alloc, resultIndices); Value denominator; if (countIncludePad) { - IndexExpr kernelSize = LiteralIndexExpr(1); + IndexExpr kernelSize = LitIE(1); for (unsigned int i = 0; i < kernelShape.size(); ++i) kernelSize = kernelSize * kernelShape[i]; denominator = kernelSize.getValue(); @@ -310,7 +310,7 @@ struct ONNXPoolOpLowering : public OpConversionPattern { // Identity value of the operation. auto identity = getIdentityValue(rewriter, loc, outputElementType); // Create a local reduction value for output[n][c][ho][wo]. - // Single scalar, no need for default alignment. + // Single scalar, no need for default alignment. Ok to use alloca. Value reductionVal = create.mem.alloca(MemRefType::get({}, memRefType.getElementType())); @@ -320,11 +320,11 @@ struct ONNXPoolOpLowering : public OpConversionPattern { // for ho in range(HO): // for wo in range(WO): ValueRange calcLoopDef = create.krnl.defineLoops(outputShape.size()); - SmallVector lbs(outputShape.size(), LiteralIndexExpr(0)); + SmallVector lbs(outputShape.size(), LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); create.krnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { MultiDialectBuilder create(createKrnl); @@ -334,7 +334,7 @@ struct ONNXPoolOpLowering : public OpConversionPattern { // pixel. SmallVector outputIndices; for (unsigned int i = 0; i < outputShape.size(); ++i) - outputIndices.emplace_back(DimIndexExpr(loopInd[i])); + outputIndices.emplace_back(DimIE(loopInd[i])); // 2.1 Emit: output[n][c][ho][wo] = identity create.krnl.store(identity, reductionVal); @@ -359,13 +359,13 @@ struct ONNXPoolOpLowering : public OpConversionPattern { // s0, input dim ic.emplace_back(create.krnlIE.getShapeAsDim(inputOperand, j)); // s1, kernel dim - ic.emplace_back(SymbolIndexExpr(shapeHelper.kernelShape[i])); + ic.emplace_back(SymIE(shapeHelper.kernelShape[i])); // s2, pad dim - ic.emplace_back(SymbolIndexExpr(shapeHelper.pads[i])); + ic.emplace_back(SymIE(shapeHelper.pads[i])); // s3, stride dim - ic.emplace_back(LiteralIndexExpr(shapeHelper.strides[i])); + ic.emplace_back(LitIE(shapeHelper.strides[i])); // s4, dilation dim - ic.emplace_back(LiteralIndexExpr(shapeHelper.dilations[i])); + ic.emplace_back(LitIE(shapeHelper.dilations[i])); IVExprs.emplace_back(ic); } @@ -445,7 +445,8 @@ struct ONNXPoolOpLowering : public OpConversionPattern { { // Construct inputIndices for (int i = 0; i < kernelOffset; ++i) inputIndices.emplace_back(outputIndices[i]); - for (int i = kernelOffset; i < (int)inputShape.size(); ++i) { + for (int i = kernelOffset; + i < static_cast(inputShape.size()); ++i) { int j = i - kernelOffset; DimIndexExpr hp(poolingLoopInd[j]); IndexExpr startH = windowStartExprs[j]; diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 861f6bb4dc..ad35f5a3fb 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -255,7 +255,7 @@ namespace { // Returns the DenseElementsAttr of input if it's a krnl.global constant or // onnx.Constant, or if it's one step removed from a krnl/onnx constant by a // builtin.unrealized_conversion_cast. Otherwise returns a nullptr attribute. -DenseElementsAttr getDenseElementAttrFromConstValue(mlir::Value value) { +DenseElementsAttr getDenseElementAttrFromConstValue(Value value) { Operation *definingOp = value.getDefiningOp(); if (auto castOp = dyn_cast_or_null(definingOp)) { if (castOp.getNumOperands() != 1) @@ -318,8 +318,7 @@ Value foldOrEmitONNXTransposeOpKrnl(ConversionPatternRewriter &rewriter, /// Emit MemRef ReinterpretCastOp to create a new view for 'data'. /// The new view is created using the given 'outputDims'. Value emitMemRefReinterpretCastOp(ConversionPatternRewriter &rewriter, - Location loc, Value data, SmallVectorImpl &outputDims, - Type outputType) { + Location loc, Value data, DimsExpr &outputDims, Type outputType) { MemRefBuilder createMemRef(rewriter, loc); Value newView = createMemRef.reinterpretCast(data, outputDims); // Set type to the output type to avoid unrealized_conversion_cast. @@ -355,7 +354,7 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, Value order = create.mem.alignedAlloc(type, ubs); ValueRange initLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(initLoopDef, initLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // order[axis_0, axis_1, ..., axis_k-1, k, axis_k+1, ....] = k createKrnl.store(loopInd[axis], order, loopInd); }); @@ -365,7 +364,8 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, // Emit krnl.Call to call omTensorSort API Type intType = rewriter.getIntegerType(64); Value valAxis = create.math.constant(intType, axis); - Value valAscending = create.math.constant(intType, (int64_t)ascending); + Value valAscending = + create.math.constant(intType, static_cast(ascending)); SmallVector operands = {order, input, valAxis, valAscending}; rewriter.create(loc, "omTensorSort", 1, operands); return order; @@ -376,11 +376,10 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, outerUbs[axis] = ubs[axis] - oneIE; ValueRange loopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(loopDef, loopDef, lbs, outerUbs, - [&](KrnlBuilder &createKrnl, ValueRange iLoopInd) { - IndexExpr i1 = DimIndexExpr(iLoopInd[axis]) + oneIE; - ValueRange swapLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(swapLoopDef, swapLoopDef, {i1}, {ubs[axis]}, - [&](KrnlBuilder &ck, ValueRange swapLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange iLoopInd) { + IndexExpr i1 = DimIE(iLoopInd[axis]) + oneIE; + createKrnl.forLoopIE(i1, ubs[axis], /*step*/ 1, /*parallel*/ false, + [&](const KrnlBuilder &ck, ValueRange swapLoopInd) { MultiDialectBuilder create( ck); SmallVector kLoopInd(iLoopInd); @@ -402,7 +401,7 @@ Value emitArgSort(ConversionPatternRewriter &rewriter, Location loc, cond = create.math.sgt(x, y); else cond = create.math.slt(x, y); - create.scf.ifThenElse(cond, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(cond, [&](const SCFBuilder &createSCF) { KrnlBuilder createKrnl(createSCF); createKrnl.store(kOrd, order, iLoopInd); createKrnl.store(iOrd, order, kLoopInd); @@ -422,7 +421,7 @@ Value getOptionalScalarValue(ConversionPatternRewriter &rewriter, Location loc, if (mlir::isa(optionalScalar.getType())) { return create.math.constant(elementType, defaultValue); } else if (mlir::cast(optionalScalar.getType()).getRank() == 0) { - return create.krnl.load(optionalScalar, {}); + return create.krnl.load(optionalScalar); } else { Value zero = create.math.constantIndex(0); return create.krnl.load(optionalScalar, {zero}); @@ -548,20 +547,18 @@ KrnlTypeConverter::KrnlTypeConverter() { }); addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) - return std::nullopt; + return Value(); return builder.create(loc, resultType, inputs) .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) - return std::nullopt; + return Value(); return builder.create(loc, resultType, inputs) .getResult(0); @@ -611,13 +608,13 @@ bool hasNonIdentityLayout(ValueRange operands) { // requirement by definition. If found one, it is parDim and the function // returns true. -bool findSuitableParallelDimension(llvm::SmallVectorImpl &lb, - llvm::SmallVectorImpl &ub, int64_t firstInclusiveDim, - int64_t lastExclusiveDim, int64_t &parDim, int64_t minSize) { +bool findSuitableParallelDimension(ArrayRef lb, + ArrayRef ub, int64_t firstInclusiveDim, int64_t lastExclusiveDim, + int64_t &parDim, int64_t minSize) { assert(lb.size() == ub.size() && "expected identical ranks for lb/ub"); if (firstInclusiveDim < 0) firstInclusiveDim = 0; - if (lastExclusiveDim > (int64_t)lb.size()) + if (lastExclusiveDim > static_cast(lb.size())) lastExclusiveDim = lb.size(); for (int64_t i = firstInclusiveDim; i < lastExclusiveDim; ++i) { IndexExpr tripCount = ub[i] - lb[i]; @@ -662,22 +659,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType, return 1; } // Gather operation statics - int64_t vectorizedOpNum, scalarOpNum; - double avgVL = VectorMachineSupport::getAvgArchVectorLength( - genOps, elementType, vectorizedOpNum, scalarOpNum); + int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure; + double avgVL = + VectorMachineSupport::getAvgArchVectorLength(genOps, elementType, + vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure); if (avgVL < 1.5) { LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with " << avgVL << " avg VL\n"); return 1; } - LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n"); + LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL + << ", vec op num " << vectorizedOpNum + << ", max reg pressure " + << estimatedMaxVectorRegisterPressure << "\n"); // Define a target max unroll as a function of register pressure. int64_t unrollVL; int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum(); - if (vectorizedOpNum >= vrNum / 2) + if (estimatedMaxVectorRegisterPressure >= vrNum) + unrollVL = 1; + else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum) unrollVL = 2; - else if (vectorizedOpNum >= vrNum / 4) + else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum) unrollVL = 4; else unrollVL = 8; @@ -743,6 +746,22 @@ int64_t capVLForMaxUnroll( return archVL * unrollVL; } +int64_t boostVLForMinUnroll( + MemRefType memRefType, MemRefType convertedMemRefType, int64_t totVL) { + if (totVL == 1) + return 1; // Simd already disabled, nothing to cap. + Type convertedElementType = convertedMemRefType.getElementType(); + int64_t convertedArchVL = + VectorMachineSupport::getArchVectorLength(convertedElementType); + if (convertedArchVL > totVL) { + LLVM_DEBUG(llvm::dbgs() + << " simd enable: boost totVL to " << convertedArchVL + << " because of type conversions.\n"); + return convertedArchVL; + } + return totVL; +} + int64_t capVLForSimdOnly( MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) { if (totVL == 1) @@ -794,7 +813,7 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType, } // Unless otherwise disabled, here is the estimated trip count. if (canOverCompute && - collapsedInnermostLoops == (int64_t)memRefType.getRank()) { + collapsedInnermostLoops == static_cast(memRefType.getRank())) { // Fully collapsed and can add padding to be fine simdLoopStaticTripCount = isStaticSize ? staticSize : -1; return maxUnrollVL * archVL; @@ -838,7 +857,8 @@ void impl::onnxToKrnlParallelReport(Operation *op, bool successful, // Print report on this op. printf("==PAR-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(), (successful ? "-par" : ""), nodeNameStr.c_str(), comment.c_str(), - (long long int)loopLevel, (long long int)parallelLoopTripCount); + static_cast(loopLevel), + static_cast(parallelLoopTripCount)); } void impl::onnxToKrnlSimdReport(Operation *op, bool successful, @@ -858,7 +878,8 @@ void impl::onnxToKrnlSimdReport(Operation *op, bool successful, // Print report on this op. printf("==SIMD-REPORT==, %s%s, %s, %s, %lld, %lld\n", opName.data(), (successful ? "-simd" : ""), nodeNameStr.c_str(), message.c_str(), - (long long int)vectorLength, (long long int)simdLoopTripCount); + static_cast(vectorLength), + static_cast(simdLoopTripCount)); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 485ece2370..3db45b4525 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -159,8 +159,7 @@ mlir::Value foldOrEmitONNXTransposeOpKrnl( /// The new view is created using the given 'outputDims'. mlir::Value emitMemRefReinterpretCastOp( mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, - mlir::Value data, llvm::SmallVectorImpl &outputDims, - mlir::Type outputType); + mlir::Value data, DimsExpr &outputDims, mlir::Type outputType); /// Emit krnl iterate to compute argsort of a given MemRef along a given axis. /// Output MemRef has the same shape as the input MemRef but is of IndexType. @@ -309,7 +308,7 @@ class KrnlTypeConverter : public mlir::TypeConverter { // For all ONNX operations. void populateONNXToKrnlConversionPattern(mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *, bool enableTiling, - bool enableParallel); + bool enableParallel, bool enableFastMath); // `ControlFlow` directory methods: void populateLoweringONNXIfOpPattern( @@ -318,6 +317,8 @@ void populateLoweringONNXLoopOpPattern( mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXScanOpPattern( mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXYieldOpPattern( + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); // `Math` directory methods: void populateLoweringONNXClipOpPattern( @@ -379,10 +380,10 @@ void populateLoweringONNXNonMaxSuppressionOpPattern( // `Quantization` directory methods: void populateLoweringONNXDynamicQuantizeLinearOpPattern( mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *, - bool enableSIMD, bool enableParallel); + bool enableSIMD, bool enableParallel, bool enableFastMath); void populateLoweringONNXQuantizeLinearOpPattern(mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *, bool enableSIMD, - bool enableParallel); + bool enableParallel, bool enableFastMath); // `RNN` directory methods: void populateLoweringONNXGRUOpPattern( @@ -612,8 +613,8 @@ bool hasNonIdentityLayout(mlir::ValueRange operands); // Return the outermost loop within [firstDim, lastDim) for which (ub-lb) >= // minSize. Runtime dimensions are assumed to satisfy the size requirement by // definition. If found one, it is parDim and the function returns true. -bool findSuitableParallelDimension(llvm::SmallVectorImpl &lb, - llvm::SmallVectorImpl &ub, int64_t firstInclusiveDim, +bool findSuitableParallelDimension(mlir::ArrayRef lb, + mlir::ArrayRef ub, int64_t firstInclusiveDim, int64_t lastExclusiveDim, int64_t &parDim, int64_t minSize = 4); //===----------------------------------------------------------------------===// @@ -662,6 +663,12 @@ int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType, // Cap totVL so that it is at most maxUnrollVL * archVL. int64_t capVLForMaxUnroll( mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL); +// In some type conversion loops we may have a given totVL based on a given +// memRef type and gen op mix. But the final result may be converted to a +// different type, which may requires a minimum unroll to proceed as a single +// SIMD operation. This call adjust the totVL for that case. +int64_t boostVLForMinUnroll(mlir::MemRefType memRefType, + mlir::MemRefType convertedMemRefType, int64_t totVL); // Enabling a simdOnly code generation scheme by capping totVL so that it // divides simdLoopStaticTripCount. When not possible (either because // there is no totVL that divides simdLoopStaticTripCount or trip count is @@ -747,5 +754,21 @@ void emitMinMaxReductionToScalar(mlir::ConversionPatternRewriter &rewriter, mlir::Value &minAlloc, mlir::Value &maxAlloc, bool enableSIMD, bool enableParallel); +// Compute the reciprocal scale (recscale) for the symmetric quantization. Can +// generate parallel and SIMD code as requested. Formula for recscale: +// ``` +// recscale = (2^(b-1) - 1) / absmax(X) +// ``` +// where +// - X is the input tensor, +// - b is the number of bits we want to quantize to (e.g. 8 for integer 8), and +// - absmax is a function to compute the absolute maximun value over entire +// tensor +// +void emitSymmetricQuantRecscaleToScalar( + mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Operation *op, mlir::Value input, uint64_t bitWidth, + mlir::Value &recscale, bool enableSIMD, bool enableParallel); + } // namespace onnx_mlir #endif diff --git a/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp b/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp index ab216bea57..fdaef429e0 100644 --- a/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp +++ b/src/Conversion/ONNXToKrnl/ObjectDetection/NonMaxSuppression.cpp @@ -110,51 +110,52 @@ static void suppressByScores(ConversionPatternRewriter &rewriter, Location loc, Value zero = create.math.constantIndex(0); Value one = create.math.constantIndex(1); // Store the number of scores whose value is greater than the threshold. + // Scalar, ok to use alloca. Value topk = create.mem.alloca(MemRefType::get({}, indexType)); // Compute the effective max output per class. Value effectiveMaxPerClass = create.mem.alloca(MemRefType::get({}, indexType)); - create.krnl.store(zero, effectiveMaxPerClass, {}); + create.krnl.store(zero, effectiveMaxPerClass); ValueRange bcLoopDef = create.krnl.defineLoops(2); create.krnl.iterate(bcLoopDef, bcLoopDef, {zero, zero}, {bs, cs}, - [&](KrnlBuilder &createKrnl, ValueRange bcLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange bcLoopInd) { MultiDialectBuilder create( createKrnl); Value b(bcLoopInd[0]), c(bcLoopInd[1]); // Reset the number of scores whose value is greater than the // threshold. Counting is done per class. - create.krnl.store(zero, topk, {}); + create.krnl.store(zero, topk); // Count the number of scores whose value is greater than the // threshold. Counting is done per class. ValueRange sLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(sLoopDef, sLoopDef, {zero}, {ss}, - [&](KrnlBuilder &createKrnl, ValueRange sLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange sLoopInd) { Value s(sLoopInd[0]); MathBuilder createMath(createKrnl); Value score = createKrnl.load(scores, {b, c, s}); // Increase the counter if score > threshold. Value gt = createMath.sgt(score, scoreThreshold); - Value topkVal = createKrnl.load(topk, {}); + Value topkVal = createKrnl.load(topk); Value topkPlusOneVal = createMath.add(topkVal, one); topkVal = createMath.select(gt, topkPlusOneVal, topkVal); - createKrnl.store(topkVal, topk, {}); + createKrnl.store(topkVal, topk); }); // Update the effective max output per class. - Value x = create.krnl.load(topk, {}); - Value y = create.krnl.load(effectiveMaxPerClass, {}); - create.krnl.store(create.math.max(x, y), effectiveMaxPerClass, {}); + Value x = create.krnl.load(topk); + Value y = create.krnl.load(effectiveMaxPerClass); + create.krnl.store(create.math.max(x, y), effectiveMaxPerClass); }); // Suppress the number of output bounding boxes per class. - Value x = create.krnl.load(maxOutputPerClass, {}); - Value y = create.krnl.load(effectiveMaxPerClass, {}); - create.krnl.store(create.math.min(x, y), maxOutputPerClass, {}); + Value x = create.krnl.load(maxOutputPerClass); + Value y = create.krnl.load(effectiveMaxPerClass); + create.krnl.store(create.math.min(x, y), maxOutputPerClass); } /// Bounding boxes may contain a mix of flipped and non-flipped boxes. Try to @@ -175,7 +176,7 @@ static Value tryToUnflip( ValueRange loopDef = create.krnl.defineLoops(2); create.krnl.iterateIE(loopDef, loopDef, {zeroIE, zeroIE}, {bs, ss}, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { MathBuilder createMath(createKrnl); DimIndexExpr b(loopInd[0]), s(loopInd[1]); // Load a bounding box. @@ -272,13 +273,14 @@ struct ONNXNonMaxSuppressionOpLowering // Refine the number of output boxes per class by suppressing it using // spatial dimension size and score threshold. + // Scalar, ok to use alloca. Value maxOutputPerClass = create.mem.alloca(MemRefType::get({}, indexType)); // 1. Suppress by using spatial dimension size. Value x = create.math.castToIndex(maxOutputBoxPerClass); - create.krnl.store(create.math.min(x, ss), maxOutputPerClass, {}); + create.krnl.store(create.math.min(x, ss), maxOutputPerClass); // 2. Suppress by score threshold. suppressByScores(rewriter, loc, scores, scoreTH, maxOutputPerClass); - Value MOPC = create.krnl.load(maxOutputPerClass, {}); + Value MOPC = create.krnl.load(maxOutputPerClass); // Sort scores in the descending order. Value order = emitArgSort(rewriter, loc, scores, /*axis=*/2, @@ -290,14 +292,13 @@ struct ONNXNonMaxSuppressionOpLowering boxes = tryToUnflip(rewriter, loc, boxes); // The total number of output selected indices. - IndexExpr numSelectedIndicesIE = bsIE * csIE * DimIndexExpr(MOPC); + IndexExpr numSelectedIndicesIE = bsIE * csIE * DimIE(MOPC); // Allocate a MemRef for the output. This MemRef is NOT the final output // since the number of selected indices has yet not suppressed by IOU. So // the first dimension size is larger than necessary. // Output shape : [num_selected_indices, 3] - SmallVector outputDims = { - numSelectedIndicesIE, LiteralIndexExpr(3)}; + SmallVector outputDims = {numSelectedIndicesIE, LitIE(3)}; SmallVector outputShape; if (numSelectedIndicesIE.isLiteral()) outputShape.emplace_back(numSelectedIndicesIE.getLiteral()); @@ -313,9 +314,10 @@ struct ONNXNonMaxSuppressionOpLowering // dim of the output, which is suppressed by IOU during computation and // cannot be computed in advance. // Final output shape : [effective_num_selected_indices, 3] + // Scalar, ok to use alloca. Value effectiveNumSelectedIndices = create.mem.alloca(MemRefType::get({}, indexType)); - create.krnl.store(zero, effectiveNumSelectedIndices, {}); + create.krnl.store(zero, effectiveNumSelectedIndices); // Suppress by using IOU. // Iterate over all bounding boxes in the descending order of scores. @@ -323,11 +325,11 @@ struct ONNXNonMaxSuppressionOpLowering create.mem.alloca(MemRefType::get({}, indexType)); ValueRange bcLoopDef = create.krnl.defineLoops(2); create.krnl.iterate(bcLoopDef, bcLoopDef, {zero, zero}, {bs, cs}, - [&](KrnlBuilder &createKrnl, ValueRange bcLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange bcLoopInd) { MultiDialectBuilder create( createKrnl); // Keep trace of the number of output boxes per class. - create.krnl.store(zero, effectiveMaxOutputPerClass, {}); + create.krnl.store(zero, effectiveMaxOutputPerClass); // Keep trace of removed indices per class. DimIndexExpr ssIE(ss); SmallVector dims = {ssIE}; @@ -341,7 +343,7 @@ struct ONNXNonMaxSuppressionOpLowering // Iterate in the descending order of scores. ValueRange sLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(sLoopDef, sLoopDef, {zero}, {ss}, - [&](KrnlBuilder &createKrnl, ValueRange sLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange sLoopInd) { Value b(bcLoopInd[0]), c(bcLoopInd[1]), s(sLoopInd[0]); MultiDialectBuilder create( createKrnl); @@ -355,7 +357,7 @@ struct ONNXNonMaxSuppressionOpLowering Value checkScore = create.math.sgt(score, scoreTH); // 2. Have not yet got enough outputs. Value currentMOPC = - create.krnl.load(effectiveMaxOutputPerClass, {}); + create.krnl.load(effectiveMaxOutputPerClass); Value checkMOPC = create.math.slt(currentMOPC, MOPC); // 3. Bounding box has not yet been removed. Value isRemoved = @@ -381,7 +383,7 @@ struct ONNXNonMaxSuppressionOpLowering // Store the index of the selected box to the output. // out_index = effective_num_selected_indices // selected_indices[out_index] = [b, c, selected_box_index] - Value soVal = create.krnl.load(effectiveNumSelectedIndices, {}); + Value soVal = create.krnl.load(effectiveNumSelectedIndices); create.krnl.store(b, selectedMemRef, {soVal, zero}); create.krnl.store(c, selectedMemRef, {soVal, one}); create.krnl.store(selectedBI, selectedMemRef, {soVal, two}); @@ -400,7 +402,7 @@ struct ONNXNonMaxSuppressionOpLowering // using IOU. ValueRange oLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(oLoopDef, oLoopDef, {zero}, {ss}, - [&](KrnlBuilder &createKrnl, ValueRange oLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange oLoopInd) { Value o(oLoopInd[0]); MathBuilder createMath(createKrnl); @@ -438,9 +440,8 @@ struct ONNXNonMaxSuppressionOpLowering }); // Insert allocation and deallocation for the final output. - Value effectiveNSI = create.krnl.load(effectiveNumSelectedIndices, {}); - SmallVector resDims = { - DimIndexExpr(effectiveNSI), LiteralIndexExpr(3)}; + Value effectiveNSI = create.krnl.load(effectiveNumSelectedIndices); + SmallVector resDims = {DimIE(effectiveNSI), LitIE(3)}; Value resMemRef = create.mem.alignedAlloc( MemRefType::get({ShapedType::kDynamic, 3}, elementType), resDims); @@ -448,7 +449,7 @@ struct ONNXNonMaxSuppressionOpLowering ValueRange resLoopDef = create.krnl.defineLoops(2); create.krnl.iterate(resLoopDef, resLoopDef, {zero, zero}, {effectiveNSI, three}, - [&](KrnlBuilder &createKrnl, ValueRange resLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange resLoopInd) { MathBuilder createMath(createKrnl); Value load = createKrnl.load(selectedMemRef, resLoopInd); Value res = createMath.cast(elementType, load); diff --git a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp index fdebe15d86..1222f98708 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" @@ -23,18 +24,27 @@ using namespace mlir; namespace onnx_mlir { -// Implementation of quantize helper function. -// TODO: add parallel. -void emitDynamicQuantizationLinearScalarParameters( - ConversionPatternRewriter &rewriter, Location loc, Operation *op, - MemRefType inputType, MemRefType quantizedType, Value input, Value qMin, - Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint, +// Implementation of computing the min and max over an entire tensor. Can +// generate parallel and SIMD code as requested. +void emitDynamicQuantizationLinearMinMax(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Value input, Value &inputMin, Value &inputMax, bool enableSIMD, bool enableParallel) { - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create(rewriter, loc); + Value inputMinAlloc, inputMaxAlloc; + emitMinMaxReductionToScalar(rewriter, loc, op, input, inputMinAlloc, + inputMaxAlloc, enableSIMD, enableParallel); + inputMin = create.krnl.load(inputMinAlloc); + inputMax = create.krnl.load(inputMaxAlloc); +} - // Types - Type elementType = inputType.getElementType(); - Type quantizedElementType = quantizedType.getElementType(); +// Implementation of quantize helper function. Returns Values scale, zeroPoint, +// and quantizedZeroPoint directly (not as a value to a memory location, +// directly the floating point results). +void emitDynamicQuantizationLinearScalarParametersFromMinMax( + ConversionPatternRewriter &rewriter, Location loc, Operation *op, + MemRefType inputType, MemRefType quantizedType, Value inputMin, + Value inputMax, Value qMin, Value qMax, Value &scale, Value &zeroPoint, + Value &quantizedZeroPoint, bool wantZeroPoint, bool enableParallel) { // Equations: // y_scale = (max(x) - min(x))/(qMax - qMin) @@ -44,41 +54,46 @@ void emitDynamicQuantizationLinearScalarParameters( // // where, saturate is to clip to [0, 255] for ui8. - Value inputMinAlloc, inputMaxAlloc; - emitMinMaxReductionToScalar(rewriter, loc, op, input, inputMinAlloc, - inputMaxAlloc, enableSIMD, enableParallel); - Value xMin = create.krnl.load(inputMinAlloc); - Value xMax = create.krnl.load(inputMaxAlloc); - + MultiDialectBuilder create(rewriter, loc); + // Types. + Type elementType = inputType.getElementType(); + Type quantizedElementType = quantizedType.getElementType(); // Include 0 to max(x) and min(x). // x_min = min(min(x), 0) // x_max = max(max(x), 0) Value zero = create.math.constant(elementType, 0.0); - xMax = create.math.max(xMax, zero); - xMin = create.math.min(xMin, zero); + inputMax = create.math.max(inputMax, zero); + inputMin = create.math.min(inputMin, zero); // Compute y_scale. - Value xDiff = create.math.sub(xMax, xMin); + Value xDiff = create.math.sub(inputMax, inputMin); Value boundDiff = create.math.sub(qMax, qMin); scale = create.math.div(xDiff, boundDiff); // Compute y_zero_point. - Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); - // Saturate zero point. - Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); - // Round zero point. - zeroPoint = create.math.round(saturateZeroPoint); + if (wantZeroPoint) { + Value interZeroPoint = + create.math.sub(qMin, create.math.div(inputMin, scale)); + // Saturate zero point. + Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); + // Round zero point. + zeroPoint = create.krnl.roundEven(saturateZeroPoint); + } else { + zeroPoint = zero; + } quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint); } struct ONNXDynamicQuantizeLinearOpLowering : public OpConversionPattern { ONNXDynamicQuantizeLinearOpLowering(TypeConverter &typeConverter, - MLIRContext *ctx, bool enableSIMD, bool enableParallel) + MLIRContext *ctx, bool enableSIMD, bool enableParallel, + bool enableFastMath) : OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD), - enableParallel(enableParallel) {} + enableParallel(enableParallel), enableFastMath(enableFastMath) {} bool enableSIMD = false; bool enableParallel = false; + bool enableFastMath = false; using LocalDialectBuilder = MultiDialectBuilder; @@ -94,12 +109,12 @@ struct ONNXDynamicQuantizeLinearOpLowering Value X = adaptor.getX(); // MemRefType for inputs and outputs. - auto xMemRefType = dyn_cast(X.getType()); - auto yMemRefType = dyn_cast( + auto xMemRefType = mlir::dyn_cast(X.getType()); + auto yMemRefType = mlir::dyn_cast( typeConverter->convertType(dqlOp.getResult(0).getType())); - auto yScaleMemRefType = dyn_cast( + auto yScaleMemRefType = mlir::dyn_cast( typeConverter->convertType(dqlOp.getResult(1).getType())); - auto yZeroPointMemRefType = dyn_cast( + auto yZeroPointMemRefType = mlir::dyn_cast( typeConverter->convertType(dqlOp.getResult(2).getType())); // Types @@ -118,19 +133,23 @@ struct ONNXDynamicQuantizeLinearOpLowering Value YZeroPoint = create.mem.alignedAlloc( yZeroPointMemRefType, shapeHelper.getOutputDims(2)); + Value xMin, xMax; + emitDynamicQuantizationLinearMinMax( + rewriter, loc, op, X, xMin, xMax, enableSIMD, enableParallel); Value qMax = create.math.constant(elementType, 255.0); Value qMin = create.math.constant(elementType, 0.0); Value scale, zeroPoint, zeroPointInt; - - emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op, - xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt, - enableSIMD, enableParallel); + bool wantZeroPoint = !disableQuantZeroPoint; + emitDynamicQuantizationLinearScalarParametersFromMinMax(rewriter, loc, op, + xMemRefType, yMemRefType, xMin, xMax, qMin, qMax, scale, zeroPoint, + zeroPointInt, wantZeroPoint, enableParallel); create.krnl.store(scale, YScale); create.krnl.store(zeroPointInt, YZeroPoint); emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/, + enableSIMD, enableParallel, enableFastMath); rewriter.replaceOp(op, {Y, YScale, YZeroPoint}); onnxToKrnlSimdReport(op); @@ -140,9 +159,9 @@ struct ONNXDynamicQuantizeLinearOpLowering void populateLoweringONNXDynamicQuantizeLinearOpPattern( RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, - bool enableSIMD, bool enableParallel) { + bool enableSIMD, bool enableParallel, bool enableFastMath) { patterns.insert( - typeConverter, ctx, enableSIMD, enableParallel); + typeConverter, ctx, enableSIMD, enableParallel, enableFastMath); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp index 124b854bde..c75ebd155e 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp @@ -12,25 +12,40 @@ // //===----------------------------------------------------------------------===// +#ifndef ONNX_MLIR_QUANTIZE_HELPER_HPP +#define ONNX_MLIR_QUANTIZE_HELPER_HPP 1 + #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" namespace onnx_mlir { // Given an input, scale, zero point, qMin, and qMax, perform a linear -// quantization and store in alloc. +// quantization and store in alloc. FastMath enables taking the reciprocal for +// faster results on machines where mul is faster than div. void emitQuantizationLinearScalarParameters( mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Operation *op, mlir::MemRefType inputType, mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims, mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale, - mlir::Value zeroPoint, bool enableSIMD, bool enableParallel); + mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel, bool enableFastMath); + +// Compute min max over an entire tensor, which can then be used for dynamic +// quantize linear. +void emitDynamicQuantizationLinearMinMax( + mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Operation *op, mlir::Value input, mlir::Value &inputMin, + mlir::Value &inputMax, bool enableSIMD, bool enableParallel); -// Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin -// and qMax. -void emitDynamicQuantizationLinearScalarParameters( +// Compute scale and zero points for dynamic quantization from min/max. +void emitDynamicQuantizationLinearScalarParametersFromMinMax( mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, mlir::Operation *op, mlir::MemRefType inputType, - mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin, - mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint, - mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel); + mlir::MemRefType quantizedType, mlir::Value inputMin, mlir::Value inputMax, + mlir::Value qMin, mlir::Value qMax, mlir::Value &scale, + mlir::Value &zeroPoint, mlir::Value &quantizedZeroPoint, bool wantZeroPoint, + bool enableParallel); + } // namespace onnx_mlir + +#endif diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 83b2094fc7..5743e71077 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" @@ -20,23 +21,62 @@ using namespace mlir; +#define DISABLE_FAST_MATH 0 /* disable reciprocal (for debug) */ + namespace onnx_mlir { // Helper function for quantization. void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax, - Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) { - MultiDialectBuilder create( - rewriter, loc); + Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel, bool enableFastMath) { + MultiDialectBuilder + create(rewriter, loc); // Types Type quantizedElementType = quantizedType.getElementType(); + Type inputElementType = inputType.getElementType(); int64_t rank = inputType.getRank(); + // Use fast math with reciprocal? + bool useReciprocal = + !DISABLE_FAST_MATH && enableFastMath && isa(inputElementType); + // Flatten the input data and outputs DimsExpr inputDims, flatInputDims, flatAllocDims; inputDims = allocDims; // Unput and output have the same shape. + // + if (rank == 0) { + // Do scalar computation only when the input is a scalar tensor. + Value x = create.krnl.load(input); + // Scale + Value scaleX; + if (useReciprocal) { + Value one = create.math.constant(inputElementType, 1.0); + Value scaleReciprocal = create.math.div(one, scale); + scaleX = create.math.mul(x, scaleReciprocal); + } else { + scaleX = create.math.div(x, scale); + } + // Round + Value roundX = create.krnl.roundEven(scaleX); + // Adjust + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; + // Saturate: use max into a min. + Value saturateX = create.math.clip(adjustX, qMin, qMax); + // Convert into quantized type. + Value quantSaturateX = create.math.cast(quantizedElementType, saturateX); + create.krnl.store(quantSaturateX, alloc); + onnxToKrnlSimdReport(op, /*successful*/ false, 0, 0, + "no simd in quantizationLinear whole tensor"); + return; + } + Value flatInput = create.mem.reshapeToFlatInnermost(input, inputDims, flatInputDims, rank); Value flatAlloc = @@ -49,12 +89,19 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, if (enableSIMD) { int64_t innermostLoopCollapse = 1; // Only innermost is simdized. bool canOverCompute = false; - GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5}, - {GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2}, - {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}}; + GenOpMix mixAdjust; + if (hasZeroPoint) + mixAdjust = {{GenericOps::ArithmeticGop, 1}}; + GenOpMix mixRound = getGenOpMix(inputElementType, op); + GenericOps divOrMulGenOp = + useReciprocal ? GenericOps::MulGop : GenericOps::DivGop; + GenOpMix mixOthers = {{divOrMulGenOp, 1}, {GenericOps::ConversionGop, 1}, + {GenericOps::MinMaxGop, 2}, + {GenericOps::EstimatedVectorRegisterPressure, 4}}; + GenOpMix mix1 = computeGenOpMixUnion(mixAdjust, mixRound); + GenOpMix mix2 = computeGenOpMixUnion(mix1, mixOthers); totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, - innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, + innermostLoopCollapse, mix2, canOverCompute, simdLoopStaticTripCount, simdOnly); } @@ -66,23 +113,37 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, inputAF.emplace_back(zero); DimsExpr outputAF; outputAF.emplace_back(zero); + + Value scaleReciprocal; + if (useReciprocal) { + Value one = create.math.constant(inputElementType, 1.0); + scaleReciprocal = create.math.div(one, scale); + } create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { - MultiDialectBuilder create(kb); + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { + MultiDialectBuilder create(kb); Value x = inputVals[0]; // Scale - Value scaleX = create.math.div(x, scale); + Value scaleX; + if (useReciprocal) + scaleX = create.math.mul(x, scaleReciprocal); + else + scaleX = create.math.div(x, scale); // Round - Value roundX = create.math.round(scaleX); + Value roundX = create.krnl.roundEven(scaleX); // Adjust - Value adjustX = create.math.add(roundX, zeroPoint); - // Saturate + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; + // Saturate: use max into a min. Value saturateX = create.math.clip(adjustX, qMin, qMax); - Value res = create.math.cast(quantizedElementType, saturateX); - resVals.emplace_back(res); - }); + // Convert into quantized type. + return create.math.cast(quantizedElementType, saturateX); + }}); + if (totVL > 1) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "quantizationLinear whole tensor"); @@ -94,12 +155,13 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, struct ONNXQuantizeLinearOpLowering : public OpConversionPattern { ONNXQuantizeLinearOpLowering(TypeConverter &typeConverter, MLIRContext *ctx, - bool enableSIMD, bool enableParallel) + bool enableSIMD, bool enableParallel, bool enableFastMath) : OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD), - enableParallel(enableParallel) {} + enableParallel(enableParallel), enableFastMath(enableFastMath) {} bool enableSIMD = false; bool enableParallel = false; + bool enableFastMath = false; using LocalDialectBuilder = MultiDialectBuilder; @@ -117,8 +179,8 @@ struct ONNXQuantizeLinearOpLowering Value YZeroPoint = qlOp.getYZeroPoint(); // Optional input. // MemRefType for inputs and outputs. - auto xMemRefType = dyn_cast(X.getType()); - auto yMemRefType = dyn_cast( + auto xMemRefType = mlir::dyn_cast(X.getType()); + auto yMemRefType = mlir::dyn_cast( typeConverter->convertType(qlOp.getResult().getType())); MemRefType yScaleMemRefType = mlir::cast(YScale.getType()); @@ -126,11 +188,11 @@ struct ONNXQuantizeLinearOpLowering Type elementType = xMemRefType.getElementType(); Type quantizedElementType = yMemRefType.getElementType(); - // Does not support per-axis and i8. + // Does not support per-axis and other types rather than i8. assert(yScaleMemRefType.getRank() == 0 && "Does not support per-axis quantization"); - assert(quantizedElementType.isUnsignedInteger() && - "Does not support i8 quantization"); + assert(quantizedElementType.isInteger(8) && + "Only support i8/ui8 quantization at this moment"); // Get shape. ONNXQuantizeLinearOpShapeHelper shapeHelper(op, operands, &create.krnlIE); @@ -160,15 +222,22 @@ struct ONNXQuantizeLinearOpLowering // Load y_zero_point. Value zeroPoint; + bool hasZeroPoint = false; if (!isNoneValue(YZeroPoint)) { zeroPoint = create.krnl.load(adaptor.getYZeroPoint()); zeroPoint = create.math.cast(elementType, zeroPoint); - } else - zeroPoint = create.math.constant(elementType, 0.0); - + hasZeroPoint = true; + } + if (disableQuantZeroPoint) { + // TODO: should we expect to disable hasZeroPoint forcefully, or + // generate an error if we had a zero point? Right now, just forcefully + // assert we have no zero point, i.e. ignore one even if we had a zero + // point. + hasZeroPoint = false; + } emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, hasZeroPoint, enableSIMD, enableParallel, enableFastMath); rewriter.replaceOp(op, {Y}); onnxToKrnlSimdReport(op); @@ -178,9 +247,9 @@ struct ONNXQuantizeLinearOpLowering void populateLoweringONNXQuantizeLinearOpPattern(RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, bool enableSIMD, - bool enableParallel) { + bool enableParallel, bool enableFastMath) { patterns.insert( - typeConverter, ctx, enableSIMD, enableParallel); + typeConverter, ctx, enableSIMD, enableParallel, enableFastMath); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp index b90fe14696..cc6d94d42a 100644 --- a/src/Conversion/ONNXToKrnl/RNN/GRU.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/GRU.cpp @@ -452,7 +452,7 @@ void calculateState( // Do element-wise computations. Fuse them into a single nested loop. ValueRange loops = create.krnl.defineLoops(htRank); create.krnl.iterate(loops, loops, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value bs(indices[0]), hs(indices[1]); @@ -541,7 +541,7 @@ void calculateState( // Emit rt and (rt (.) Ht-1). ValueRange loops1 = create.krnl.defineLoops(htRank); create.krnl.iterate(loops1, loops1, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value bs(indices[0]), hs(indices[1]); @@ -574,7 +574,7 @@ void calculateState( // Do element-wise computations. Fuse them into a single nested loop. ValueRange loops2 = create.krnl.defineLoops(htRank); create.krnl.iterate(loops2, loops2, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); IndexExprScope ieScope(createKrnl); Value bs(indices[0]), hs(indices[1]); diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp index dceed2cb5e..49ae86408e 100644 --- a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -349,7 +349,7 @@ void calculateState( } ValueRange loops = create.krnl.defineLoops(htRank); create.krnl.iterate(loops, loops, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MathBuilder createMath(createKrnl); Value bs(indices[0]), hs(indices[1]); // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp index dab3a299f2..c21930a7c7 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -88,14 +88,14 @@ void initializeIntermediateStates(ConversionPatternRewriter &rewriter, rewriter, loc); IndexExprScope childScope(create.krnl); ValueRange loopDef = create.krnl.defineLoops(nLoops); - SmallVector lbs(nLoops, LiteralIndexExpr(0)); + SmallVector lbs(nLoops, LitIE(0)); Value boundVal = (direction == FORWARD || direction == BIDIRECTIONAL) ? forwardHt : reverseHt; SmallVector ubs; create.krnlIE.getShapeAsDims(boundVal, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { SmallVector IVs; IVs.emplace_back(loopInd[0]); IVs.emplace_back(loopInd[1]); @@ -193,7 +193,7 @@ void initializeHiddenAndCell(ConversionPatternRewriter &rewriter, Location loc, } ValueRange loops = create.krnl.defineLoops(htRank); create.krnl.iterate(loops, loops, htLbs, htUbs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value hiddenVal = zero; if (!isNoneValue(initialH)) hiddenVal = createKrnl.load(initialH, indices); @@ -232,7 +232,7 @@ void stateToOutputForHiddenOrCell(ConversionPatternRewriter &rewriter, } ValueRange loops = create.krnl.defineLoops(2); create.krnl.iterate(loops, loops, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value b(indices[0]), h(indices[1]); // Forward. Value val = createKrnl.load(forwardVal, {b, h}); @@ -275,8 +275,8 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, ubs.emplace_back(create.mem.dim(sliceX, r)); } ValueRange loops = create.krnl.defineLoops(2); - create.krnl.iterate( - loops, loops, lbs, ubs, [&](KrnlBuilder &createKrnl, ValueRange indices) { + create.krnl.iterate(loops, loops, lbs, ubs, + [&](const KrnlBuilder &createKrnl, ValueRange indices) { Value b(indices[0]), i(indices[1]); Value val = createKrnl.load(X, {timestepIV, b, i}); createKrnl.store(val, sliceX, {b, i}); @@ -289,9 +289,10 @@ Value emitXSliceAt(ConversionPatternRewriter &rewriter, Location loc, Value X, // When a sample reachs the limit of its sequence len, nextHt will be padded // with 0 (or initialH), and Ht will keep the last value at the sequence end // so that the final value Ht is the last value at their sequence len. -Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath, - Value sequenceLens, Value initialH, Value nextHt, Value sequenceIV, - Value directionIV, Value bs, Value hs, Value Ht) { +Value handleSequenceLens(const KrnlBuilder &createKrnl, + const MathBuilder &createMath, Value sequenceLens, Value initialH, + Value nextHt, Value sequenceIV, Value directionIV, Value bs, Value hs, + Value Ht) { if (!isNoneValue(sequenceLens)) { Value sequenceUB = createKrnl.load(sequenceLens, {bs}); Value initial; diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp index c418c1d002..607bd77bea 100644 --- a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -65,10 +65,10 @@ mlir::Value emitXSliceAt(mlir::ConversionPatternRewriter &rewriter, // When a sample reachs the limit of its sequence len, nextHt will be padded // with 0 (or initialH), and Ht will keep the last value at the sequence end // so that the final value Ht is the last value at their sequence len. -mlir::Value handleSequenceLens(KrnlBuilder &createKrnl, MathBuilder &createMath, - mlir::Value sequenceLens, mlir::Value initialH, mlir::Value nextHt, - mlir::Value sequenceIV, mlir::Value directionIV, mlir::Value bs, - mlir::Value hs, mlir::Value Ht); +mlir::Value handleSequenceLens(const KrnlBuilder &createKrnl, + const MathBuilder &createMath, mlir::Value sequenceLens, + mlir::Value initialH, mlir::Value nextHt, mlir::Value sequenceIV, + mlir::Value directionIV, mlir::Value bs, mlir::Value hs, mlir::Value Ht); // Override the following methods when lowering an RNN operation: // - hasAllNoneOutput @@ -160,15 +160,14 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { if (direction == FORWARD || direction == BIDIRECTIONAL) { IndexExprScope childScope(create.krnl); - mlir::ValueRange loopDef = create.krnl.defineLoops(1); - llvm::SmallVector lbs(1, LiteralIndexExpr(0)); - llvm::SmallVector ubs; + IndexExpr lb = LitIE(0); + IndexExpr ub; if (!mlir::ShapedType::isDynamic(sequenceDimSize)) - ubs.emplace_back(LiteralIndexExpr(sequenceDimSize)); + ub = LitIE(sequenceDimSize); else - ubs.emplace_back(create.krnlIE.getShapeAsDim(X, 0)); - create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, mlir::ValueRange loopInd) { + ub = create.krnlIE.getShapeAsDim(X, 0); + create.krnl.forLoopIE(lb, ub, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder &createKrnl, mlir::ValueRange loopInd) { MathBuilder createMath(createKrnl); mlir::Value directionIV = createMath.constant(rewriter.getIndexType(), 0); @@ -185,15 +184,14 @@ struct ONNXRNNOpLowering : public mlir::OpConversionPattern { if (direction == REVERSE || direction == BIDIRECTIONAL) { IndexExprScope childScope(create.krnl); - mlir::ValueRange loopDef = create.krnl.defineLoops(1); - llvm::SmallVector lbs(1, LiteralIndexExpr(0)); - llvm::SmallVector ubs; + IndexExpr lb = LitIE(0); + IndexExpr ub; if (!mlir::ShapedType::isDynamic(sequenceDimSize)) - ubs.emplace_back(LiteralIndexExpr(sequenceDimSize)); + ub = LitIE(sequenceDimSize); else - ubs.emplace_back(create.krnlIE.getShapeAsDim(X, 0)); - create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &ck, mlir::ValueRange loopInd) { + ub = create.krnlIE.getShapeAsDim(X, 0); + create.krnl.forLoopIE(lb, ub, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder &ck, mlir::ValueRange loopInd) { MultiDialectBuilder create(ck); mlir::AffineMap reverseIVMap = mlir::AffineMap::get(1, 1, diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp index 98674670ac..b8c78f4293 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceAt.cpp @@ -38,8 +38,7 @@ struct ONNXSequenceAtOpLowering : public OpConversionPattern { auto dimSize = create.mem.dim(input_sequence, 0); SymbolIndexExpr boundIE(dimSize); - IndexExpr positionIE = - SymbolIndexExpr(create.krnl.load(adaptor.getPosition())); + IndexExpr positionIE = SymIE(create.krnl.load(adaptor.getPosition())); // Handle the negative position IndexExpr condIE = positionIE < 0; IndexExpr fixedPosition = positionIE + boundIE; diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp index fd7bcbf118..5657d0555e 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceErase.cpp @@ -55,7 +55,7 @@ struct ONNXSequenceEraseOpLowering // Erase the end of the sequence positionIE = boundIE - 1; } else { - positionIE = SymbolIndexExpr(create.krnl.load(adaptor.getPosition())); + positionIE = SymIE(create.krnl.load(adaptor.getPosition())); // Handle the negative position IndexExpr correctionIE = positionIE + boundIE; IndexExpr conditionIE = positionIE < 0; @@ -64,13 +64,8 @@ struct ONNXSequenceEraseOpLowering // Copy the elements before the position KrnlBuilder createKrnl(rewriter, loc); - SmallVector lbs; - lbs.emplace_back(LiteralIndexExpr(0)); - SmallVector ubs; - ubs.emplace_back(positionIE); - ValueRange firstLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(firstLoopDef, firstLoopDef, lbs, ubs, - [&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) { + createKrnl.forLoopIE(LitIE(0), positionIE, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) { Value element = createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]); createKrnl.seqstore(element, alloc, positionIE); @@ -78,13 +73,8 @@ struct ONNXSequenceEraseOpLowering }); // Copy the elements after the position - SmallVector lbs1; - lbs1.emplace_back(positionIE + 1); - SmallVector ubs1; - ubs1.emplace_back(boundIE); - ValueRange secondLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(secondLoopDef, secondLoopDef, lbs1, ubs1, - [&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) { + createKrnl.forLoopIE(positionIE + 1, boundIE, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) { Value element = createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]); Value oneIndex = create.math.constantIndex(1); diff --git a/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp b/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp index 622c80fb5a..806dbb71d3 100644 --- a/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp +++ b/src/Conversion/ONNXToKrnl/Sequence/SequenceInsert.cpp @@ -59,7 +59,7 @@ struct ONNXSequenceInsertOpLowering // ToDo (chentong): backward shape inference may help positionIE = boundIE; } else { - positionIE = SymbolIndexExpr(create.krnl.load(adaptor.getPosition())); + positionIE = SymIE(create.krnl.load(adaptor.getPosition())); // Handle the negative position IndexExpr condIE = positionIE < 0; IndexExpr fixedPosition = positionIE + boundIE; @@ -77,13 +77,8 @@ struct ONNXSequenceInsertOpLowering // compilation problem due to the unranked tensor even though // the loop will not be reached at runtime. } else { - SmallVector lbs; - lbs.emplace_back(LiteralIndexExpr(0)); - SmallVector ubs; - ubs.emplace_back(positionIE); - ValueRange firstLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(firstLoopDef, firstLoopDef, lbs, ubs, - [&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) { + createKrnl.forLoopIE(LitIE(0), positionIE, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) { auto element = createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]); createKrnl.seqstore(element, alloc, positionIE); @@ -91,13 +86,8 @@ struct ONNXSequenceInsertOpLowering }); // Copy the elements after the position - SmallVector lbs1; - lbs1.emplace_back(positionIE + 1); - SmallVector ubs1; - ubs1.emplace_back(boundIE); - ValueRange secondLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(secondLoopDef, secondLoopDef, lbs1, ubs1, - [&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) { + createKrnl.forLoopIE(positionIE + 1, boundIE, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) { auto element = createKrnl.load(adaptor.getInputSequence(), indicesLoopInd[0]); auto oneIndex = create.math.constantIndex(1); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp index 592ba67b1a..3ed51b37a9 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ArgMinMax.cpp @@ -98,19 +98,19 @@ struct ONNXArgMinMaxOpLowering : public OpConversionPattern { // 1. Krnl loops to initialize the result. ValueRange initLoopDef = create.krnl.defineLoops(reducedRank); - SmallVector initLbs(reducedRank, LiteralIndexExpr(0)); + SmallVector initLbs(reducedRank, LitIE(0)); create.krnl.iterateIE(initLoopDef, initLoopDef, initLbs, outputDims, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(minusOne, alloc, loopInd); }); // 2. Krnl loop to calculate arg min/arg max. ValueRange calcLoopDef = create.krnl.defineLoops(dataRank); - SmallVector lbs(dataRank, LiteralIndexExpr(0)); + SmallVector lbs(dataRank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(data, ubs); create.krnl.iterateIE(calcLoopDef, calcLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Handle the operation: SmallVector inLoopIVs, outLoopIVs, dstLoopIVs; diff --git a/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp b/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp index 13546b0ebf..8fd71ac670 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Compress.cpp @@ -55,14 +55,14 @@ struct ONNXCompressOpLowering : public OpConversionPattern { // Create temp memory for summing up the true value and init to zero. Type indexType = rewriter.getIndexType(); MemRefType indexMemRefType = MemRefType::get({}, indexType); + // Scalar, ok to use alloca. Value sumMemRef = create.mem.alloca(indexMemRefType); create.krnl.store(zeroIE.getValue(), sumMemRef); // Now create a loop to iterate over all conditions. Value condMemRef = adaptor.getCondition(); IndexExpr condShapeFirstRank = create.krnlIE.getShapeAsDim(condMemRef, 0); - ValueRange loopDef = create.krnl.defineLoops(1); - create.krnl.iterateIE(loopDef, loopDef, {zeroIE}, {condShapeFirstRank}, - [&](KrnlBuilder createKrnl, ValueRange loopInd) { + create.krnl.forLoopIE(zeroIE, condShapeFirstRank, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange loopInd) { MathBuilder createMath(createKrnl); // Load the condition Value currCond = createKrnl.load(condMemRef, loopInd); // Type i1. @@ -143,24 +143,25 @@ struct ONNXCompressOpLowering : public OpConversionPattern { } } + // Scalar, ok to use alloca. Value readIndexMemRef = create.mem.alloca(indexMemRefType); create.krnl.store(zeroIE.getValue(), readIndexMemRef); ValueRange inputLoopDef = create.krnl.defineLoops(inputRank); create.krnl.iterateIE(inputLoopDef, inputLoopDef, inputLbs, inputUbs, - [&](KrnlBuilder createKrnl, ValueRange inputLoopInd) { + [&](const KrnlBuilder createKrnl, ValueRange inputLoopInd) { MultiDialectBuilder create( createKrnl); Value readIndex = create.krnl.load(readIndexMemRef); Value inBound = trueVal; if (!skipCond) inBound = create.math.slt(readIndex, condUb); - create.scf.ifThenElse(inBound, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(inBound, [&](const SCFBuilder &createSCF) { MultiDialectBuilder create( createSCF); Value currCond = create.krnl.load(condMemRef, {readIndex}); Value copy = create.math.neq(currCond, falseVal); - create.scf.ifThenElse(copy, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(copy, [&](const SCFBuilder &createSCF) { MultiDialectBuilder create(createSCF); Value val = create.krnl.load(inputMemRef, inputLoopInd); // Copy to output. @@ -215,10 +216,9 @@ struct ONNXCompressOpLowering : public OpConversionPattern { innerLbs.emplace_back(inputLbs[i]); innerUbs.emplace_back(inputUbs[i]); } - ValueRange axisLoopDef = create.krnl.defineLoops(1); - create.krnl.iterateIE(axisLoopDef, axisLoopDef, {inputLbs[axisValue]}, - {inputUbs[axisValue]}, - [&](KrnlBuilder createKrnl, ValueRange axisLoopInd) { + create.krnl.forLoopIE(inputLbs[axisValue], inputUbs[axisValue], + /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createKrnl, ValueRange axisLoopInd) { MultiDialectBuilder create( createKrnl); // Compute the test if we have enough condition value for current @@ -227,12 +227,12 @@ struct ONNXCompressOpLowering : public OpConversionPattern { Value inBound = trueVal; if (!skipCond) inBound = create.math.slt(readIndex, condUb); - create.scf.ifThenElse(inBound, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(inBound, [&](const SCFBuilder &createSCF) { MultiDialectBuilder create( createSCF); Value currCond = create.krnl.load(condMemRef, {readIndex}); Value copy = create.math.neq(currCond, falseVal); - create.scf.ifThenElse(copy, [&](SCFBuilder &createSCF) { + create.scf.ifThenElse(copy, [&](const SCFBuilder &createSCF) { KrnlBuilder createKrnl(createSCF); // Load the write index. Value writeIndex = createKrnl.load(writeIndexMemRef); @@ -240,7 +240,7 @@ struct ONNXCompressOpLowering : public OpConversionPattern { ValueRange innerLoopDefs = createKrnl.defineLoops(innerRank); createKrnl.iterateIE(innerLoopDefs, innerLoopDefs, innerLbs, innerUbs, - [&](KrnlBuilder createKrnl, ValueRange innerLoopInd) { + [&](const KrnlBuilder createKrnl, ValueRange innerLoopInd) { MathBuilder createMath(createKrnl); // Compute access functions for input and output. SmallVector inputAccessFct, outputAccessFct; diff --git a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp index 1f0833d7f8..40fc1c9d92 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Concat.cpp @@ -75,7 +75,7 @@ struct ONNXConcatOpLowering : public OpConversionPattern { // dim of different inputs. SmallVector commonUB(shapeHelper.getOutputDims()); // IndexExprScope IEScope(&rewriter, loc); - IndexExpr accumulatedOffset = LiteralIndexExpr(0); + IndexExpr accumulatedOffset = LitIE(0); for (unsigned int i = 0; i < inputNum; ++i) { // Since the accumulatedOffsetValue will be used in a nested // IndexExprScope, we get the Value of this IndexExpr and pass it as a @@ -84,7 +84,7 @@ struct ONNXConcatOpLowering : public OpConversionPattern { OpBuilder::InsertionGuard insertGuard(rewriter); // Create loop. ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(operands[i], ubs); // For each input, only the dimension 'axis' is different @@ -101,7 +101,7 @@ struct ONNXConcatOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(loopDef, loopDef, lbs, commonUB, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Indices for the read and write. SmallVector readIndices, writeIndices; for (unsigned int r = 0; r < rank; ++r) { @@ -109,9 +109,8 @@ struct ONNXConcatOpLowering : public OpConversionPattern { writeIndices.emplace_back(loopInd[r]); else { IndexExprScope IEScope(&rewriter, loc); - IndexExpr writeOffset = DimIndexExpr(loopInd[r]); - IndexExpr accumulatedOffsetIE = - SymbolIndexExpr(accumulatedOffsetValue); + IndexExpr writeOffset = DimIE(loopInd[r]); + IndexExpr accumulatedOffsetIE = SymIE(accumulatedOffsetValue); writeOffset = writeOffset + accumulatedOffsetIE; writeIndices.emplace_back(writeOffset.getValue()); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp b/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp index 24848eee8b..aa86fbbec3 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ConcatShapeTranspose.cpp @@ -125,7 +125,7 @@ struct ONNXConcatShapeTransposeOpLowering // optimization. Difference may come from constant vs. dynamic, or dynamic // dim of different inputs. SmallVector commonUB = outputConcatDims; - IndexExpr accumulatedOffset = LiteralIndexExpr(0); + IndexExpr accumulatedOffset = LitIE(0); for (unsigned int i = 0; i < numInputs; ++i) { // Since the accumulatedOffsetValue will be used in a nested // IndexExprScope, we get the Value of this IndexExpr and pass it as a @@ -134,13 +134,13 @@ struct ONNXConcatShapeTransposeOpLowering OpBuilder::InsertionGuard insertGuard(rewriter); // Create loop. ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(operands[i], ubs); // For each input, only the dimension 'axis' is different commonUB[axis] = ubs[axis]; create.krnl.iterateIE(loopDef, loopDef, lbs, commonUB, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Indices for the read and write. SmallVector readIndices, writeIndices; for (unsigned int r = 0; r < rank; ++r) { @@ -148,9 +148,8 @@ struct ONNXConcatShapeTransposeOpLowering writeIndices.emplace_back(loopInd[r]); else { IndexExprScope IEScope(&rewriter, loc); - IndexExpr writeOffset = DimIndexExpr(loopInd[r]); - IndexExpr accumulatedOffsetIE = - SymbolIndexExpr(accumulatedOffsetValue); + IndexExpr writeOffset = DimIE(loopInd[r]); + IndexExpr accumulatedOffsetIE = SymIE(accumulatedOffsetValue); writeOffset = writeOffset + accumulatedOffsetIE; writeIndices.emplace_back(writeOffset.getValue()); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp b/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp index 2b6b98d30e..ebce30ce5a 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ConstantOfShape.cpp @@ -80,11 +80,11 @@ struct ONNXConstantOfShapeOpLowering if (!hasAllScalarValues({alloc})) { IndexExprScope childScope(&rewriter, loc); ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(constantVal, alloc, loopInd); }); } else diff --git a/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp b/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp index 21fc67f8da..0fe77439b8 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Dim.cpp @@ -43,7 +43,7 @@ struct ONNXDimOpLowering : public OpConversionPattern { Type elementType = outputMemRefType.getElementType(); // Output is 1D memref of one element. - SmallVector outputDims(1, LiteralIndexExpr(1)); + SmallVector outputDims(1, LitIE(1)); Value alloc = create.mem.alignedAlloc(outputMemRefType, outputDims); // Write the dimension at axis to the output. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp b/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp index b75a94e5a6..ae93d5ba9e 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Expand.cpp @@ -56,7 +56,7 @@ struct ONNXExpandOpLowering : public OpConversionPattern { SmallVector lbs(outputRank, zeroIE); create.krnl.iterateIE(outputLoopDef, outputLoopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange outputLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange outputLoopInd) { IndexExprScope outputScope(createKrnl, shapeHelper.getScope()); SmallVector outputLoopIndices, lhsAccessExprs; getIndexExprList(outputLoopInd, outputLoopIndices); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp index 35cb8c4f7e..58227a2b69 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Gather.cpp @@ -102,7 +102,7 @@ struct ONNXGatherOpLowering : public OpConversionPattern { } } create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); SymbolIndexExpr axisDim(dataDims[axisLit]); diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp index 9e4db0a1ca..0f3b5c24e2 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherElements.cpp @@ -73,9 +73,9 @@ struct ONNXGatherElementsOpLowering // output[i][j]...[n] = data[i][j]..[index]..[n] (index used at axis dim.) // ValueRange loopDef = create.krnl.defineLoops(indicesRank); - DimsExpr lbs(indicesRank, LiteralIndexExpr(0)); + DimsExpr lbs(indicesRank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, indicesDims, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp index 69a38b2fce..3bbed9d647 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp @@ -4,7 +4,7 @@ //===---------------- GatherND.cpp - Lowering GatherND Op -----------------===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -32,10 +32,10 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // Debug function used to emit code to print the supplied 'indices'. static void printIndices( - StringRef title, const DimsExpr &indices, KrnlBuilder &createKrnl) { + StringRef title, const DimsExpr &indices, const KrnlBuilder &createKrnl) { llvm::Twine msg(title + ": ("); createKrnl.printf(msg.str()); - int64_t n = (int64_t)indices.size(); + int64_t n = static_cast(indices.size()); for (int64_t i = 0; i < n; ++i) { Value val = indices[i].getValue(); createKrnl.printf(" ", val); @@ -122,6 +122,7 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // Initialize the index used to store the result values. Value iZero = create.math.constantIndex(0); Value iOne = create.math.constantIndex(1); + // Scalar, ok to use alloca. Value storeIndex = create.mem.alloca(MemRefType::get({}, rewriter.getIndexType())); create.krnl.store(iZero, storeIndex); @@ -133,16 +134,15 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // } // output.reshape(outputShape) ValueRange loopDef = create.krnl.defineLoops(2); - DimsExpr lbs(2, LiteralIndexExpr(0)), - ubs = {newIndicesShape[0], newIndicesShape[1]}; + DimsExpr lbs(2, LitIE(0)), ubs = {newIndicesShape[0], newIndicesShape[1]}; if (emitPrintStmts) { - create.krnl.printTensor("reshapedIndices: ", reshapedIndices); - create.krnl.printTensor("reshapedData: ", reshapedData); + create.krnl.printTensor("reshapedIndices%s%d%e", reshapedIndices); + create.krnl.printTensor("reshapedData%s%d%e", reshapedData); } create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); @@ -154,7 +154,7 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // Access function for 'reshapedData'. The first index is equal to the // first loop index. DimsExpr reshapedDataAccessFct; - IndexExpr ind = SymbolIndexExpr(loopInd[0]); + IndexExpr ind = SymIE(loopInd[0]); reshapedDataAccessFct.emplace_back(ind); // The last index of the access function for 'reshapedIndices' is @@ -162,7 +162,7 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // The loaded values from 'reshapedIndices' are the next set of // indices to push to the `reshapedDataAccessFct`. for (unsigned i = 0; i < indicesLastDim; ++i) { - IndexExpr ind = LiteralIndexExpr(i); + IndexExpr ind = LitIE(i); reshapedIndicesAccessFct.emplace_back(ind); if (emitPrintStmts) @@ -185,7 +185,8 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { // When indices.shape[-1] is equal to (rank(data) - b) the // `reshapedDataAccessFct` computed so far has the same number of // indices as the rank of 'reshapedData'. - assert((int64_t)reshapedDataAccessFct.size() == reshapedDataRank && + assert(static_cast(reshapedDataAccessFct.size()) == + reshapedDataRank && "Access function should have the same rank as reshapedData"); if (emitPrintStmts) @@ -212,10 +213,10 @@ struct ONNXGatherNDOpLowering : public OpConversionPattern { Value last = reshapedDataLastDimExpr.getValue(); ValueRange innerLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(innerLoopDef, innerLoopDef, {zero}, {last}, - [&](KrnlBuilder &createKrnl, ValueRange innerLoopInd) { - IndexExpr ind = SymbolIndexExpr(innerLoopInd[0]); + [&](const KrnlBuilder &createKrnl, ValueRange innerLoopInd) { + IndexExpr ind = SymIE(innerLoopInd[0]); reshapedDataAccessFct.emplace_back(ind); - assert((int64_t)reshapedDataAccessFct.size() == + assert(static_cast(reshapedDataAccessFct.size()) == reshapedDataRank && "Access function should have the same rank as " "reshapedData"); diff --git a/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp b/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp index 42311bddd9..9443fea480 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/NonZero.cpp @@ -107,13 +107,14 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { Value zero = create.math.constant(xElementType, 0); // Bounds for the input tensor. - SmallVector xLbs(xRank, LiteralIndexExpr(0)); + SmallVector xLbs(xRank, LitIE(0)); SmallVector xUbs; create.krnlIE.getShapeAsDims(X, xUbs); // Emit a variable for the total number of nonzero values. + // Scalar, ok to use alloca. Value nonzeroCount = create.mem.alloca(MemRefType::get({}, indexTy)); - create.krnl.store(iZero, nonzeroCount, {}); + create.krnl.store(iZero, nonzeroCount); // Emit alloc and dealloc for reduction sum along each dimension. // MemRefType: [Dxi64] where D is the dimension size. @@ -130,7 +131,7 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { ValueRange initLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(initLoopDef, initLoopDef, {iZero}, {xBound.getValue()}, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { createKrnl.store(iZero, alloc, loopInd); }); rsumMemRefs.emplace_back(alloc); @@ -140,15 +141,15 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { // the reduction sum for each dimension. ValueRange rsumLoopDef = create.krnl.defineLoops(xMemRefType.getRank()); create.krnl.iterateIE(rsumLoopDef, rsumLoopDef, xLbs, xUbs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { MathBuilder createMath(createKrnl); Value x = createKrnl.load(X, loopInd); Value eqCond = createMath.eq(x, zero); Value zeroOrOne = createMath.select(eqCond, iZero, iOne); // Count the total number of nonzero values. - Value total = createKrnl.load(nonzeroCount, {}); + Value total = createKrnl.load(nonzeroCount); total = createMath.add(total, zeroOrOne); - createKrnl.store(total, nonzeroCount, {}); + createKrnl.store(total, nonzeroCount); // Reduction sum of the number of nonzero values for each dimension. for (int64_t i = 0; i < xRank; ++i) { Value sum = createKrnl.load(rsumMemRefs[i], loopInd[i]); @@ -160,10 +161,10 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { // Emit alloc and dealloc for the result of this operation. // MemRefType : [RxNxi64] where R is the input's rank, N is the number of // non zero values. - Value numberOfZeros = create.krnl.load(nonzeroCount, {}); + Value numberOfZeros = create.krnl.load(nonzeroCount); SmallVector dimExprs; - dimExprs.emplace_back(LiteralIndexExpr(xRank)); - dimExprs.emplace_back(DimIndexExpr(numberOfZeros)); + dimExprs.emplace_back(LitIE(xRank)); + dimExprs.emplace_back(DimIE(numberOfZeros)); Value resMemRef = create.mem.alignedAlloc(resMemRefType, dimExprs); // Emit code to compute the output for each dimension. @@ -176,11 +177,12 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { // out[0][i] = p // ``` + // Scalars, ok to use alloca. Value pos = create.mem.alloca(MemRefType::get({}, indexTy)); Value sum = create.mem.alloca(MemRefType::get({}, indexTy)); ValueRange iLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(iLoopDef, iLoopDef, {iZero}, {numberOfZeros}, - [&](KrnlBuilder &ck, ValueRange iLoopInd) { + [&](const KrnlBuilder &ck, ValueRange iLoopInd) { MultiDialectBuilder create(ck); @@ -191,26 +193,26 @@ struct ONNXNonZeroOpLowering : public OpConversionPattern { IndexExpr rsumBounds0 = create.krnlIE.getShapeAsDim(rsumBoundsVal, 0); - create.krnl.store(iMinusOne, pos, {}); - create.krnl.store(iZero, sum, {}); + create.krnl.store(iMinusOne, pos); + create.krnl.store(iZero, sum); ValueRange jLoopDef = create.krnl.defineLoops(1); create.krnl.iterate(jLoopDef, jLoopDef, {iZero}, {rsumBounds0.getValue()}, - [&](KrnlBuilder &createKrnl, ValueRange jLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange jLoopInd) { MathBuilder createMath(createKrnl); Value j(jLoopInd[0]); Value o = createKrnl.load(rsumMemRefs[axis], {j}); - Value s = createKrnl.load(sum, {}); - Value p = createKrnl.load(pos, {}); + Value s = createKrnl.load(sum); + Value p = createKrnl.load(pos); s = createMath.add(s, o); Value andCond = createMath.andi( createMath.slt(i, s), createMath.eq(p, iMinusOne)); p = createMath.select(andCond, j, p); - createKrnl.store(p, pos, {}); - createKrnl.store(s, sum, {}); + createKrnl.store(p, pos); + createKrnl.store(s, sum); }); - Value p = create.krnl.load(pos, {}); + Value p = create.krnl.load(pos); p = create.math.cast(resElementType, p); create.krnl.store(p, resMemRef, {axisVal, i}); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp b/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp index bf55e3636f..574856ab17 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/OneHot.cpp @@ -62,7 +62,8 @@ struct ONNXOneHotOpLowering : public OpConversionPattern { create.krnlIE.getShapeAsDims(indices, indicesUbs); ValueRange indicesLoopDef = create.krnl.defineLoops(indicesRank); create.krnl.iterateIE(indicesLoopDef, indicesLoopDef, indicesLbs, - indicesUbs, [&](KrnlBuilder createKrnl, ValueRange indicesLoopInd) { + indicesUbs, + [&](const KrnlBuilder createKrnl, ValueRange indicesLoopInd) { // Loop for all input values. MathBuilder createMath(createKrnl); // Input val is allowed to be any integer/float. Read and convert to @@ -89,9 +90,8 @@ struct ONNXOneHotOpLowering : public OpConversionPattern { Value onValueIndexVal = onValueIndex.getValue(); // Now we have the index that is on, iterate over the depth values // along axis, and set the right one to the value on. - ValueRange depthLoopDef = createKrnl.defineLoops(1); - createKrnl.iterateIE(depthLoopDef, depthLoopDef, {zeroIE}, {depth}, - [&](KrnlBuilder createBuilder, ValueRange depthLoopInd) { + createKrnl.forLoopIE(zeroIE, depth, /*step*/ 1, /*par*/ false, + [&](const KrnlBuilder createBuilder, ValueRange depthLoopInd) { MathBuilder createMath(createKrnl); Value onCond = createMath.eq(depthLoopInd[0], onValueIndexVal); Value res = createMath.select(onCond, onVal, offVal); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp index 542b3c2fd8..1b3d83890c 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Pad.cpp @@ -98,11 +98,10 @@ struct ONNXPadOpLowering : public OpConversionPattern { create.krnlIE.getShapeAsDims(data, ubs); ValueRange mainLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(mainLoopDef, mainLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange dataLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange dataLoopInd) { SmallVector resLoopInd; for (uint64_t i = 0; i < rank; ++i) { - IndexExpr resInd = - DimIndexExpr(dataLoopInd[i]) + shapeHelper.pads[i]; + IndexExpr resInd = DimIE(dataLoopInd[i]) + shapeHelper.pads[i]; resLoopInd.emplace_back(resInd); } Value dataValue = createKrnl.load(data, dataLoopInd); @@ -117,12 +116,12 @@ struct ONNXPadOpLowering : public OpConversionPattern { // Iterate over the result tensor dimensions. ValueRange mainLoopDef = create.krnl.defineLoops(rank); create.krnl.iterateIE(mainLoopDef, mainLoopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange resLoopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange resLoopInd) { MultiDialectBuilder create( createKrnl); SmallVector dataLoopInd; for (uint64_t i = 0; i < rank; ++i) { - IndexExpr dataInd = DimIndexExpr(resLoopInd[i]); + IndexExpr dataInd = DimIE(resLoopInd[i]); IndexExpr pad = shapeHelper.pads[i]; IndexExpr dim = create.krnlIE.getShapeAsDim(data, i); if (padMode.equals_insensitive("edge")) { diff --git a/src/Conversion/ONNXToKrnl/Tensor/Range.cpp b/src/Conversion/ONNXToKrnl/Tensor/Range.cpp index b3bd44d8cc..8ddd91f99b 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Range.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Range.cpp @@ -150,7 +150,7 @@ struct ONNXRangeOpLowering : public OpConversionPattern { // Acc index: SmallVector accIndex; - accIndex.emplace_back(LiteralIndexExpr(0)); + accIndex.emplace_back(LitIE(0)); // Initialize accumulator with value: create.krnl.storeIE(loadedStart, acc, accIndex); @@ -158,8 +158,8 @@ struct ONNXRangeOpLowering : public OpConversionPattern { ValueRange loopDef = create.krnl.defineLoops(1); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); - create.krnl.iterateIE(loopDef, loopDef, {LiteralIndexExpr(0)}, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + create.krnl.iterateIE(loopDef, loopDef, {LitIE(0)}, ubs, + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Emit body of the loop: // output[i] = start + (i * delta); // Read value: @@ -167,7 +167,7 @@ struct ONNXRangeOpLowering : public OpConversionPattern { // Store result: SmallVector resultIndices; - resultIndices.emplace_back(DimIndexExpr(loopInd[0])); + resultIndices.emplace_back(DimIE(loopInd[0])); createKrnl.storeIE(result, alloc, resultIndices); // Increment result: diff --git a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp index 9527fba3c3..75bbed3531 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Resize.cpp @@ -118,11 +118,11 @@ struct ONNXResizeOpLowering : public OpConversionPattern { Value one = create.math.constantIndex(1); ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(alloc, ubs); - create.krnl.iterateIE( - loopDef, loopDef, lbs, ubs, [&](KrnlBuilder &ck, ValueRange loopInd) { + create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, + [&](const KrnlBuilder &ck, ValueRange loopInd) { MultiDialectBuilder create(ck); SmallVector readIndices; diff --git a/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp b/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp index 64c0b5cd27..d0a7f5d3ae 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ReverseSequence.cpp @@ -94,9 +94,9 @@ struct ONNXReverseSequenceOpLowering // Define loops and iteration trip counts (equivalent to size of output) ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { IndexExprScope innerLoopScope(&rewriter, shapeHelper.getScope()); // compute the loop indices for the output diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp index 3cc6a366e3..cdea57693d 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterElements.cpp @@ -74,10 +74,10 @@ struct ONNXScatterElementsOpLowering // output[i][j]..[index]..[n] = val (index used at position axis) // ValueRange loopDef = create.krnl.defineLoops(updatesRank); - DimsExpr lbs(updatesRank, LiteralIndexExpr(0)), ubs; + DimsExpr lbs(updatesRank, LitIE(0)), ubs; create.krnlIE.getShapeAsDims(updates, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); diff --git a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp index 5550c15524..2a585da52a 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/ScatterND.cpp @@ -69,10 +69,10 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { // output[indices[idx]] = updates[idx] // ValueRange loopDef = create.krnl.defineLoops(updatesRank); - DimsExpr lbs(updatesRank, LiteralIndexExpr(0)), ubs; + DimsExpr lbs(updatesRank, LitIE(0)), ubs; create.krnlIE.getShapeAsDims(updates, ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { // Insert code inside the loop. IndexExprScope innerLoopScope(createKrnl); @@ -91,15 +91,15 @@ struct ONNXScatterNDOpLowering : public OpConversionPattern { DimsExpr outputAccessFct; for (unsigned i = 0; i < dataRank; ++i) { if (i < indicesRank - 1) { - IndexExpr ind = LiteralIndexExpr(i); + IndexExpr ind = LitIE(i); DimsExpr indicesAccessFct(indicesAccessFctFirst); indicesAccessFct.emplace_back(ind); Value indexVal = createKrnl.loadIE(indices, indicesAccessFct); IndexExpr index = NonAffineIndexExpr(indexVal); outputAccessFct.emplace_back(index); } else { - IndexExpr index = SymbolIndexExpr( - loopInd[std::min(i, loopInd.size() - 1)]); + IndexExpr index = + SymIE(loopInd[std::min(i, loopInd.size() - 1)]); outputAccessFct.emplace_back(index); } } diff --git a/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp index 95c01a1859..1082490711 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Shape.cpp @@ -59,7 +59,7 @@ struct ONNXShapeOpLowering : public OpConversionPattern { for (uint64_t i = 0; i < selectedData.size(); ++i) { Value val = selectedData[i].getValue(); Value intVal = create.math.cast(elementType, val); - create.krnl.storeIE(intVal, alloc, {LiteralIndexExpr(i)}); + create.krnl.storeIE(intVal, alloc, {LitIE(i)}); } rewriter.replaceOp(op, alloc); onnxToKrnlSimdReport(op); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp b/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp index 659be67d41..956637d553 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Slice.cpp @@ -47,9 +47,9 @@ struct ONNXSliceOpLowering : public OpConversionPattern { create.mem.alignedAlloc(outputMemRefType, shapeHelper.getOutputDims()); ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange loopInd) { + [&](const KrnlBuilder &createKrnl, ValueRange loopInd) { IndexExprScope loopScope(createKrnl); // Compute indices for the load and store op. @@ -58,8 +58,8 @@ struct ONNXSliceOpLowering : public OpConversionPattern { SmallVector loadIndices, storeIndices; for (int ii = 0; ii < outputRank; ++ii) { DimIndexExpr inductionIndex(loopInd[ii]); - IndexExpr start = SymbolIndexExpr(shapeHelper.starts[ii]); - IndexExpr step = SymbolIndexExpr(shapeHelper.steps[ii]); + IndexExpr start = SymIE(shapeHelper.starts[ii]); + IndexExpr step = SymIE(shapeHelper.steps[ii]); loadIndices.emplace_back((step * inductionIndex) + start); storeIndices.emplace_back(inductionIndex); } diff --git a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp index e77d437a48..e6490ffd3f 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Split.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Split.cpp @@ -64,12 +64,12 @@ LogicalResult ONNXSplitOpLoweringCommon(OP_TYPE splitOp, OP_ADAPTOR adaptor, rewriter, loc); ValueRange loopDef = create.krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create.krnlIE.getShapeAsDims(allocs[i], ubs); create.krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { SmallVector readIndices; for (uint64_t r = 0; r < rank; ++r) { DimIndexExpr readIndex(indices[r]); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp b/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp index e4dc340c97..f76cee05de 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Tile.cpp @@ -78,10 +78,10 @@ struct ONNXTileOpLowering : public OpConversionPattern { create.mem.alignedAlloc(memRefType, shapeHelper.getOutputDims()); ValueRange loopDef = create.krnl.defineLoops(outputRank); - SmallVector lbs(outputRank, LiteralIndexExpr(0)); + SmallVector lbs(outputRank, LitIE(0)); create.krnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.getOutputDims(), - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { // Compute the indices used by the input tensor load operation. // Note: An alternative implementation can be found at the end of this // file. diff --git a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp index 39ef87b30d..c597396179 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp @@ -143,7 +143,7 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { bool enableParallel) const { uint64_t rank = mlir::cast(outputMemRef.getType()).getRank(); ValueRange loopDef = create->krnl.defineLoops(rank); - SmallVector lbs(rank, LiteralIndexExpr(0)); + SmallVector lbs(rank, LitIE(0)); SmallVector ubs; create->krnlIE.getShapeAsDims(inputMemRef, ubs); @@ -161,12 +161,12 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { } create->krnl.iterateIE(loopDef, loopDef, lbs, ubs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { // Compute the indices used by the load operation. SmallVector storeIndices; for (uint64_t i = 0; i < rank; ++i) { Value index = indices[ArrayAttrIntVal(permAttr, i)]; - storeIndices.emplace_back(DimIndexExpr(index)); + storeIndices.emplace_back(DimIE(index)); } Value loadData = createKrnl.load(inputMemRef, indices); createKrnl.storeIE(loadData, outputMemRef, storeIndices); @@ -192,13 +192,13 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { // Strides SmallVector inStrides, outStrides; inStrides.resize_for_overwrite(rank); - inStrides[rank - 1] = LiteralIndexExpr(1); - IndexExpr strideIE = LiteralIndexExpr(1); + inStrides[rank - 1] = LitIE(1); + IndexExpr strideIE = LitIE(1); for (int i = rank - 2; i >= 0; --i) { strideIE = strideIE * inUBs[i + 1]; inStrides[i] = strideIE; } - strideIE = LiteralIndexExpr(1); + strideIE = LitIE(1); outStrides.resize_for_overwrite(rank); for (int i = rank - 2; i >= 0; --i) { strideIE = strideIE * outUBs[i + 1]; @@ -207,7 +207,7 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { // The number of elements in a block to copy, computed for the last N // dimensions. - IndexExpr elemsToCopy = LiteralIndexExpr(1); + IndexExpr elemsToCopy = LitIE(1); for (uint64_t i = rank - numLastDims; i < rank; ++i) elemsToCopy = elemsToCopy * inUBs[i]; Value elemsToCopyI64 = create->math.cast(i64Ty, elemsToCopy.getValue()); @@ -220,7 +220,7 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { // Main loop defined over the outer-most dimensions. ValueRange loopDef = create->krnl.defineLoops(outerRank); - SmallVector lbs(outerRank, LiteralIndexExpr(0)); + SmallVector lbs(outerRank, LitIE(0)); if (enableParallel) { int64_t parId; // Note that if there is only 1 dim, lastExclusiveDim is automatically @@ -235,22 +235,20 @@ struct ONNXTransposeOpLowering : public OpConversionPattern { } } create->krnl.iterateIE(loopDef, loopDef, lbs, inUBs, - [&](KrnlBuilder &createKrnl, ValueRange indices) { + [&](const KrnlBuilder &createKrnl, ValueRange indices) { MultiDialectBuilder create(createKrnl); IndexExprScope loopScope(createKrnl); // Compute destination and source offsets for memcpy. - IndexExpr destOffsetIE = LiteralIndexExpr(0); - IndexExpr srcOffsetIE = LiteralIndexExpr(0); + IndexExpr destOffsetIE = LitIE(0); + IndexExpr srcOffsetIE = LitIE(0); for (uint64_t i = 0; i < outerRank; ++i) { // source offset DimIndexExpr srcIndex(indices[i]); - srcOffsetIE = - srcOffsetIE + srcIndex * SymbolIndexExpr(inStrides[i]); + srcOffsetIE = srcOffsetIE + srcIndex * SymIE(inStrides[i]); // destination offset DimIndexExpr destIndex(indices[ArrayAttrIntVal(permAttr, i)]); // Note: index for outStrides is not the permuted index. - destOffsetIE = - destOffsetIE + destIndex * SymbolIndexExpr(outStrides[i]); + destOffsetIE = destOffsetIE + destIndex * SymIE(outStrides[i]); } // call memcpy. create.krnl.memcpy(outputMemRef, inputMemRef, elemsToCopyI64, diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp index d43f14a7cb..7eee85c5be 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unique.cpp @@ -125,16 +125,17 @@ struct ONNXUniqueOpLowering : public ConversionPattern { // Type indexTy = rewriter.getIndexType(); Value iZero = create.math.constantIndex(0); + // Scalar, ok to use alloca. Value uniqueCount = create.mem.alloca(MemRefType::get({}, indexTy)); - create.krnl.store(iZero, uniqueCount, {}); + create.krnl.store(iZero, uniqueCount); Value noneValue; emitArgUnique(rewriter, loc, uniqueCount, X, axis, /*sorted=*/sorted, noneValue, noneValue, noneValue, noneValue, /*count_only=*/true); // // Calculate shapes of output Tensors // - Value total = create.krnl.load(uniqueCount, {}); - NonAffineIndexExpr totalDimExpr = DimIndexExpr(total); + Value total = create.krnl.load(uniqueCount); + NonAffineIndexExpr totalDimExpr = DimIE(total); DimsExpr outputYDims; DimsExpr outputIndexDims; DimsExpr outputInverseIndexDims; @@ -211,7 +212,7 @@ struct ONNXUniqueOpLowering : public ConversionPattern { // // Emit a Unique call to get the outputs // - create.krnl.store(iZero, uniqueCount, {}); + create.krnl.store(iZero, uniqueCount); emitArgUnique(rewriter, loc, uniqueCount, X, axis, /*sorted=*/sorted, outputY, indices, inverseIndices, counts, /*count_only=*/false); if (isNoneValue(indices)) diff --git a/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp b/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp index 90f15db2ff..29856a79cc 100644 --- a/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp +++ b/src/Conversion/ONNXToStablehlo/DialectBuilder.cpp @@ -48,7 +48,7 @@ Value StablehloBuilder::constant(Type type, double val) const { b().create(loc(), b().getF64FloatAttr(val)); }) .Case([&](IntegerType elementType) { - assert(val == (int64_t)val && "value is ambiguous"); + assert(val == static_cast(val) && "value is ambiguous"); unsigned width = elementType.getWidth(); if (width == 1) @@ -57,12 +57,12 @@ Value StablehloBuilder::constant(Type type, double val) const { else { if (elementType.isUnsignedInteger()) { constant = b().create( - loc(), b().getIntegerAttr( - elementType, APInt(width, (uint64_t)val, false))); + loc(), b().getIntegerAttr(elementType, + APInt(width, static_cast(val), false))); } else { constant = b().create( - loc(), b().getIntegerAttr( - elementType, APInt(width, (int64_t)val, true))); + loc(), b().getIntegerAttr(elementType, + APInt(width, static_cast(val), true))); } } }) diff --git a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp index e40057b5d3..f307ce6c9d 100644 --- a/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/MatMul.cpp @@ -135,7 +135,7 @@ struct ONNXMatMulOpLoweringToStablehlo : public ConversionPattern { llvm::to_vector<4>(llvm::seq(0, paddedRank - 2)), llvm::to_vector<4>(llvm::seq(0, paddedRank - 2)), {paddedRank - 1 - oneDPadA}, {paddedRank - 2}), - nullptr); + /*precision_config*/ nullptr, /*algorithm*/ nullptr); else { dotProduct = rewriter.create(loc, op->getResultTypes().front(), broadcastedA, broadcastedB, nullptr); diff --git a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp index b7f2214c02..f1b0d61657 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp @@ -121,8 +121,7 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { Value operand = operands[0]; - assert( - hasStaticShape(operand.getType()) && "Only Static shapes are accepted"); + bool isStaticShape = hasStaticShape(operand.getType()); Location loc = op->getLoc(); Type outputType = *op->result_type_begin(); @@ -151,29 +150,51 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern { // Sum of the all the exponents for the denominator SmallVector reducedShape = getReductionShape(ExpOutputType, axes, false); - ShapedType ReducedShapeType = mlir::cast( - RankedTensorType::get(reducedShape, ExpOutputType.getElementType())); + ShapedType ReducedShapeType; + if (isStaticShape) { + ReducedShapeType = mlir::cast( + RankedTensorType::get(reducedShape, ExpOutputType.getElementType())); + } else { + SmallVector ReducedShapeVector = + getReductionShape(ExpOutputType, axes, true); + ReducedShapeType = mlir::cast(RankedTensorType::get( + ReducedShapeVector, ExpOutputType.getElementType())); + } Value identity = rewriter.create( loc, rewriter.getZeroAttr(ExpOutputType.getElementType())); Value ReduceSum = computeReduceSum(loc, ElementwiseExpStableHLO, identity, - reducedShape, axes, rewriter, false, ReducedShapeType); + reducedShape, axes, rewriter, !isStaticShape, ReducedShapeType); + if (ReduceSum == nullptr) return failure(); - SmallVector broadcast_dims = - getBroadcastDims(ElementwiseExpStableHLO, axes); - Value BroadCastOp = - rewriter.create(loc, ExpOutputType, - ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims)); + Value BroadCastOp; + if (isStaticShape) { + SmallVector broadcast_dims = + getBroadcastDims(ElementwiseExpStableHLO, axes); + BroadCastOp = + rewriter.create(loc, ExpOutputType, + ReduceSum, rewriter.getDenseI64ArrayAttr(broadcast_dims)); + } else { + mlir::Value OutputDimensions = + rewriter.create(loc, operand); + SmallVector DimIndex; + for (int64_t i = 0; i < ExpOutputType.getRank(); i++) + DimIndex.push_back(i); + BroadCastOp = rewriter.create(loc, + ExpOutputType, ReduceSum, OutputDimensions, + rewriter.getDenseI64ArrayAttr(DimIndex)); + } if (BroadCastOp == nullptr) return failure(); - Value Softmax_output = rewriter.create( + Value SoftmaxOutput = rewriter.create( loc, ElementwiseExpStableHLO, BroadCastOp); - if (Softmax_output == nullptr) + + if (SoftmaxOutput == nullptr) return failure(); - rewriter.replaceOp(op, Softmax_output); + rewriter.replaceOp(op, SoftmaxOutput); return success(); } }; diff --git a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp index 9dc5a243be..4d8c3af89a 100644 --- a/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp +++ b/src/Conversion/ONNXToStablehlo/ONNXToStablehloCommon.cpp @@ -117,7 +117,7 @@ DenseIntElementsAttr GetI64ElementsAttr( namespace { // Returns the DenseElementsAttr of input if it's a stablehlo constant or // onnx.Constant. Otherwise returns a nullptr attribute. -DenseElementsAttr getDenseElementAttrFromConstValue(mlir::Value value) { +DenseElementsAttr getDenseElementAttrFromConstValue(Value value) { Operation *definingOp = value.getDefiningOp(); if (auto globalOp = dyn_cast_or_null(definingOp)) { return mlir::dyn_cast(globalOp.getValueAttr()); diff --git a/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp b/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp index 32186f88ac..9580fd03d4 100644 --- a/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp +++ b/src/Conversion/ONNXToStablehlo/RNN/LSTM.cpp @@ -497,8 +497,8 @@ void stateToOutput(ConversionPatternRewriter &rewriter, template <> void calculateStateWithUnroll(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, llvm::StringRef direction, int64_t sequenceDimSize, - Value X, LstmState &state, LstmActivationPack activationForward, + Location loc, llvm::StringRef direction, int64_t sequenceDimSize, Value X, + LstmState &state, LstmActivationPack activationForward, LstmActivationPack activationReverse, LstmWeightPack weightForward, LstmWeightPack weightReverse, LstmBiasPack biasForward, LstmBiasPack biasReverse, Value sequenceLens, Value initialH) { @@ -506,10 +506,10 @@ void calculateStateWithUnroll(rewriter, loc, Xt, state, activationForward, @@ -520,12 +520,12 @@ void calculateStateWithUnroll(rewriter, loc, Xt, state, activationReverse, @@ -538,16 +538,16 @@ void calculateStateWithUnroll void calculateStateWithLoop(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, llvm::StringRef direction, int64_t sequenceDimSize, - Value X, LstmState &state, LstmActivationPack activationForward, + Location loc, llvm::StringRef direction, int64_t sequenceDimSize, Value X, + LstmState &state, LstmActivationPack activationForward, LstmActivationPack activationReverse, LstmWeightPack weightForward, LstmWeightPack weightReverse, LstmBiasPack biasForward, LstmBiasPack biasReverse, Value sequenceLens, Value initialH) { MultiDialectBuilder create(rewriter, loc); if (direction == FORWARD || direction == BIDIRECTIONAL) { - mlir::Value directionIV = create.onnx.constantInt64({0}); - mlir::Value sequenceIV = create.onnx.constantInt64({0}); + Value directionIV = create.onnx.constantInt64({0}); + Value sequenceIV = create.onnx.constantInt64({0}); SmallVector operands = { sequenceIV, state.allHForward, state.forwardHt, state.forwardCt}; SmallVector returnedTypes = {sequenceIV.getType(), @@ -565,7 +565,7 @@ void calculateStateWithLoop( loc, lhs, rhs, ::stablehlo::ComparisonDirection::LT); compareResult = rewriter.create<::stablehlo::ReshapeOp>( @@ -583,12 +583,12 @@ void calculateStateWithLoop(rewriter, loc, Xt, state, activationForward, weightForward, biasForward, seqIV, directionIV, sequenceLens, initialH, /*enableUnroll=*/false, /*isForward=*/true); - mlir::Value one = create.onnx.constantInt64({1}); + Value one = create.onnx.constantInt64({1}); Value newSeqIV = create.onnx.add(seqIV, one); rewriter.create<::stablehlo::ReturnOp>(loc, ValueRange( @@ -600,10 +600,9 @@ void calculateStateWithLoop operands = { reverseSequenceIV, state.allHReverse, state.reverseHt, state.reverseCt}; @@ -622,7 +621,7 @@ void calculateStateWithLoop( loc, lhs, rhs, ::stablehlo::ComparisonDirection::GE); compareResult = rewriter.create<::stablehlo::ReshapeOp>( @@ -640,12 +639,12 @@ void calculateStateWithLoop(rewriter, loc, Xt, state, activationReverse, weightReverse, biasReverse, revseqIV, directionIV, sequenceLens, initialH, /*enableUnroll=*/false, /*isForward=*/false); - mlir::Value one = create.onnx.constantInt64({1}); + Value one = create.onnx.constantInt64({1}); Value newrevseqIV = create.onnx.sub(revseqIV, one); rewriter.create<::stablehlo::ReturnOp>( loc, ValueRange({newrevseqIV, state.allHReverse, state.reverseHt, diff --git a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp index c61bca11e7..4d4bd013ff 100644 --- a/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp +++ b/src/Conversion/ONNXToStablehlo/RNN/RNNBase.cpp @@ -34,8 +34,8 @@ Value allocAllHidden( } /// Allocate the hidden or cell output. -mlir::Value allocHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::Value X, mlir::Value W, mlir::Value R) { +Value allocHiddenOrCell(mlir::ConversionPatternRewriter &rewriter, Location loc, + Value X, Value W, Value R) { MultiDialectBuilder create(rewriter, loc); RankedTensorType zeroType = RankedTensorType::get( {/*num_directions=*/dimAt(W, 0), /*batch_size=*/dimAt(X, 1), diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Reshape.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Reshape.cpp index ce73e9bdda..d15c2028ab 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Reshape.cpp @@ -40,8 +40,8 @@ struct ONNXReshapeOpLoweringToStablehlo : public ConversionPattern { SmallVector dims; IndexExpr::getValues(outputDims, dims); - Type outputShapeType = - RankedTensorType::get({(int64_t)dims.size()}, rewriter.getIndexType()); + Type outputShapeType = RankedTensorType::get( + {static_cast(dims.size())}, rewriter.getIndexType()); Value shape = rewriter.create(loc, dims); shape = rewriter.create(loc, outputShapeType, shape); diff --git a/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp b/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp index 6da4e17f78..3c9767c4d8 100644 --- a/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp +++ b/src/Conversion/ONNXToStablehlo/Tensor/Slice.cpp @@ -70,7 +70,7 @@ struct ONNXSliceOpLoweringToStablehlo : public ConversionPattern { int64_t axis = mlir::cast(value).getInt(); if (axis < 0) axis += rank; - assert((axis >= 0 && axis < (int64_t)rank) && + assert((axis >= 0 && axis < static_cast(rank)) && "Axes contains an out-of-bound index"); axesIntLitToIdx[axis] = idx++; } diff --git a/src/Conversion/ONNXToTOSA/CMakeLists.txt b/src/Conversion/ONNXToTOSA/CMakeLists.txt index 672263e5e2..48aa239454 100644 --- a/src/Conversion/ONNXToTOSA/CMakeLists.txt +++ b/src/Conversion/ONNXToTOSA/CMakeLists.txt @@ -4,17 +4,39 @@ add_onnx_mlir_library(OMONNXToTOSA ConvertONNXToTOSA.cpp DialectBuilder.cpp ONNXToTOSALegalizeUtils.cpp + ONNXToTOSACommon.cpp Math/Elementwise.cpp Math/Gemm.cpp Math/Softmax.cpp - Math/ReduceMean.cpp + Math/Reduce.cpp Math/Conv2D.cpp + Math/MatMul.cpp + Math/Softmax.cpp + Math/Gemm.cpp NN/MaxPoolSingleOut.cpp NN/AveragePool.cpp + NN/QuantizeLinear.cpp + NN/DequantizeLinear.cpp + NN/BatchNorm.cpp + Tensor/Concat.cpp Tensor/Constant.cpp + Tensor/Expand.cpp + Tensor/EyeLike.cpp + Tensor/Flatten.cpp + Tensor/Gather.cpp + Tensor/PaddingOp.cpp Tensor/Reshape.cpp Tensor/Resize.cpp + Tensor/Shrink.cpp + Tensor/Slice.cpp + Tensor/Split.cpp + Tensor/Squeeze.cpp + Tensor/Tile.cpp + Tensor/Transpose.cpp + Tensor/Where.cpp + Flow/EntryPoint.cpp + LINK_LIBS PUBLIC OMONNXOps diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index 2f0a357249..b08087dd77 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -13,31 +13,76 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/DialectConversion.h" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include using namespace mlir; namespace onnx_mlir { void populateONNXToTOSAConversionPattern(ConversionTarget &target, - RewritePatternSet &patterns, TypeConverter &typeConverter, - MLIRContext *ctx) { + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, + int64_t groupedConvThreshold) { // Math populateLoweringONNXElementwiseOpToTOSAPattern( target, patterns, typeConverter, ctx); - populateLoweringONNXReduceMeanOpToTOSAPattern( + populateLoweringONNXReduceOpsToTOSAPattern( target, patterns, typeConverter, ctx); populateLoweringONNXGemmOpToTOSAPattern(target, patterns, typeConverter, ctx); populateLoweringONNXSoftmaxOpToTOSAPattern( target, patterns, typeConverter, ctx); - populateLoweringONNXConvOpToTOSAPattern(target, patterns, typeConverter, ctx); + populateLoweringONNXConvOpToTOSAPattern( + target, patterns, typeConverter, ctx, groupedConvThreshold); + // Tensor + populateLoweringONNXConcatOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXReshapeOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXGatherOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXResizeOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXShrinkOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXConstOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXEyeLikeOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXPadOpToTOSAPattern(target, patterns, typeConverter, ctx); + populateLoweringONNXFlattenOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXSliceOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXSplitOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXSqueezeOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXTileOpToTOSAPattern(target, patterns, typeConverter, ctx); + populateLoweringONNXExpandOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXTransposeOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXWhereOpToTOSAPattern( + target, patterns, typeConverter, ctx); // NN populateLoweringONNXMaxPoolSingleOutOpToTOSAPattern( target, patterns, typeConverter, ctx); populateLoweringONNXAveragePoolOpToTOSAPattern( target, patterns, typeConverter, ctx); - // Tensor - populateLoweringONNXConstOpToTOSAPattern( + populateLoweringONNXQuantizeLinearOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXDequantizeLinearOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXMatMulOpToTOSAPattern( + target, patterns, typeConverter, ctx); + populateLoweringONNXBatchNormalizationOpToTOSAPattern( + target, patterns, typeConverter, ctx); + // Flow + populateLoweringONNXEntryPointOpToTOSAPattern( target, patterns, typeConverter, ctx); populateLoweringONNXReshapeOpToTOSAPattern( target, patterns, typeConverter, ctx); @@ -58,7 +103,17 @@ struct FrontendToTosaLoweringPass FrontendToTosaLoweringPass(const FrontendToTosaLoweringPass &pass) : PassWrapper>() {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnOperation() final; + +public: + Option groupedConvThreshold{*this, "grouped-conv-threshold", + llvm::cl::desc("The threshold used to decompose grouped convolution " + "into a concatenation of tosa.conv2d operations"), + llvm::cl::ZeroOrMore, + llvm::cl::init(std::numeric_limits::max())}; }; void FrontendToTosaLoweringPass::runOnOperation() { @@ -73,7 +128,8 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> std::optional { - if (isTOSASignedInt(type) || isTOSAFloat(type) || mlir::isa(type)) + if (isTOSAInt(type) || isa(type) || isa(type) || + isTOSABool(type)) return type; return std::nullopt; }); @@ -85,10 +141,11 @@ void FrontendToTosaLoweringPass::runOnOperation() { // Define legal dialects and operations target.addLegalDialect(); + mlir::arith::ArithDialect, mlir::shape::ShapeDialect>(); // Define patterns - populateONNXToTOSAConversionPattern(target, patterns, typeConverter, context); + populateONNXToTOSAConversionPattern( + target, patterns, typeConverter, context, groupedConvThreshold); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index 004850b94c..a52c54053c 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -4,7 +4,7 @@ //====------ DialectBuilder.hpp - TOSA dialect builder --------------------===// // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. // // ============================================================================= // @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -56,6 +57,12 @@ Value TosaBuilder::createConst( } bool TosaBuilder::needsRankBroadcast(ValueRange valueRange) { + if (llvm::any_of(valueRange, [](const auto value) { + return !mlir::cast(value.getType()).hasRank(); + })) { + return false; // we have no way to determine the broadcast, so do not + // attempt it + } int64_t firstRank = mlir::cast(valueRange[0].getType()).getRank(); for (Value operand : valueRange) { auto operandType = mlir::cast(operand.getType()); @@ -112,32 +119,43 @@ Value TosaBuilder::getConst(ArrayRef vec, ArrayRef shape) { return constOp; } +Value TosaBuilder::getConst(ArrayRef vec, ArrayRef shape) { + assert(testNumberOfElementsMatch(vec, shape) && + "getConstTensor(): number of elements mismatch."); + + auto constType = RankedTensorType::get(shape, rewriter().getI8Type()); + + Value constOp = this->createConstFromRankedTensorAndVec(vec, constType); + return constOp; +} + Value TosaBuilder::getConst(ArrayRef vec, ArrayRef shape) { auto elementType = rewriter().getF32Type(); Value constOp = this->createConst(vec, shape, elementType); return constOp; } -Value TosaBuilder::getSplattedConst(float val, llvm::ArrayRef shape) { - auto constType = tosa::reduceAxisToOne(shape, rewriter().getF32Type()); +Value TosaBuilder::getSplattedConst(float val, Type dtype, int64_t rank) { + auto constType = tosa::reduceAxisToOne(rank, rewriter().getF32Type()); auto constAttr = DenseElementsAttr::get(constType, val); auto constOp = rewriter().create(loc(), constType, constAttr); - return constOp; + + return rewriter().createOrFold( + loc(), RankedTensorType::get(constType.getShape(), dtype), constOp); } -Value TosaBuilder::transpose(mlir::Value &value, llvm::ArrayRef perm) { +Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) { int64_t valueRank = mlir::cast(value.getType()).getRank(); - assert((valueRank == (int64_t)perm.size()) && + assert((valueRank == static_cast(perm.size())) && "value and perm vector don't have the same rank"); // Create Permutation Const Value permList = this->getConst(perm, {valueRank}); auto valueType = mlir::cast(value.getType()); // get new value type Type newValueType = RankedTensorType::get( - llvm::SmallVector( - valueType.getShape().size(), ShapedType::kDynamic), + llvm::SmallVector(perm.size(), ShapedType::kDynamic), valueType.getElementType()); // create transpose for value Value newValue = tosa::CreateOpAndInfer( @@ -158,7 +176,13 @@ Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef size, return newSliceInput; } -Value TosaBuilder::reshape(mlir::Value &value, llvm::ArrayRef shape) { +std::optional TosaBuilder::gather(Value resultValue, Value inputValue, + Value indicesValue, int32_t batchDims, int32_t axis) { + return tosa::convertGatherOp(rewriter(), loc(), resultValue, inputValue, + indicesValue, batchDims, axis); +} + +Value TosaBuilder::reshape(Value value, llvm::ArrayRef shape) { auto shapeAttr = rewriter().getDenseI64ArrayAttr(shape); auto valueType = mlir::cast(value.getType()); Type newValueType = RankedTensorType::get( @@ -168,26 +192,28 @@ Value TosaBuilder::reshape(mlir::Value &value, llvm::ArrayRef shape) { rewriter(), loc(), newValueType, value, shapeAttr); } -Value TosaBuilder::mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift) { +Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { if (needsRankBroadcast({lhs, rhs})) { llvm::SmallVector valueVec = equalizeRanks({lhs, rhs}); lhs = valueVec[0]; rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs, shift); } -Value TosaBuilder::intdiv(mlir::Value &lhs, mlir::Value &rhs) { +Value TosaBuilder::intdiv(Value &lhs, Value &rhs) { Type lhsElementType = mlir::cast(lhs.getType()).getElementType(); Type rhsElementType = mlir::cast(rhs.getType()).getElementType(); - assert((lhsElementType.isSignlessInteger(32) && - rhsElementType.isSignlessInteger(32)) && - "Tosa IntDivOp needs 32-bit signless integer inputs"); + assert(lhsElementType == rhsElementType && + "Tosa DivOp needs matching element types on lhs and rhs"); if (needsRankBroadcast({lhs, rhs})) { llvm::SmallVector valueVec = equalizeRanks({lhs, rhs}); @@ -196,41 +222,202 @@ Value TosaBuilder::intdiv(mlir::Value &lhs, mlir::Value &rhs) { } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsElementType); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsElementType); return tosa::CreateOpAndInfer( rewriter(), loc(), newValueType, lhs, rhs); } -Value TosaBuilder::reciprocal(mlir::Value &input) { - auto inputType = mlir::cast(input.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(inputType.getRank(), ShapedType::kDynamic), - inputType.getElementType()); - return tosa::CreateOpAndInfer( - rewriter(), loc(), newValueType, input); -} - template -Value TosaBuilder::binaryOp(mlir::Value &lhs, mlir::Value &rhs) { +Value TosaBuilder::binaryOp(Value &lhs, Value &rhs) { if (needsRankBroadcast({lhs, rhs})) { llvm::SmallVector valueVec = equalizeRanks({lhs, rhs}); lhs = valueVec[0]; rhs = valueVec[1]; } auto lhsType = mlir::cast(lhs.getType()); - Type newValueType = RankedTensorType::get( - llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), - lhsType.getElementType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); return tosa::CreateOpAndInfer(rewriter(), loc(), newValueType, lhs, rhs); } -template Value TosaBuilder::binaryOp( - mlir::Value &lhs, mlir::Value &rhs); +template Value TosaBuilder::binaryOp(Value &lhs, Value &rhs); template Value TosaBuilder::binaryOp( mlir::Value &lhs, mlir::Value &rhs); + +template Value TosaBuilder::binaryOp( + mlir::Value &lhs, mlir::Value &rhs); + +template +Value TosaBuilder::unaryOp(mlir::Value &input) { + return tosa::CreateOpAndInfer(rewriter(), loc(), input.getType(), input); +} + +template Value TosaBuilder::unaryOp(mlir::Value &input); + +template Value TosaBuilder::unaryOp( + mlir::Value &input); + +template Value TosaBuilder::unaryOp(mlir::Value &input); + +template Value TosaBuilder::unaryOp(mlir::Value &input); +template Value TosaBuilder::unaryOp(mlir::Value &input); +template Value TosaBuilder::unaryOp(mlir::Value &input); + +template +Value TosaBuilder::compareOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value &lhs, mlir::Value &rhs) { + if (needsRankBroadcast({lhs, rhs})) { + llvm::SmallVector valueVec = equalizeRanks({lhs, rhs}); + lhs = valueVec[0]; + rhs = valueVec[1]; + } + return tosa::CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(rewriter.getI1Type()), lhs, rhs); +} + +mlir::Value TosaBuilder::equal(mlir::Value &lhs, mlir::Value &rhs) { + return compareOp(rewriter(), loc(), lhs, rhs); +} + +mlir::Value TosaBuilder::greater(mlir::Value &lhs, mlir::Value &rhs) { + return compareOp(rewriter(), loc(), lhs, rhs); +} + +mlir::Value TosaBuilder::greaterEqual(mlir::Value &lhs, mlir::Value &rhs) { + return compareOp(rewriter(), loc(), lhs, rhs); +} + +mlir::Value TosaBuilder::less(mlir::Value &lhs, mlir::Value &rhs) { + return this->greater(rhs, lhs); +} + +mlir::Value TosaBuilder::lessEqual(mlir::Value &lhs, mlir::Value &rhs) { + return this->greaterEqual(rhs, lhs); +} + +Value TosaBuilder::select( + mlir::Value &cond, mlir::Value &lhs, mlir::Value &rhs) { + if (needsRankBroadcast({cond, lhs, rhs})) { + llvm::SmallVector valueVec = equalizeRanks({cond, lhs, rhs}); + cond = valueVec[0]; + lhs = valueVec[1]; + rhs = valueVec[2]; + } + auto lhsType = cast(lhs.getType()); + Type newValueType = + (!lhsType.hasRank()) + ? lhsType + : RankedTensorType::get(llvm::SmallVector( + lhsType.getRank(), ShapedType::kDynamic), + lhsType.getElementType()); + return tosa::CreateOpAndInfer( + rewriter(), loc(), newValueType, cond, lhs, rhs); +} + +mlir::Value TosaBuilder::castToNewTensorElementType( + mlir::Value in, mlir::Type newElemTy) { + auto tensorTy = cast(in.getType()); + if (tensorTy.getElementType() == newElemTy) { + // Nothing to do + return in; + } + + auto newTensorTy = tensorTy.clone(newElemTy); + return tosa::CreateOpAndInfer( + rewriter(), loc(), newTensorTy, in); +} + +Value TosaBuilder::sqrt(mlir::Value &input) { + auto inputType = cast(input.getType()); + auto oneHalf = this->getSplattedConst( + 0.5, inputType.getElementType(), inputType.getRank()); + return this->binaryOp(input, oneHalf); +} + +static bool containsNonZero(llvm::SmallVectorImpl &values) { + return llvm::any_of(values, [](int64_t value) { return value != 0; }); +} + +FailureOr TosaBuilder::resizeWindowBasedOps(mlir::Value &value, + const llvm::ArrayRef inputShape, + const llvm::ArrayRef weightSpatialShape, + llvm::SmallVectorImpl &padding, + const llvm::ArrayRef strides, + const llvm::ArrayRef dilation) { + + // Returns the number of unused values at the end of a dimension + auto getOffset = [](int64_t inputDimension, int64_t outputDimension, + int64_t kernelDimension, int64_t padFront, + int64_t padBack, int64_t stride, int64_t dilation) { + int64_t offset = inputDimension + padFront + padBack - + dilation * (kernelDimension - 1) - 1 - + outputDimension * stride + stride; + assert(offset >= 0); + return offset; + }; + + auto getOutputSpatialDimension = + [](int64_t inputDimension, int64_t kernelDimension, int64_t padFront, + int64_t padBack, int64_t stride, int64_t dilation) { + int64_t outputSpatialDimension = + std::floor((inputDimension + padFront + padBack - + dilation * (kernelDimension - 1) - 1)) / + stride + + 1; + return outputSpatialDimension; + }; + + // Only the end of a dimension is cut or padded differently. The beginning + // is unchanged. + llvm::SmallVector cellsToCut; + llvm::SmallVector cellsToPad; + for (int i = 0; i < 2; i++) { + int64_t padFront = padding[2 * i]; + int64_t padBack = padding[2 * i + 1]; + int64_t outputSpatialDimension = + getOutputSpatialDimension(inputShape[i + 1], weightSpatialShape[i], + padFront, padBack, strides[i], dilation[i]); + int64_t offset = getOffset(inputShape[i + 1], outputSpatialDimension, + weightSpatialShape[i], padFront, padBack, strides[i], dilation[i]); + if (offset > padBack) { + cellsToPad.push_back(0); + cellsToCut.push_back(offset - padBack); + } else { + cellsToPad.push_back(padBack - offset); + cellsToCut.push_back(0); + } + } + + // Edge case where the kernel only uses padding values and none of the actual + // input values + if ((inputShape[1] - cellsToCut[0] == 0) || + (inputShape[2] - cellsToCut[1] == 0)) + return rewriter().notifyMatchFailure( + loc(), "the operation does not use any value of the input tensor"); + + // Only slice if we actually need it + if (containsNonZero(cellsToCut)) { + value = this->slice(value, + {inputShape[0], inputShape[1] - cellsToCut[0], + inputShape[2] - cellsToCut[1], inputShape[3]}, + {0, 0, 0, 0}); + } + padding[1] = cellsToPad[0]; + padding[3] = cellsToPad[1]; + + return value; +} + // ============================================================================= // IndexExpr Builder for Lowering using Shape/TOSA Dialect. // ============================================================================= @@ -255,10 +442,8 @@ ElementsAttr IndexExprBuilderForTosa::getConst(Value value) { } Value IndexExprBuilderForTosa::getVal(Value intArrayVal, uint64_t i) { - MultiDialectBuilder create(*this); - // Need to add some acceptable dialects to TOSA conversion. - llvm_unreachable( - "unimplemented (see IndexExprBuilderForKrnl for functionality)."); + // TODO: unimplemented (see IndexExprBuilderForKrnl for functionality). + return {}; } Value IndexExprBuilderForTosa::getShapeVal( diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp index 1050d97053..333f90c456 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.hpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.hpp @@ -38,6 +38,9 @@ struct TosaBuilder : DialectBuilder { TosaBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~TosaBuilder() {} + std::optional gather(mlir::Value resultValue, + mlir::Value inputValue, mlir::Value indicesValue, int32_t batchDims, + int32_t axis); template mlir::Value binaryOp(mlir::Value &lhs, mlir::Value &rhs); mlir::Value mul(mlir::Value &lhs, mlir::Value &rhs, int32_t shift = 0); @@ -46,19 +49,58 @@ struct TosaBuilder : DialectBuilder { mlir::Value transpose(mlir::Value &value, llvm::ArrayRef perm); mlir::Value slice(mlir::Value &inputConst, llvm::ArrayRef size, llvm::ArrayRef start); - mlir::Value reshape(mlir::Value &value, llvm::ArrayRef shape); - mlir::Value reciprocal(mlir::Value &input); + mlir::Value reshape(mlir::Value value, llvm::ArrayRef shape); + + template + mlir::Value unaryOp(mlir::Value &input); + mlir::Value sqrt(mlir::Value &input); + + template + mlir::Value compareOp(mlir::PatternRewriter &rewriter, mlir::Location loc, + mlir::Value &lhs, mlir::Value &rhs); + mlir::Value equal(mlir::Value &lhs, mlir::Value &rhs); + mlir::Value greater(mlir::Value &lhs, mlir::Value &rhs); + mlir::Value greaterEqual(mlir::Value &lhs, mlir::Value &rhs); + mlir::Value less(mlir::Value &lhs, mlir::Value &rhs); + mlir::Value lessEqual(mlir::Value &lhs, mlir::Value &rhs); + + mlir::Value select(mlir::Value &cond, mlir::Value &lhs, mlir::Value &rhs); + mlir::Value castToNewTensorElementType(mlir::Value in, mlir::Type newElemTy); + + /// When using window based ops like maxpool or conv2d, we sometimes have + /// unused values at the end of a spatial dimension. TOSA does not allow that, + /// the input can only have values that are actually used. To achieve this we + /// have to reduce padding and if this is not enough, we even have to insert a + /// slice op. + mlir::FailureOr resizeWindowBasedOps(mlir::Value &value, + llvm::ArrayRef inputShape, + llvm::ArrayRef weightSpatialShape, + llvm::SmallVectorImpl &padding, + llvm::ArrayRef strides = {1, 1}, + llvm::ArrayRef dilation = {0, 0}); mlir::Value getConst( llvm::ArrayRef vec, llvm::ArrayRef shape); mlir::Value getConst( llvm::ArrayRef vec, llvm::ArrayRef shape); + mlir::Value getConst( + llvm::ArrayRef vec, llvm::ArrayRef shape); mlir::Value getConst( llvm::ArrayRef vec, llvm::ArrayRef shape); - // Create a 32-bit float constant operator from a float + // Create a floating-point constant operator from a float // The tensor will have the same rank as shape but all dimensions will // have size 1 (differs from tensorflow impl.) - mlir::Value getSplattedConst(float val, llvm::ArrayRef shape = {}); + // If dtype is provided, it also cast the value to the appropriate dtype. + mlir::Value getSplattedConst(float val, mlir::Type dtype, int64_t rank); + + // Creates a constant of shape <1x1x...x1> of rank `rank` with all values set + // to `value`. + template + mlir::Value getSplattedConst(T value, size_t rank) { + llvm::SmallVector tmpTensor(rank, 1); + std::vector zpVec = std::vector{value}; + return getConst(zpVec, tmpTensor); + } // Adds reshape ops to expand the rank to the max rank of the values. llvm::SmallVector equalizeRanks(mlir::ValueRange valueRange); diff --git a/src/Conversion/ONNXToTOSA/Flow/EntryPoint.cpp b/src/Conversion/ONNXToTOSA/Flow/EntryPoint.cpp new file mode 100644 index 0000000000..9a28fbd269 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Flow/EntryPoint.cpp @@ -0,0 +1,59 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- EntryPoint.cpp - EntryPoint Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file removes the "onnx.EntryPoint" and renames the func.func to @forward +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXEntryPointLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXEntryPointOp::Adaptor; + // This function is from typesTransformsToTorchPass.cpp + LogicalResult matchAndRewrite(ONNXEntryPointOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto functionName = op.getFunc().getRootReference().getValue(); + // Differs from origin to get module + auto module = op->getParentOfType(); + if (!module) + return failure(); + auto mainFuncOp = module.lookupSymbol(functionName); + if (mainFuncOp) { + StringRef forwardRef = "forward"; + auto forwardAttr = StringAttr::get(module.getContext(), forwardRef); + mainFuncOp->setAttr(llvm::StringRef("sym_name"), forwardAttr); + } + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXEntryPointOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp index dc35d39099..0328879fae 100644 --- a/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Conv2D.cpp @@ -19,13 +19,6 @@ #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" #include -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" -#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" -#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" -#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" -#include - using namespace mlir; namespace onnx_mlir { @@ -87,15 +80,16 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op, class ONNXConvOpLoweringToTOSA : public ConversionPattern { public: - ONNXConvOpLoweringToTOSA(MLIRContext *ctx) - : ConversionPattern(ONNXConvOp::getOperationName(), 1, ctx) {} + ONNXConvOpLoweringToTOSA(MLIRContext *ctx, int64_t groupedConvThreshold) + : ConversionPattern(ONNXConvOp::getOperationName(), 1, ctx), + groupedConvThreshold(groupedConvThreshold) {} using OpAdaptor = typename ONNXConvOp::Adaptor; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { OpAdaptor adaptor(operands, op->getAttrDictionary()); auto loc = op->getLoc(); - auto convOp = llvm::cast(op); + auto convOp = mlir::cast(op); TosaBuilder tosaBuilder(rewriter, loc); @@ -106,10 +100,18 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { auto inputType = mlir::cast(input.getType()); auto weightType = mlir::cast(weights.getType()); + if (!inputType || !weightType || !inputType.hasStaticShape() || + !weightType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "only ranked tensor types are supported"); + } + // Get shapehelper for autopad attributes IndexExprBuilderForTosa createTosaIE(rewriter, convOp->getLoc()); ONNXConvOpShapeHelper shapeHelper(op, operands, &createTosaIE); - shapeHelper.computeShapeAndAssertOnFailure(); + if (shapeHelper.computeShape().failed()) { + return rewriter.notifyMatchFailure(convOp, "Could not infer shapes"); + } auto weightShape = weightType.getShape(); @@ -144,8 +146,18 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { llvm::SmallVector pads; IndexExpr::getLiteral(shapeHelper.pads, pads); // reorder padding values - DenseI64ArrayAttr newPads = - rewriter.getDenseI64ArrayAttr({pads[0], pads[2], pads[1], pads[3]}); + llvm::SmallVector reorderedPads = { + pads[0], pads[2], pads[1], pads[3]}; + FailureOr resizedInput = tosaBuilder.resizeWindowBasedOps(newInput, + cast(newInput.getType()).getShape(), + {weightShape[2], weightShape[3]}, reorderedPads, shapeHelper.strides, + shapeHelper.dilations); + + if (failed(resizedInput)) + return rewriter.notifyMatchFailure( + op, "could not resize input to match parameters"); + + DenseI64ArrayAttr newPads = rewriter.getDenseI64ArrayAttr(reorderedPads); // Handle group parameter by creating multiple convs const int64_t group = adaptor.getGroup(); @@ -159,9 +171,40 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { convOp->getLoc(), newConvOutputType, newInput, newWeight, bias, newPads, strides, dilations); } else { - conv2D = createConvInGroups(rewriter, convOp, tosaBuilder, resultType, - weightShape, newInput, newWeight, bias, group, newPads, strides, - dilations); + auto inputChannels = inputType.getDimSize(1); + auto outputChannels = resultType.cast().getDimSize(1); + if (group == inputChannels && (outputChannels % inputChannels == 0)) { + // If the group == inputChannels and + // outputChannels == inputChannels * integerNumber, + // this grouped convolution is equal to a Depthwise convolution. + + // Convert weights [OC,IC,KH,KW] -> [KH, KW, OC, M(ChannelMultiplier)] + Value transposedWeight = tosaBuilder.transpose(weights, {2, 3, 0, 1}); + // A reshape op is needed to adhere to the TOSA standard + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + Value newWeight = tosaBuilder.reshape( + transposedWeight, {weightShape[2], weightShape[3], inputChannels, + outputChannels / inputChannels}); + + Type newConvOutputType = RankedTensorType::get( + llvm::SmallVector(4, ShapedType::kDynamic), + resultType.cast().getElementType()); + + conv2D = tosa::CreateOpAndInfer(rewriter, + convOp->getLoc(), newConvOutputType, newInput, newWeight, bias, + newPads, strides, dilations); + } else if (group <= groupedConvThreshold) { + // Decompose group convolution into a concatenation of tosa.conv2d ops + // can be costly, so only allow it when the number of groups is less + // than configurable threshold. + + conv2D = createConvInGroups(rewriter, convOp, tosaBuilder, resultType, + weightShape, newInput, newWeight, bias, group, newPads, strides, + dilations); + } else { + return rewriter.notifyMatchFailure( + op, "this type of grouped Conv is not supported"); + } } // Convert output [N,OH,OW,OC] -> [N,OC,OH,OW] @@ -170,13 +213,17 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern { rewriter.replaceOp(convOp, {newOutput}); return success(); } + +private: + int64_t groupedConvThreshold; }; + } // namespace void populateLoweringONNXConvOpToTOSAPattern(ConversionTarget &target, - RewritePatternSet &patterns, TypeConverter &typeConverter, - MLIRContext *ctx) { - patterns.insert(ctx); + RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx, + int64_t groupedConvThreshold) { + patterns.insert(ctx, groupedConvThreshold); } } // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 2e105d2dc5..9ff84b1788 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -12,8 +12,15 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include using namespace mlir; @@ -24,25 +31,120 @@ struct TOSADialectOp { using Op = mlir::tosa::NegateOp; }; -namespace { +struct IsIntOrFloat { + static LogicalResult checkType( + ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { + if (!isa(scalarType) && !isTOSAInt(scalarType)) { + return rewriter.notifyMatchFailure( + op, "this operation only supports signed integer or float types"); + } + return success(); + } +}; + +struct IsInt { + static LogicalResult checkType( + ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { + if (!isTOSAInt(scalarType)) { + return rewriter.notifyMatchFailure( + op, "this operation only supports int types"); + } + return success(); + } +}; + +struct IsFloat { + static LogicalResult checkType( + ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { + if (!isa(scalarType)) { + return rewriter.notifyMatchFailure( + op, "this operation only supports float types"); + } + return success(); + } +}; + +struct IsBool { + static LogicalResult checkType( + ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { + if (!isTOSABool(scalarType)) { + return rewriter.notifyMatchFailure( + op, "this operation only supports bool type"); + } + return success(); + } +}; + +template +LogicalResult checkBasicTosaRequirementsForBinaryOps( + ConversionPatternRewriter &rewriter, Operation *op, OpAdaptorT adaptor, + Type resultType) { + Value lhs = adaptor.getOperands()[0]; + auto lhsType = dyn_cast(lhs.getType()); + + Value rhs = adaptor.getOperands()[1]; + auto rhsType = dyn_cast(rhs.getType()); + + auto resultTensorType = dyn_cast(resultType); + if (!lhsType || !rhsType || !resultTensorType) { + return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); + } + + Type resultElementType = resultTensorType.getElementType(); + + if (TosaOpT::template hasTrait< + ::mlir::OpTrait::SameOperandsAndResultElementType>()) { + if (lhsType.getElementType() != rhsType.getElementType() || + lhsType.getElementType() != resultElementType) { + return rewriter.notifyMatchFailure( + op, "lhs, rhs and result must have the same type"); + } + } + + if (failed(TypeChecker::checkType(rewriter, resultElementType, op))) { + return failure(); + } + + return success(); +} // Element-wise unary ops lowering to TOSA dialect. //===----------------------------------------------------------------------===// -template +template class ONNXElementwiseUnaryOpLoweringToTOSA - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename ElementwiseUnaryOp::Adaptor; - LogicalResult matchAndRewrite(ElementwiseUnaryOp op, OpAdaptor adaptor, + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ElementwiseUnaryOpONNX::Adaptor; + LogicalResult matchAndRewrite(ElementwiseUnaryOpONNX op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp>( - op, op.getType(), adaptor.getX()); + + Value input = *adaptor.getODSOperands(0).begin(); + auto inputType = dyn_cast(input.getType()); + Value output = op.getResult(); + auto outputType = dyn_cast(output.getType()); + + if (!inputType || !outputType) { + return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); + } + + Type inputElementType = inputType.getElementType(); + Type outputElementType = outputType.getElementType(); + + if (failed(InputType::checkType(rewriter, inputElementType, op))) + return failure(); + + if (failed(InputType::checkType(rewriter, outputElementType, op))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), *adaptor.getODSOperands(0).begin()); return success(); } }; -template +template class ONNXBinaryElementwiseOpLoweringToTOSA : public OpConversionPattern { public: @@ -51,33 +153,21 @@ class ONNXBinaryElementwiseOpLoweringToTOSA LogicalResult matchAndRewrite(ONNXOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value lhs = adaptor.getA(); - auto lhsType = mlir::dyn_cast(lhs.getType()); - - Value rhs = adaptor.getB(); - auto rhsType = mlir::dyn_cast(rhs.getType()); - - auto resultType = mlir::dyn_cast(op.getResult().getType()); - if (!lhsType || !rhsType || !resultType) { - return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); - } - - Type resultElementType = resultType.getElementType(); + if (failed(checkBasicTosaRequirementsForBinaryOps(rewriter, op, adaptor, op.getResult().getType()))) + return failure(); - if (!resultElementType.isIntOrFloat()) { - return rewriter.notifyMatchFailure( - op, "only int and float are supported"); - } + auto loc = op.getLoc(); + Value lhs = adaptor.getOperands()[0]; + Value rhs = adaptor.getOperands()[1]; if (TosaOpT::template hasTrait< mlir::OpTrait::ResultsBroadcastableShape>()) { IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc()); ONNXBroadcastOpShapeHelper shapeHelper(op, {}, &createTosaIE); - shapeHelper.computeShapeAndAssertOnFailure(); - - if (shapeHelper.hasRankBroadcast()) { + if (shapeHelper.computeShape().succeeded() && + shapeHelper.hasRankBroadcast()) { TosaBuilder tosaBuilder(rewriter, loc); llvm::SmallVector newValues = tosaBuilder.equalizeRanks({lhs, rhs}); @@ -92,20 +182,24 @@ class ONNXBinaryElementwiseOpLoweringToTOSA } }; -class ONNXFloorOpLoweringToTOSA : public OpConversionPattern { +class ONNXMulOpLoweringToTosa : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename ONNXFloorOp::Adaptor; - LogicalResult matchAndRewrite(ONNXFloorOp op, OpAdaptor adaptor, + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (failed(checkBasicTosaRequirementsForBinaryOps( + rewriter, op, adaptor, op.getResult().getType()))) + return failure(); - auto scalarType = getElementTypeOrSelf(adaptor.getX()); - if (!isTOSAFloat(scalarType)) - return rewriter.notifyMatchFailure( - op, "`tosa.floor` only supports float types"); + Value lhs = adaptor.getA(); + Value rhs = adaptor.getB(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value mulOp = tosaBuilder.mul(lhs, rhs); + copySingleResultType(op, mulOp); + rewriter.replaceOp(op, {mulOp}); - rewriter.replaceOpWithNewOp( - op, op.getType(), adaptor.getX()); return success(); } }; @@ -130,6 +224,182 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { } }; +// Support for prelu/leakyrelu adapted from tensorflow to tosa implementation +static LogicalResult legalizeFloatingPointPrelu(Operation *op, + PatternRewriter &rewriter, Value input, Value alphaOrSlope, + TensorType outputType) { + auto loc = op->getLoc(); + TosaBuilder tosaBuilder(rewriter, loc); + Value constZero = tosaBuilder.getSplattedConst( + 0.0, outputType.getElementType(), outputType.getRank()); + + auto mul = tosaBuilder.mul(input, alphaOrSlope); + auto greaterEqual = tosaBuilder.greaterEqual(input, constZero); + auto select = tosaBuilder.select(greaterEqual, input, mul); + copySingleResultType(op, select); + rewriter.replaceOp(op, {select}); + return success(); +} + +class ONNXLeakyReluOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = ONNXLeakyReluOp::Adaptor; + LogicalResult matchAndRewrite(ONNXLeakyReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = cast(op.getResult().getType()); + if (failed(IsIntOrFloat::checkType( + rewriter, outputType.getElementType(), op))) { + return failure(); + } + + // ONNX docs: alpha : float (default 0.01) + float alpha = 0.01; + FloatAttr alphaAttr = adaptor.getAlphaAttr(); + if (alphaAttr) { + // No easy interface in MLIR to get value as float + alpha = alphaAttr.getValueAsDouble(); + } + auto loc = op->getLoc(); + TosaBuilder tosaBuilder(rewriter, loc); + return legalizeFloatingPointPrelu(op, rewriter, adaptor.getX(), + tosaBuilder.getSplattedConst( + alpha, outputType.getElementType(), outputType.getRank()), + outputType); + } +}; + +template +class ONNXComparisonOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OnnxCompOp::Adaptor; + LogicalResult matchAndRewrite(OnnxCompOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value input1 = adaptor.getA(); + auto input1ElemType = cast(input1.getType()).getElementType(); + if (failed(IsIntOrFloat::checkType(rewriter, input1ElemType, op))) { + return failure(); + } + + Value input2 = adaptor.getB(); + auto input2ElemType = cast(input2.getType()).getElementType(); + if (input1ElemType != input2ElemType) { + return failure(); + } + + Value res; + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + if constexpr (std::is_same_v) { + res = tosaBuilder.equal(input1, input2); + } else if constexpr (std::is_same_v) { + res = tosaBuilder.greaterEqual(input1, input2); + } else if constexpr (std::is_same_v) { + res = tosaBuilder.greater(input1, input2); + } else if constexpr (std::is_same_v) { + res = tosaBuilder.lessEqual(input1, input2); + } else if constexpr (std::is_same_v) { + res = tosaBuilder.less(input1, input2); + } + copySingleResultType(op, res); + rewriter.replaceOp(op, {res}); + return success(); + } +}; + +class ONNXClipOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = ONNXClipOp::Adaptor; + LogicalResult matchAndRewrite(ONNXClipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto res = adaptor.getInput(); + auto min = adaptor.getMin(); + auto max = adaptor.getMax(); + + auto matchIntOrFloat = [&](Value val) -> std::tuple { + APInt valueInt(64, 0); + APFloat valueFloat(0.0f); + if (matchPattern(val, m_ConstantInt(&valueInt))) { + auto intVal = valueInt.getSExtValue(); + return {true, intVal, static_cast(intVal)}; + } + if (matchPattern(val, m_ConstantFloat(&valueFloat))) { + float floatVal = valueFloat.convertToFloat(); + return {true, static_cast(floatVal), floatVal}; + } + return {false, 0, 0.0}; + }; + + // Use ClampOp if min and max are splat constants. + // Otherwise, MaximumOp and MinimumOp to clamp min and max, respectively. + auto [isSplatConstMin, minInt, minFloat] = matchIntOrFloat(min); + auto [isSplatConstMax, maxInt, maxFloat] = matchIntOrFloat(max); + if (isSplatConstMin && isSplatConstMax) { + rewriter.replaceOpWithNewOp(op, op.getType(), res, + rewriter.getI64IntegerAttr(minInt), + rewriter.getI64IntegerAttr(maxInt), + rewriter.getF32FloatAttr(minFloat), + rewriter.getF32FloatAttr(maxFloat)); + } else { + if (!isNoneValue(min)) { + res = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), op.getType(), res, min); + } + if (!isNoneValue(max)) { + res = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), op.getType(), res, max); + } + rewriter.replaceOp(op, res); + } + return success(); + } +}; + +class ONNXCastOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXCastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + + auto resultTy = dyn_cast_if_present( + getTypeConverter()->convertType(op.getType())); + if (!resultTy) { + return rewriter.notifyMatchFailure(op, "expected valid result type"); + } + auto input = adaptor.getInput(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure(op, "expected valid input type"); + } + if (isa(inputTy.getElementType()) && + isa(resultTy.getElementType())) { + // ONNX.Cast has truncating behavior, and tosa.cast has rounds + // half-to-even. We simulate truncate by floor for positive values and + // ceil for negative ones. Conversion to boolean works the same between + // onnx.Cast and tosa.cast. + if (resultTy.getElementType().getIntOrFloatBitWidth() != 1) { + auto zero = tosaBuilder.getSplattedConst( + 0.0f, inputTy.getElementType(), resultTy.getRank()); + auto positive = tosaBuilder.greaterEqual(input, zero); + + auto floor = tosaBuilder.unaryOp(input); + auto ceil = tosaBuilder.unaryOp(input); + input = tosaBuilder.select(positive, floor, ceil); + } + } + + rewriter.replaceOpWithNewOp(op, resultTy, input); + return success(); + } +}; + class ONNXDivOpLoweringToTOSA : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -142,31 +412,335 @@ class ONNXDivOpLoweringToTOSA : public OpConversionPattern { TosaBuilder tosaBuilder(rewriter, op->getLoc()); - if (resultElementType.isSignlessInteger(32)) { - // tosa::IntDivOp takes 32-but signless integers as inputs + if (isa(resultElementType)) { Value divOp = tosaBuilder.intdiv(lhs, rhs); + copySingleResultType(op, divOp); rewriter.replaceOp(op, {divOp}); return success(); } - // If it is not a 32-bit signless integer, decompose ONNXDivOp into - // tosa::ReciprocalOp and tosa::MulOp - Value reciprocalOp = tosaBuilder.reciprocal(rhs); + // For floating point types, decompose ONNXDivOp into + // tosa::ReciprocalOp and tosa::MulOp. + Value reciprocalOp = tosaBuilder.unaryOp(rhs); Value mulOp = tosaBuilder.mul(lhs, reciprocalOp); + copySingleResultType(op, mulOp); rewriter.replaceOp(op, {mulOp}); return success(); } }; -} // namespace +class ONNXSqrtOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resultTensorType = cast(op.getResult().getType()); + if (failed(IsFloat::checkType( + rewriter, resultTensorType.getElementType(), op))) { + return failure(); + } + + Value input = op.getX(); + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value sqrtOp = tosaBuilder.sqrt(input); + rewriter.replaceOp(op, {sqrtOp}); + return success(); + } +}; + +class ONNXEluOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXEluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // ELU(x) = x if x >= 0 + // alpha * (exp(x) - 1.) if x < 0 + + auto resultTensorType = cast(op.getResult().getType()); + if (failed(IsFloat::checkType( + rewriter, resultTensorType.getElementType(), op))) { + return failure(); + } + + Value input = op.getX(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + + Value one = tosaBuilder.getSplattedConst( + 1.0, resultTensorType.getElementType(), resultTensorType.getRank()); + Value alpha = + tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), + resultTensorType.getElementType(), resultTensorType.getRank()); + Value constZero = tosaBuilder.getSplattedConst( + 0.0, resultTensorType.getElementType(), resultTensorType.getRank()); + + Value exp = tosaBuilder.unaryOp(input); + copySingleResultType(op, exp); + Value expMinusOne = tosaBuilder.binaryOp(exp, one); + Value alphaTimesExpMinusOne = tosaBuilder.mul(expMinusOne, alpha); + Value greaterEqual = tosaBuilder.greaterEqual(input, constZero); + auto select = + tosaBuilder.select(greaterEqual, input, alphaTimesExpMinusOne); + copySingleResultType(op, select); + rewriter.replaceOp(op, {select}); + return success(); + } +}; + +class ONNXHardSigmoidOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXHardSigmoidOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // ONNXHardSigmoid -> TOSA: + // - tosa.add(input, beta/alpha) + // - tosa.clamp(add) with min = 0, and max = 1/alpha + // - tosa.mul(clamp, alpha) + Value input = adaptor.getX(); + + auto resultType = cast(op.getResult().getType()); + auto resultElementType = resultType.getElementType(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + + auto alpha = adaptor.getAlpha(); + + auto betaOverAlpha = adaptor.getBeta(); + betaOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven); + + APFloat oneOverAlpha(alpha.getSemantics(), 1); + oneOverAlpha.divide(alpha, APFloat::rmNearestTiesToEven); + + if (!resultType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "HardSigmoid: Static shape required to create splatted const"); + } + + Value constBetaOverAlpha = + tosaBuilder.getSplattedConst(betaOverAlpha.convertToDouble(), + resultElementType, resultType.getRank()); + Value constAlpha = tosaBuilder.getSplattedConst( + alpha.convertToDouble(), resultElementType, resultType.getRank()); + + auto addOp = + tosaBuilder.binaryOp(input, constBetaOverAlpha); + Value clampOp = tosa::CreateOpAndInfer(rewriter, + op->getLoc(), resultType, addOp, rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(oneOverAlpha.convertToDouble()), + rewriter.getF32FloatAttr(0), + rewriter.getF32FloatAttr(oneOverAlpha.convertToDouble())); + auto mulOp = tosaBuilder.mul(clampOp, constAlpha); + copySingleResultType(op, mulOp); + rewriter.replaceOp(op, {mulOp}); + return success(); + } +}; + +class ONNXPReluOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = ONNXPReluOp::Adaptor; + LogicalResult matchAndRewrite(ONNXPReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = cast(op.getResult().getType()); + if (failed(IsIntOrFloat::checkType( + rewriter, outputType.getElementType(), op))) { + return failure(); + } + + return legalizeFloatingPointPrelu( + op, rewriter, adaptor.getX(), adaptor.getSlope(), outputType); + } +}; + +class ONNXSoftplusOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXSoftplusOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = cast(op.getResult().getType()); + if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { + return failure(); + } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSoftplusOp: Rank required to create splatted const"); + } + + Value input = adaptor.getX(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + auto one = tosaBuilder.getSplattedConst( + 1.0, outputType.getElementType(), outputType.getRank()); + + auto expOp = tosaBuilder.unaryOp(input); + copySingleResultType(op, expOp); + auto expPlusOne = tosaBuilder.binaryOp(expOp, one); + auto logOp = tosaBuilder.unaryOp(expPlusOne); + rewriter.replaceOp(op, {logOp}); + return success(); + } +}; + +class ONNXSeluOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXSeluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = cast(op.getResult().getType()); + if (failed(IsFloat::checkType(rewriter, outputType.getElementType(), op))) { + return failure(); + } + + Value input = adaptor.getX(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXSeluOp: Rank required to create splatted const"); + } + + Value alpha = + tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), + outputType.getElementType(), outputType.getRank()); + Value gamma = + tosaBuilder.getSplattedConst(adaptor.getGamma().convertToDouble(), + outputType.getElementType(), outputType.getRank()); + Value constZero = tosaBuilder.getSplattedConst( + 0.0, outputType.getElementType(), outputType.getRank()); + + Value exp = tosaBuilder.unaryOp(input); + Value expTimesAlpha = tosaBuilder.mul(exp, alpha); + Value expTimesAlphaMinusAlpha = + tosaBuilder.binaryOp(expTimesAlpha, alpha); + + Value greater = tosaBuilder.greater(input, constZero); + auto select = tosaBuilder.select(greater, input, expTimesAlphaMinusAlpha); + Value valTimesGamma = tosaBuilder.mul(select, gamma); + + rewriter.replaceOp(op, {valTimesGamma}); + return success(); + } +}; + +class ONNXThresholdedReluOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXThresholdedReluOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = cast(op.getResult().getType()); + if (failed(IsIntOrFloat::checkType( + rewriter, outputType.getElementType(), op))) { + return failure(); + } + if (!outputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "ONNXThresholdedReluOp: Rank required to create splatted const"); + } + + Value input = adaptor.getX(); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + auto alpha = + tosaBuilder.getSplattedConst(adaptor.getAlpha().convertToDouble(), + outputType.getElementType(), outputType.getRank()); + auto zero = tosaBuilder.getSplattedConst( + 0.0, outputType.getElementType(), outputType.getRank()); + + auto greater = tosaBuilder.greater(input, alpha); + auto select = tosaBuilder.select(greater, input, zero); + + rewriter.replaceOp(op, {select}); + return success(); + } +}; + +static void populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern( + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA, + ONNXBinaryElementwiseOpLoweringToTOSA>(typeConverter, ctx); +} + +static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA, + ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); +} void populateLoweringONNXElementwiseOpToTOSAPattern(ConversionTarget &target, RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx) { - patterns.insert, - ONNXBinaryElementwiseOpLoweringToTOSA, - ONNXBinaryElementwiseOpLoweringToTOSA, - ONNXFloorOpLoweringToTOSA, ONNXReluOpLoweringToTOSA, - ONNXDivOpLoweringToTOSA>(typeConverter, ctx); + patterns.insert, + ONNXComparisonOpLoweringToTOSA, + ONNXComparisonOpLoweringToTOSA, + ONNXComparisonOpLoweringToTOSA, + ONNXComparisonOpLoweringToTOSA>(typeConverter, ctx); + + populateLoweringONNXElementwiseBinaryTemplateOpToTOSAPattern( + patterns, typeConverter, ctx); + populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( + patterns, typeConverter, ctx); } } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index 556de3f7d7..1d1182d03b 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -4,7 +4,7 @@ //===---------------- Gemm.cpp - Gemm Op ----------------------------------===// // -// Copyright (c) 2022 Advanced Micro Devices, Inc. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. // // ============================================================================= // @@ -49,6 +49,10 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { FloatAttr beta = adaptor.getBetaAttr(); auto AType = mlir::cast(A.getType()); auto BType = mlir::cast(B.getType()); + if (!AType.hasRank() || !BType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "Lowering Gemm to MatMul requires ranked A and B."); + } auto shapeA = AType.getShape(); auto shapeB = BType.getShape(); auto resultType = mlir::cast( @@ -102,7 +106,8 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { // A if (alpha && alpha.getValueAsDouble() != 1.) { Value splattedConstAlpha = tosaBuilder.getSplattedConst( - (float)alpha.getValueAsDouble(), newShapeA); + static_cast(alpha.getValueAsDouble()), AType.getElementType(), + newShapeA.size()); alphaMulResult = tosaBuilder.mul(splattedConstAlpha, A, 0); } @@ -110,7 +115,8 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { // a multiplication for beta * C if (beta && isCPresent && beta.getValueAsDouble() != 1.) { Value splattedConstBeta = tosaBuilder.getSplattedConst( - (float)beta.getValueAsDouble(), newShapeA); + static_cast(beta.getValueAsDouble()), AType.getElementType(), + newShapeA.size()); betaMulResult = tosaBuilder.mul(splattedConstBeta, C, 0); } diff --git a/src/Conversion/ONNXToTOSA/Math/MatMul.cpp b/src/Conversion/ONNXToTOSA/Math/MatMul.cpp new file mode 100644 index 0000000000..219ed882d7 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Math/MatMul.cpp @@ -0,0 +1,487 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- ONNXMatMulOp.cpp - ONNXMatMulOp --------------===// +// +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNXMatMulOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +// TOSA matmul is performed on two 3D inputs and generates a 3D output. +// Lower ranked tensors are dim-1 reshaped up to 3D +Value reshapeUpTo3DTensor(Value tensor, TosaBuilder &builder) { + auto tensorTy = cast(tensor.getType()); + auto rank = tensorTy.getRank(); + + assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3"); + if (rank == 3) + return tensor; + + ArrayRef shape = tensorTy.getShape(); + SmallVector newShape({1, 1, 1}); + + if (rank == 2) { // batchsize = 1 + newShape[1] = shape[0]; + newShape[2] = shape[1]; + } else { // rank 1 + newShape[2] = shape[0]; + } + + return builder.reshape(tensor, newShape); +} + +// Obtaining the rank broadcasted shapes of tensors makes it easier to +// construct the input and output reshaping logic. +void getRankBroadcastedShape(Value tensor, const int64_t maxInputRank, + bool isRHS, SmallVectorImpl &bcastedShape) { + auto tensorTy = cast(tensor.getType()); + ArrayRef tensorShape = tensorTy.getShape(); + int64_t tensorRank = tensorTy.getRank(); + + const int64_t bcastDims = maxInputRank - tensorRank; + + if (isRHS && (tensorRank == 1) && bcastDims) { + // RHS with rank1 is special. It be synthetically transposed to dim[:-2] + for (int32_t i = 0; i < bcastDims - 1; i++) + bcastedShape.push_back(1); + bcastedShape.push_back(tensorShape[0]); + bcastedShape.push_back(1); + } else { + if (bcastDims > 0) { // rank broadcast + for (uint32_t i = 0; i < bcastDims; i++) + bcastedShape.push_back(1); + } + for (const auto &dim : tensorShape) + bcastedShape.push_back(dim); + } +} + +Type getMatMulOutputType(Type inputElemTy, PatternRewriter &rewriter) { + Type outputElemTy; + if (auto floatTy = dyn_cast(inputElemTy)) { + if (floatTy.isBF16() || floatTy.isF16() || floatTy.isF32()) { + // Always accumulate on f32 + outputElemTy = rewriter.getF32Type(); + } + } else if (auto integerTy = dyn_cast(inputElemTy)) { + if (integerTy.isInteger(/*width=*/8)) { + outputElemTy = rewriter.getIntegerType(/*width=*/32); + } else if (integerTy.isInteger(/*width=*/16)) { + outputElemTy = rewriter.getIntegerType(/*width=*/48); + } + } + return outputElemTy; +} + +// Lowering based on the lowering of torch-mlir +class ONNXMatMulOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using Adaptor = ONNXMatMulOp::Adaptor; + + LogicalResult matchAndRewrite(ONNXMatMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + TosaBuilder builder(rewriter, op->getLoc()); + + auto lhs = adaptor.getA(); + auto rhs = adaptor.getB(); + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + if (!lhsTy || !rhsTy || !lhsTy.hasStaticShape() || + !rhsTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "only ranked tensor types are supported"); + } + + auto lhsRank = lhsTy.getRank(); + auto rhsRank = rhsTy.getRank(); + + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + + auto lhsElemTy = lhsTy.getElementType(); + auto rhsElemTy = rhsTy.getElementType(); + + if (lhsElemTy != rhsElemTy) { + return rewriter.notifyMatchFailure( + op, "expected both inputs to have same element type"); + } + + auto outputElemType = getMatMulOutputType(lhsElemTy, rewriter); + if (!outputElemType) { + return rewriter.notifyMatchFailure(op, + "Only i8 and i16 integer and bf16, f16 and " + "f32 float types are valid"); + } + + int64_t maxInputRank = lhsRank > rhsRank ? lhsRank : rhsRank; + // If performing dot product on vectors, the RHS is synthetically transposed + if (maxInputRank == 1) + maxInputRank++; + + // Step: Rank broadcast the two inputs. + SmallVector lhsBroadcastedShape; + SmallVector rhsBroadcastedShape; + getRankBroadcastedShape(lhs, maxInputRank, false, lhsBroadcastedShape); + getRankBroadcastedShape(rhs, maxInputRank, true, rhsBroadcastedShape); + + auto rankBroadcastedLhs = lhsRank == maxInputRank + ? lhs + : builder.reshape(lhs, lhsBroadcastedShape); + + auto rankBroadcastedRhs = rhsRank == maxInputRank + ? rhs + : builder.reshape(rhs, rhsBroadcastedShape); + + // Where broadcasting is required in one or more batch dims, the following + // is done. + // Where all batch dims are involved in broadcasting: + // Given A: 3x1x5x6 and B: 1x4x6x7 + // 1. Reshape A to 1x15x6 (squeeze all batchdims into dim1) + // 2. Transpose B to 6x1x4x7, Reshape to 1x6x28 + // 3. tosa.Matmul 1x15x6 1x6x28 = 1x15x28 + // 4. Reshape out to 3x5x4x7, Transpose to 3x4x5x7 + // Where there are batch dimensions that are broadcast and not, the + // treatment is to have dim0 correspond to product of all non-broadcast + // dimsizes: + // Given A: 4x8x16x32 B: 8x32x17 + // 1. Reshape A to 8x64x32 (squeeze all unbroadcasted dims into dim0, + // broadcasted dims into dim1) + // 2. No transpose or reshape of B as its batchdims are not broadcast to. + // 3. tosa.Matmul 8x64x32 8x32x17 = 8x64x17 + // 4. Reshape to 8x4x16x17, Transpose to 4x8x16x17 + + // Inputs to the tosa.matmul + Value matmulLhs; + Value matmulRhs; + + using TensorShape_t = struct { + int64_t dim; + int64_t shape; + }; + + // Transpose needs to done if transposeDims are not non-monotonically + // increasing. E.g. [0, 1, 2, 3]: No transpose [1, 0, 2, 3]: Transpose dim0 + // and dim1 The order need not be sequential, since one or more dims may + // have been removed due to broadcasting. + auto isTransposeRequired = [](ArrayRef transposedDims) -> bool { + int32_t lastDim = -1; + for (int32_t dim : transposedDims) { + if (lastDim > dim) + return true; + lastDim = dim; + } + return false; + }; + + SmallVector commonElems; + SmallVector lhsSqueezedElems; + SmallVector rhsSqueezedElems; + + // Check if we need to perform the broadcast on batch dim + // Not needed if max rank < 3, or if maxrank == 3 and dim[0] matches + auto needsBatchDimBroadcast = [&]() -> bool { + if (maxInputRank < 3) { + return false; + } + return maxInputRank != 3 || + lhsBroadcastedShape[0] != rhsBroadcastedShape[0]; + }; + + const bool performBatchDimBroadcast = needsBatchDimBroadcast(); + if (!performBatchDimBroadcast) { + // Simple with no broadcasting artifacts. Just reshape up to 3D + matmulLhs = reshapeUpTo3DTensor(rankBroadcastedLhs, builder); + matmulRhs = reshapeUpTo3DTensor(rankBroadcastedRhs, builder); + } else { + // In this case, either or both input matrices involve broadcasting on + // their batch dimensions. For example: + // 4x5x6, 1x6x7 -> 4x5x7 + // 4x1x5x6, 1x3x6x7 -> 4x3x5x7 + // Though maxInputRank is necessarily >=3 here, individual matrices may be + // lower rank. + // E.g. 3x4x5x6, 6 -> 3x4x5 + + // These are the accumulated products of the shape of each dim: + // 1. common dimensions: upper dimensions (dims other than two rightmost) + // whose shapes are the same for both LHS and RHS. + // 2. LHS squeezed dimensions: all dimensions of LHS that involve + // broadcasting in either direction, plus the LHS[-2] shape + // 3. RHS squeezed dimensions: all dimensions of RHS that involve + // broadcasting in either direction, plus the RHS[-1] shape + int64_t commonValue = 1; + int64_t lhsSqueezedValue = 1; + int64_t rhsSqueezedValue = 1; + + // For both LHS and RHS, the dimensions are separated into the common, + // squeezed and remaining dim. E.g. given + // LHS = 3x4x5x6 + // RHS = 1x4x6x7 + // common = {{dim=1, shape=4}} + // lhs squeezed = {{dim=0, shape=3}, + // {dim=2, shape=5}} + // rhs squeezed = {{dim=0, shape=1}, + // {dim=2, shape=7}} + // The matmul dim is LHS[-1] and RHS[-2], i.e. 6. + // Once this is obtained, LHS and RHS are expressed as: + // LHS = {common, lhs_squeezed, matmul_dim} + // RHS = {common, matmul_dim, rhs_squeezed} + // The matmul is then performed to obtain output: + // matmul_out = {common, lhs_squeezed, rhs_squeezed} + // Finally, we reshape to 'unsqueeze' the LHS and RHS parts and transpose + // them back to their correct positions. + + SmallVector transposedLhsShape; + SmallVector transposedLhsDims; + + // Step: generate the common dim/shape information + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + if (lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { + commonValue *= lhsBroadcastedShape[dim]; + commonElems.push_back({dim, lhsBroadcastedShape[dim]}); + } + } + + // Step: generate the LHS squeezed dim/shape information. + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]); + if (!isDynamicDim && + lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) { + lhsSqueezedValue *= lhsBroadcastedShape[dim]; + lhsSqueezedElems.push_back({dim, lhsBroadcastedShape[dim]}); + } + } + // including LHS[-2] + lhsSqueezedElems.push_back( + {maxInputRank - 2, lhsBroadcastedShape[maxInputRank - 2]}); + lhsSqueezedValue *= lhsBroadcastedShape[maxInputRank - 2]; + + // Step: Create the tosa.transpose array. If this array has a + // non-monotonic series of dims, perform transpose. + // First the common_elems + for (uint32_t i = 0; i < commonElems.size(); i++) { + transposedLhsShape.push_back(commonElems[i].shape); + transposedLhsDims.push_back(commonElems[i].dim); + } + // then the lhs_squeezed elems + for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { + transposedLhsShape.push_back(lhsSqueezedElems[i].shape); + transposedLhsDims.push_back(lhsSqueezedElems[i].dim); + } + // then the final dim + transposedLhsDims.push_back(maxInputRank - 1); + transposedLhsShape.push_back(lhsBroadcastedShape[maxInputRank - 1]); + + Value lhsReshapeInput = rankBroadcastedLhs; + if (isTransposeRequired(transposedLhsDims)) { + lhsReshapeInput = + builder.transpose(rankBroadcastedLhs, transposedLhsDims); + } + + // LHS = {common, lhs_squeezed, matmul_dim} + SmallVector newLhsShape( + {1, 1, lhsBroadcastedShape[maxInputRank - 1]}); + newLhsShape[0] = commonValue; + newLhsShape[1] = lhsSqueezedValue; + + matmulLhs = builder.reshape(lhsReshapeInput, newLhsShape); + + SmallVector transposedRhsShape; + SmallVector transposedRhsDims; + + // Step: Create the RHS transpose sequence + // RHS = {common, matmul_dim, rhs_squeezed} + // first the common_dims + for (uint32_t i = 0; i < commonElems.size(); i++) { + transposedRhsShape.push_back(commonElems[i].shape); + transposedRhsDims.push_back(commonElems[i].dim); + } + // The matmul_dim of RHS + transposedRhsDims.push_back(maxInputRank - 2); + transposedRhsShape.push_back(rhsBroadcastedShape[maxInputRank - 2]); + // finally all the rhs_squeeze dims + for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { + if (rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) { + rhsSqueezedElems.push_back({dim, rhsBroadcastedShape[dim]}); + rhsSqueezedValue *= rhsBroadcastedShape[dim]; + } + } + rhsSqueezedElems.push_back( + {maxInputRank - 1, rhsBroadcastedShape[maxInputRank - 1]}); + rhsSqueezedValue *= rhsBroadcastedShape[maxInputRank - 1]; + for (uint32_t i = 0; i < rhsSqueezedElems.size(); i++) { + transposedRhsShape.push_back(rhsSqueezedElems[i].shape); + transposedRhsDims.push_back(rhsSqueezedElems[i].dim); + } + + auto transposedRhsValue = rankBroadcastedRhs; + if (isTransposeRequired(transposedRhsDims)) { + transposedRhsValue = + builder.transpose(rankBroadcastedRhs, transposedRhsDims); + } + + // reshape + SmallVector newRhsShape({commonValue, + rhsBroadcastedShape[maxInputRank - 2], rhsSqueezedValue}); + matmulRhs = builder.reshape(transposedRhsValue, newRhsShape); + } + + auto matmulLhsShape = + cast(matmulLhs.getType()).getShape(); + auto matmulRhsShape = + cast(matmulRhs.getType()).getShape(); + + // The reshape/transpose should ensure the tosa.matmul always has same + // batch size for either matrix. If if shapes are dynamic, they'll be + // appropriately handled. + assert(matmulLhsShape[0] == matmulRhsShape[0] && + "tosa.matmul needs same batchsize on LHS and RHS"); + + SmallVector matmulOutputShape( + {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); + + auto mmOutputTy = RankedTensorType::get(matmulOutputShape, outputElemType); + auto mmOpResult = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), mmOutputTy, matmulLhs, matmulRhs) + ->getResult(0); + auto castToOrigOp = + builder.castToNewTensorElementType(mmOpResult, lhsElemTy); + + // Perform the reshape to output shape. This is always required unless max + // input rank=3 and there was no broadcasting, in which case the tosa.matmul + // output itself is correctly shaped. + bool performOpReshape = !(maxInputRank == 3 && !performBatchDimBroadcast); + Value output = castToOrigOp; + if (performOpReshape) { + // Since the output shape may be unknown, we construct it + // independently and reshape. Otherwise reshape may be expressed for + // an unknown to-be-inferred output shape. The final tensor.cast + // reshapes the known shape to the desired output shape. + auto computeOpShape = [&](SmallVector &reshapedOpShape, + SmallVector &transposedOpShapes) { + if (maxInputRank == 1) + return; + + if (maxInputRank == 2) { + if (lhsRank == 2) + reshapedOpShape.push_back(lhsShape[0]); + if (rhsRank == 2) + reshapedOpShape.push_back(rhsShape[1]); + return; + } + + // Step: Construct the output transpose/reshape information + // First the common_dims + for (uint32_t i = 0; i < commonElems.size(); i++) { + reshapedOpShape.push_back(commonElems[i].shape); + } + + // Then the LHS squeezed dims + for (uint32_t i = 0; i < lhsSqueezedElems.size() - 1; i++) { + // Only dims that don't broadcast - broadcasting ones come from the + // other input. + if (lhsSqueezedElems[i].shape != 1) { + reshapedOpShape.push_back(lhsSqueezedElems[i].shape); + } + } + // The last squeezed dim is lhs[-2] which needs to be + // checked separately for broadcasting + if (lhsRank > 1) { + reshapedOpShape.push_back(lhsBroadcastedShape[maxInputRank - 2]); + } + + // then the RHS squeezed dims except rhs[-1] which is handled like + // lhs[-2] + for (uint32_t i = 0; i < rhsSqueezedElems.size() - 1; i++) { + if (rhsSqueezedElems[i].shape != 1) { + reshapedOpShape.push_back(rhsSqueezedElems[i].shape); + } + } + // rhs[-1] + if (rhsRank > 1) { + reshapedOpShape.push_back(rhsBroadcastedShape[maxInputRank - 1]); + } + + // Final transposed output shape construction + for (uint32_t i = 0; i < maxInputRank - 2; i++) { + if (lhsBroadcastedShape[i] == rhsBroadcastedShape[i]) { + transposedOpShapes.push_back(lhsBroadcastedShape[i]); + } else { + transposedOpShapes.push_back(lhsBroadcastedShape[i] == 1 + ? rhsBroadcastedShape[i] + : lhsBroadcastedShape[i]); + } + } + if (lhsRank > 1) + transposedOpShapes.push_back(lhsBroadcastedShape[maxInputRank - 2]); + if (rhsRank > 1) + transposedOpShapes.push_back(rhsBroadcastedShape[maxInputRank - 1]); + + return; + }; + + // Calculated output shapes for reshape and transpose + SmallVector reshapedOpShape; + SmallVector transposedOpShape; + computeOpShape(reshapedOpShape, transposedOpShape); + + // Perform reshape + auto reshapeOp = builder.reshape(castToOrigOp, reshapedOpShape); + + // Calculate transmutation required + SetVector transmutationSetVec; + for (unsigned i = 0; i < transposedOpShape.size(); i++) { + for (unsigned j = 0; j < reshapedOpShape.size(); j++) { + if (!transmutationSetVec.contains(j) && + transposedOpShape[i] == reshapedOpShape[j]) { + transmutationSetVec.insert(j); + break; + } + } + } + ArrayRef transVec = transmutationSetVec.getArrayRef(); + + // Perform final reshape + output = isTransposeRequired(transVec) + ? builder.transpose(reshapeOp, transVec) + : reshapeOp; + } + + rewriter.replaceOp(op, {output}); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXMatMulOpToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Math/Reduce.cpp b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp new file mode 100644 index 0000000000..72ce004df9 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Math/Reduce.cpp @@ -0,0 +1,263 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Reduce.cpp - ReduceMax Op --------------------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX reduce operators to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +// Expect the reduction op to have following configuration: +// Inputs: Data, Axes +// Attrs: KeepDims, noop_with_emty_axes +template +DenseIntElementsAttr getAxesLatestsVersionAttr(ONNXReduceOp op) { + typename ONNXReduceOp::Adaptor adaptor(op); + + Value input = adaptor.getData(); + Value axesValue = adaptor.getAxes(); + int64_t noOpIfAxesEmpty = adaptor.getNoopWithEmptyAxes(); + + // axes is mandatory for tosa + SmallVector targetAxes; + if (isNoneValue(axesValue)) { + if (noOpIfAxesEmpty == 0) { + // Default behaviour when "axes" is none and "noop_with_empty_axes" is + // set to false, it is to reduce all dims + const int64_t numberOfAxes = cast(input.getType()).getRank(); + auto iotaRange = + llvm::iota_range(0, numberOfAxes, /*Inclusive=*/false); + targetAxes = SmallVector(iotaRange.begin(), iotaRange.end()); + } else { + assert(noOpIfAxesEmpty == 1 && + "noop_with_empty_axes can only be either 0 or 1"); + // If "axes" is none and "noop_with_empty_axes" is true, then this + // behaves as an identity operator, no reduction is performed and shape + // is the same as the input. This is handed by later function just return + // an empty axis array + } + } else if (axesValue.getDefiningOp() || + axesValue.getDefiningOp()) { + // "axes" are specified, retrieve + auto axesValues = + tosa::getElementsAttrFromConst(axesValue).getValues(); + targetAxes = SmallVector(axesValues.begin(), axesValues.end()); + } else { + return {}; + } + + const int64_t numTargetAxes = targetAxes.size(); + auto i64Ty = + IntegerType::get(input.getContext(), /*width=*/64, IntegerType::Signless); + return DenseIntElementsAttr::get( + RankedTensorType::get({numTargetAxes}, i64Ty), targetAxes); +} + +// Expect the reduction op to have following configuration: +// Inputs: Data +// Attrs: KeepDims, axes +template +DenseIntElementsAttr getAxesLegacyVersionAttr(ONNXReduceOp op) { + typename ONNXReduceOp::Adaptor adaptor(op); + + Value input = adaptor.getData(); + auto axes = adaptor.getAxes(); + + // axes is mandatory for tosa + SmallVector targetAxes; + if (!axes) { + // if not present all axes are reduced + const int64_t numberOfAxes = cast(input.getType()).getRank(); + auto iotaRange = + llvm::iota_range(0, numberOfAxes, /*Inclusive=*/false); + targetAxes = SmallVector(iotaRange.begin(), iotaRange.end()); + } else { + targetAxes = extractFromIntegerArrayAttr(axes.value()); + } + + const int64_t numTargetAxes = targetAxes.size(); + auto i64Ty = + IntegerType::get(input.getContext(), /*width=*/64, IntegerType::Signless); + return DenseIntElementsAttr::get( + RankedTensorType::get({numTargetAxes}, i64Ty), targetAxes); +} + +template +class ONNXReduceOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXReduceOp::Adaptor; + LogicalResult matchAndRewrite(ONNXReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputType = dyn_cast(adaptor.getData().getType()); + if (!inputType) + return rewriter.notifyMatchFailure(op, "input type not a ranked tensor."); + + auto outputType = cast( + this->getTypeConverter()->convertType(op.getResult().getType())); + + return (*lowerFn)(op, inputType, outputType, rewriter); + } +}; + +template +LogicalResult reduceLatestVersionLowering(ONNXOp_t op, + RankedTensorType inputType, RankedTensorType outputType, + ConversionPatternRewriter &rewriter) { + typename ONNXOp_t::Adaptor adaptor(op); + + Value val = onnx_mlir::tosa::convertReduceOpCommon(rewriter, op, + outputType, adaptor.getData(), inputType, getAxesLatestsVersionAttr(op), + adaptor.getKeepdims()); + + // Shape inference is handled by the helper functions + rewriter.replaceOp(op, {val}); + return success(); +} + +template +LogicalResult reduceLegacyVersionsLowering(ONNXOp_t op, + RankedTensorType inputType, RankedTensorType outputType, + ConversionPatternRewriter &rewriter) { + typename ONNXOp_t::Adaptor adaptor(op); + Value val = onnx_mlir::tosa::convertReduceOpCommon(rewriter, op, + outputType, adaptor.getData(), inputType, getAxesLegacyVersionAttr(op), + adaptor.getKeepdims()); + + // Shape inference is handled by the helper functions + rewriter.replaceOp(op, {val}); + return success(); +} + +LogicalResult reduceMeanLowering(ONNXReduceMeanOp op, + RankedTensorType inputType, RankedTensorType outputType, + ConversionPatternRewriter &rewriter) { + typename ONNXReduceMeanOp::Adaptor adaptor(op); + auto newAxesAttr = getAxesLatestsVersionAttr(op); + if (!newAxesAttr) { + return rewriter.notifyMatchFailure(op, "cannot convert with dynamic axis"); + } + // reduce_mean is lowered as followed: + // op1 = reduce_sum(input) + // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) + auto keepDims = adaptor.getKeepdims(); + int64_t inputRank = inputType.getRank(); + int64_t numElemsOnReducedAxis = 1; + for (int i = 0; i < newAxesAttr.getNumElements(); i++) { + int64_t axisVal = newAxesAttr.getValues()[i].getInt(); + if (axisVal < 0) + axisVal += inputRank; + numElemsOnReducedAxis *= inputType.getShape()[axisVal]; + } + double divScale = 1.0 / static_cast(numElemsOnReducedAxis); + + Value val = + onnx_mlir::tosa::convertReduceOpCommon(rewriter, + op, outputType, adaptor.getData(), inputType, newAxesAttr, keepDims); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value divConst = tosaBuilder.getSplattedConst( + divScale, outputType.getElementType(), outputType.getRank()); + auto output = tosaBuilder.mul(val, divConst); + + if (!output) { + return rewriter.notifyMatchFailure(op, "could not be converted"); + } + // Shape inference is handled by the helper functions + rewriter.replaceOp(op, {output}); + return success(); +} + +LogicalResult reduceMeanV13Lowering(ONNXReduceMeanV13Op op, + RankedTensorType /*inputType*/, RankedTensorType outputType, + ConversionPatternRewriter &rewriter) { + typename ONNXReduceMeanV13Op::Adaptor adaptor(op); + auto newAxesAttr = getAxesLegacyVersionAttr(op); + + auto keepDims = adaptor.getKeepdims(); + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + auto output = tosa::convertReduceMeanOp(rewriter, op, tosaBuilder, outputType, + adaptor.getData(), newAxesAttr, keepDims); + + if (!output) { + return rewriter.notifyMatchFailure(op, "Could not be converted"); + } + // Shape inference is handled by the helper functions + rewriter.replaceOp(op, {output.value()}); + return success(); +} + +} // namespace + +#define DECLARE_ONE_TO_ONE_LOWERING(ONNXOp, TOSAOp) \ + using ONNXOp##LoweringToTOSA = ONNXReduceOpLoweringToTOSA> +// Covers versions 20(latests)-18 +DECLARE_ONE_TO_ONE_LOWERING(ONNXReduceMinOp, mlir::tosa::ReduceMinOp); +// Covers versions 20(latests)-18 +DECLARE_ONE_TO_ONE_LOWERING(ONNXReduceMaxOp, mlir::tosa::ReduceMaxOp); +// Covers versions 13 (latests) +DECLARE_ONE_TO_ONE_LOWERING(ONNXReduceProdOp, mlir::tosa::ReduceProdOp); +// Covers versions 18 (latests) +DECLARE_ONE_TO_ONE_LOWERING(ONNXReduceSumOp, mlir::tosa::ReduceSumOp); +// Covers versions 18 (latests) +using ONNXReduceMeanOpLoweringToTOSA = + ONNXReduceOpLoweringToTOSA; +#undef DECLARE_ONE_TO_ONE_LOWERING + +#define DECLARE_ONE_TO_ONE_LEGACY_LOWERING(ONNXOp, TOSAOp) \ + using ONNXOp##LegacyLoweringToTOSA = ONNXReduceOpLoweringToTOSA> +// Covers versions 13-12-11 +DECLARE_ONE_TO_ONE_LEGACY_LOWERING(ONNXReduceMinV13Op, mlir::tosa::ReduceMinOp); +// Covers versions 13-12-11 +DECLARE_ONE_TO_ONE_LEGACY_LOWERING(ONNXReduceMaxV13Op, mlir::tosa::ReduceMaxOp); +// Covers version 11 +DECLARE_ONE_TO_ONE_LEGACY_LOWERING(ONNXReduceSumV11Op, mlir::tosa::ReduceSumOp); +// Covers version 13-11 +DECLARE_ONE_TO_ONE_LEGACY_LOWERING( + ONNXReduceProdV13Op, mlir::tosa::ReduceProdOp); +// Covers version 13-11 +using ONNXReduceMeanV13LegacyLoweringToTOSA = + ONNXReduceOpLoweringToTOSA; +#undef DECLARE_ONE_TO_ONE_LEGACY_LOWERING + +void populateLoweringONNXReduceOpsToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp b/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp deleted file mode 100644 index edaea512bb..0000000000 --- a/src/Conversion/ONNXToTOSA/Math/ReduceMean.cpp +++ /dev/null @@ -1,116 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===---------------- ReduceMean.cpp - ReduceMean Op --------------------===// -// -// Copyright (c) 2023 Advanced Micro Devices, Inc. -// -// ============================================================================= -// -// This file lowers ONNX reduce mean operator to TOSA dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" -#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" -#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include - -using namespace mlir; - -namespace onnx_mlir { - -namespace { - -class ONNXReduceMeanLoweringToTOSA - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename ONNXReduceMeanOp::Adaptor; - LogicalResult matchAndRewrite(ONNXReduceMeanOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto loc = op->getLoc(); - TosaBuilder tosaBuilder(rewriter, loc); - - Value input = adaptor.getData(); - Value axesValue = adaptor.getAxes(); - auto keepDims = adaptor.getKeepdims(); - auto noOpIfAxesEmpty = adaptor.getNoopWithEmptyAxes(); - - auto outputType = mlir::cast( - getTypeConverter()->convertType(op.getResult().getType())); - - RankedTensorType inputType = - mlir::dyn_cast(input.getType()); - if (!inputType) - return rewriter.notifyMatchFailure(op, "input type not a ranked tensor."); - - // axes is mandatory for tosa - llvm::SmallVector axesVec; - if (isNoneValue(axesValue) && !noOpIfAxesEmpty) { - // if not present all axes are reduced - const int64_t numberOfAxes = - mlir::cast(input.getType()).getRank(); - llvm::SmallVector allDims(numberOfAxes); - std::iota(std::begin(allDims), std::end(allDims), 0); - axesVec.append(allDims); - } else if (axesValue.getDefiningOp()) { - // if input is a tosa const get axes - auto axes = tosa::getValueFromTosaConst(axesValue); - auto axesElementsValues = axes.getValues(); - llvm::transform(axesElementsValues, std::back_inserter(axesVec), - [](int64_t axesInt) { return axesInt; }); - } - // Tosa needs a DenseElementsAttr - const int64_t vecValuesSize = axesVec.size(); - DenseElementsAttr newAxesAttr = DenseIntElementsAttr::get( - RankedTensorType::get({vecValuesSize}, rewriter.getI64Type()), axesVec); - - // reduce_mean is lowered as followed: - // op1 = reduce_sum(input) - // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) - - int64_t inputRank = inputType.getRank(); - int64_t numElemsOnReducedAxis = 1; - for (int i = 0; i < newAxesAttr.getNumElements(); i++) { - int64_t axisVal = newAxesAttr.getValues()[i].getInt(); - if (axisVal < 0) - axisVal += inputRank; - numElemsOnReducedAxis *= inputType.getShape()[axisVal]; - } - double divScale = 1.0 / static_cast(numElemsOnReducedAxis); - mlir::Type reduceElementType = inputType.getElementType(); - - auto val = onnx_mlir::tosa::convertReduceOpCommon( - rewriter, op, outputType, input, newAxesAttr, keepDims, - reduceElementType); - - if (!val.has_value()) - return rewriter.notifyMatchFailure( - op, "could not convert generic reduce op."); - - Value divConst = tosaBuilder.getSplattedConst(divScale); - auto output = tosaBuilder.mul(val.value(), divConst); - - if (!output) { - return rewriter.notifyMatchFailure(op, "could not be converted"); - } - // Shape inference is handled by the helper functions - rewriter.replaceOp(op, {output}); - return success(); - } -}; - -} // namespace - -void populateLoweringONNXReduceMeanOpToTOSAPattern(ConversionTarget &target, - RewritePatternSet &patterns, TypeConverter &typeConverter, - MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); -} - -} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Math/Softmax.cpp b/src/Conversion/ONNXToTOSA/Math/Softmax.cpp index 81321a5754..4c9962d44b 100644 --- a/src/Conversion/ONNXToTOSA/Math/Softmax.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Softmax.cpp @@ -26,14 +26,10 @@ namespace onnx_mlir { namespace { -template -Value computeReduceSum(PatternRewriter &rewriter, Operation *op, - RankedTensorType rsumType, const Value &op1ExpIn, int axis) = delete; - // Before opset 13, softmax reduces axis and every dimension following. -template <> -Value computeReduceSum(PatternRewriter &rewriter, - Operation *op, RankedTensorType rsumType, const Value &op1ExpIn, int axis) { +template +Value computeReduction(PatternRewriter &rewriter, ONNXSoftmaxV11Op op, + RankedTensorType rsumType, const Value &op1ExpIn, int axis) { const int64_t inputRank = rsumType.getRank(); // Create shared outputType with dynamic shape. Infer method when creating // ops will insert a static shape if possible @@ -41,22 +37,22 @@ Value computeReduceSum(PatternRewriter &rewriter, llvm::SmallVector(inputRank, ShapedType::kDynamic), rsumType.getElementType()); // Create first reduce with input from function operands - Value reducedSum = tosa::CreateOpAndInfer(rewriter, - op->getLoc(), outputType, op1ExpIn, rewriter.getI32IntegerAttr(axis)); + Value reducedSum = tosa::CreateOpAndInfer(rewriter, op->getLoc(), + outputType, op1ExpIn, rewriter.getI32IntegerAttr(axis)); // Loop over all following dimensions with last reduce as input for (int i = axis + 1; i < inputRank; i++) { - reducedSum = tosa::CreateOpAndInfer(rewriter, - op->getLoc(), outputType, reducedSum, rewriter.getI32IntegerAttr(i)); + reducedSum = tosa::CreateOpAndInfer(rewriter, op->getLoc(), + outputType, reducedSum, rewriter.getI32IntegerAttr(i)); } return reducedSum; } // From opset 13, softmax uses axis as the reduce axis. -template <> -Value computeReduceSum(PatternRewriter &rewriter, Operation *op, +template +Value computeReduction(PatternRewriter &rewriter, ONNXSoftmaxOp op, RankedTensorType rsumType, const Value &op1ExpIn, int axis) { - return tosa::CreateOpAndInfer(rewriter, op->getLoc(), - rsumType, op1ExpIn, rewriter.getI32IntegerAttr(axis)); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), rsumType, + op1ExpIn, rewriter.getI32IntegerAttr(axis)); } template @@ -87,27 +83,32 @@ class ONNXSoftmaxLoweringToTOSA : public OpConversionPattern { int64_t axis = adaptor.getAxis(); // Tosa only supports positive values int64_t inputRank = inputType.getRank(); - if (axis < 0) { - axis += inputRank; - } + axis = tosa::convertNegativeAxis(axis, inputRank); // The legalization below is based on convertSoftmaxOp in // tensorflow tosa/transforms/legalize_common.cc, with the // addition of handling for axis. // Floating-point lowering is more direct: // - // op1 = exp(logits) + // m = reduce_max(logits) + // op1 = exp(logits - m) // op2 = reduce_sum(op1, -1) // op3 = reciprocal(op2) // op4 = mul(op1, op3) - Value op1ExpIn = tosa::CreateOpAndInfer( - rewriter, loc, outputType, input); RankedTensorType rsumType = RankedTensorType::get( llvm::SmallVector(inputRank, ShapedType::kDynamic), outputType.getElementType()); - Value op2ReducesumOp1 = - computeReduceSum(rewriter, op, rsumType, op1ExpIn, axis); + Value reduceMax = computeReduction( + rewriter, op, rsumType, input, axis); + + Value xLessMax = tosaBuilder.binaryOp(input, reduceMax); + + Value op1ExpIn = tosa::CreateOpAndInfer( + rewriter, loc, outputType, xLessMax); + + Value op2ReducesumOp1 = computeReduction( + rewriter, op, rsumType, op1ExpIn, axis); Value op3ReciprocalOp2 = tosa::CreateOpAndInfer( rewriter, loc, op2ReducesumOp1.getType(), op2ReducesumOp1); @@ -128,4 +129,4 @@ void populateLoweringONNXSoftmaxOpToTOSAPattern(ConversionTarget &target, ONNXSoftmaxLoweringToTOSA>(typeConverter, ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp index 9874961036..80f56bfb55 100644 --- a/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp +++ b/src/Conversion/ONNXToTOSA/NN/AveragePool.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" @@ -28,27 +29,35 @@ namespace onnx_mlir { namespace { -void handleIncludePadAttr( +LogicalResult handleIncludePadAttr( ConversionPatternRewriter &rewriter, Operation *op, Value input) { - mlir::Location loc = op->getLoc(); + Location loc = op->getLoc(); // Get shape. IndexExprBuilderForTosa createTosaIE(rewriter, loc); ONNXGenericPoolOpShapeHelper shapeHelper( op, {}, &createTosaIE); - shapeHelper.computeShapeAndAssertOnFailure(); + if (shapeHelper.computeShape().failed()) { + return rewriter.notifyMatchFailure(op, "Could not infer shapes"); + } + + auto inputType = cast(input.getType()); + if (inputType.getShape().size() != 4) { + return rewriter.notifyMatchFailure(op, "TOSA only supports 2d pooling"); + } - // Build an array with padding. - llvm::SmallVector intValues; - IndexExpr::getLiteral(shapeHelper.pads, intValues); + llvm::SmallVector pads = + tosa::createOrderedPadAttrForWindowBasedOps(rewriter, + cast(input.getType()).getShape(), shapeHelper, + /*ceilMode*/ 0, {0, 1, 2, 3}); // Create Padding and ConstPad tosa::ConstOp's TosaBuilder tosaBuilder(rewriter, loc); Value padding = tosa::buildOnnxToTosaPaddingConstOp( - rewriter, intValues, loc, {0, 0, 0, 0}, {}); - auto constTosaTensor = tosaBuilder.getSplattedConst(0.0); + rewriter, pads, loc, {0, 0, 0, 0}, {}); + auto constTosaTensor = + tosaBuilder.getSplattedConst(0.0, inputType.getElementType(), 0); - auto inputType = mlir::cast(input.getType()); auto padOp = tosa::CreateOpAndInfer(rewriter, loc, mlir::RankedTensorType::get( llvm::SmallVector( @@ -62,6 +71,7 @@ void handleIncludePadAttr( rewriter.modifyOpInPlace(op, [&]() { op->setAttr("pads", rewriter.getI32ArrayAttr({0, 0, 0, 0})); }); + return success(); } class ONNXAveragePoolOpLoweringToTOSA @@ -94,7 +104,8 @@ class ONNXAveragePoolOpLoweringToTOSA // lowering still generates transposes between ONNX and TOSA formats, and // implementation doesn't diverge much. This will modify the original onnx // op. - handleIncludePadAttr(rewriter, averagePoolOp, adaptor.getX()); + if (failed(handleIncludePadAttr(rewriter, averagePoolOp, adaptor.getX()))) + return failure(); } FailureOr newAveragePoolOp = diff --git a/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp new file mode 100644 index 0000000000..3cd31c4ebf --- /dev/null +++ b/src/Conversion/ONNXToTOSA/NN/BatchNorm.cpp @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- ONNXQuantizeLinearOp.cpp - ONNXQuantizeLinearOp---------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNXBatchNormalizationOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +namespace onnx_mlir { +namespace { +class ONNXBatchNormalizationInferenceModeOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXBatchNormalizationInferenceModeOp::Adaptor; + LogicalResult matchAndRewrite(ONNXBatchNormalizationInferenceModeOp op, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + + auto outType = getTypeConverter()->convertType(op.getResult().getType()); + if (!cast(outType).hasRank()) { + return rewriter.notifyMatchFailure(op, + "ONNXBatchNormalizationInferenceModeOp to " + "TOSA requires a ranked result type"); + } + auto outTensorType = cast(outType); + + // The layout of the output is N x C x D1 x D2 … Dn. For batch + // normalization, the C dimension is kept. The new shape should be {1, C, 1, + // 1, ...}. + SmallVector newShape = {1, outTensorType.getShape()[1]}; + for (auto i = 2; i < outTensorType.getRank(); i++) + newShape.push_back(1); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value input = op.getX(); + Value mean = op.getMean(); + Value scale = op.getScale(); + Value bias = op.getB(); + Value var = op.getVar(); + + // reshape rank-1 tensors (scale, bias, mean, variance), + // such that they have the same rank as input/output tensor + Value reshapedMean; + Value reshapedScale; + Value reshapedBias; + Value reshapedVar; + + reshapedMean = tosaBuilder.reshape(mean, ArrayRef(newShape)); + reshapedScale = tosaBuilder.reshape(scale, ArrayRef(newShape)); + reshapedBias = tosaBuilder.reshape(bias, ArrayRef(newShape)); + reshapedVar = tosaBuilder.reshape(var, ArrayRef(newShape)); + + // epsilon's shape: constant -> {1, 1, 1, ...} + newShape[1] = 1; + auto eps = tosaBuilder.getSplattedConst(op.getEpsilon().convertToFloat(), + outTensorType.getElementType(), newShape.size()); + + // output = (input - mean) * scale * rsqrt(var + eps) + bias + auto op1SubInputMean = + tosaBuilder.binaryOp(input, reshapedMean); + auto op2AddVarEps = + tosaBuilder.binaryOp(reshapedVar, eps); + auto op3RsqrtOp2 = tosaBuilder.unaryOp(op2AddVarEps); + auto op4MulOp1Op3 = tosaBuilder.mul(op1SubInputMean, op3RsqrtOp2, 0); + auto op5MulOp4Scale = tosaBuilder.mul(op4MulOp1Op3, reshapedScale, 0); + auto newOutput = + tosaBuilder.binaryOp(op5MulOp4Scale, reshapedBias); + + rewriter.replaceOp(op, {newOutput}); + return success(); + } +}; +} // namespace + +void populateLoweringONNXBatchNormalizationOpToTOSAPattern( + ConversionTarget &target, RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert( + typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp b/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp new file mode 100644 index 0000000000..5d7a27d612 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/NN/DequantizeLinear.cpp @@ -0,0 +1,112 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- ONNXDequantizeLinearOp.cpp - ONNXDequantizeLinearOp-----===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNXDequantizeLinearOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXDequantizeLinearOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXDequantizeLinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + Value x = op.getX(); + auto resultType = dyn_cast_if_present( + getTypeConverter()->convertType(op.getResult().getType())); + if (!resultType || !resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected valid tensor result type"); + } + + auto zeroPoint = adaptor.getXZeroPoint(); + auto zpTy = zeroPoint.getType(); + if (isa(zpTy)) { + zeroPoint = {}; + } else if (auto shapedTy = dyn_cast(zpTy)) { + if (!shapedTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected zero point to have static shape"); + } + } else { + return rewriter.notifyMatchFailure( + loc, "expected zero point to be none or have tensor type"); + } + + if (auto scaleTy = cast(adaptor.getXScale().getType()); + !scaleTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected scale to have static shape"); + } + + int64_t axis = op.getAxis(); + // See https://github.com/onnx/onnx/issues/6067 + if (axis == 1 && resultType.getRank() == 1) + axis = 0; + if (axis < -resultType.getRank() || axis >= resultType.getRank()) { + return rewriter.notifyMatchFailure(loc, "axis is invalid"); + } + if (axis < 0) + axis += resultType.getRank(); + + // Dequantization formula is (x - zero_point) * scale + // Cast into the destination type first + + // Cast the operands of (x - zero_point) to float32 to avoid underflows + Type arithType = rewriter.getF32Type(); + Value casted = tosaBuilder.castToNewTensorElementType(x, arithType); + if (zeroPoint) { + auto zpConst = tosa::expandShape( + rewriter, loc, zeroPoint, axis, resultType.getRank()); + Value zpConstCast = + tosaBuilder.castToNewTensorElementType(zpConst, arithType); + casted = tosa::CreateOpAndInfer( + rewriter, loc, casted.getType(), casted, zpConstCast) + .getResult(); + } + auto scaleFactorConst = tosa::expandShape( + rewriter, loc, adaptor.getXScale(), axis, resultType.getRank()); + Value scaleFactorCast = + tosaBuilder.castToNewTensorElementType(scaleFactorConst, arithType); + Value mulOp = tosa::CreateOpAndInfer( + rewriter, loc, casted.getType(), casted, scaleFactorCast, 0) + .getResult(); + Value castOp = tosaBuilder.castToNewTensorElementType( + mulOp, resultType.getElementType()); + + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXDequantizeLinearOpToTOSAPattern( + ConversionTarget &target, RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp index c530c1e152..8261445c4e 100644 --- a/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp +++ b/src/Conversion/ONNXToTOSA/NN/MaxPoolSingleOut.cpp @@ -36,7 +36,7 @@ class ONNXMaxPoolSingleOutOpLoweringToTOSA : public ConversionPattern { using OpAdaptor = typename ONNXMaxPoolSingleOutOp::Adaptor; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto maxpoolOp = llvm::cast(op); + auto maxpoolOp = mlir::cast(op); OpAdaptor adaptor(operands, op->getAttrDictionary()); Value input = adaptor.getX(); @@ -48,9 +48,11 @@ class ONNXMaxPoolSingleOutOpLoweringToTOSA : public ConversionPattern { return rewriter.notifyMatchFailure( op, "memrefs as inputs are unsupported by TOSA"); } - if (dilations) { + auto isOne = [](IntegerAttr attr) { return attr.getValue().isOne(); }; + if (dilations && + !llvm::all_of(dilations.getAsRange(), isOne)) { return rewriter.notifyMatchFailure( - maxpoolOp, "dilations attribute is unsupported by TOSA"); + maxpoolOp, "dilations != 1 is unsupported by TOSA"); } if (storageOrder && storageOrder.getSInt() != 0) { return rewriter.notifyMatchFailure( diff --git a/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp b/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp new file mode 100644 index 0000000000..85a1652bf4 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/NN/QuantizeLinear.cpp @@ -0,0 +1,177 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- ONNXQuantizeLinearOp.cpp - ONNXQuantizeLinearOp---------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNXQuantizeLinearOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXQuantizeLinearOpLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ONNXQuantizeLinearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultType = dyn_cast_if_present( + getTypeConverter()->convertType(op.getResult().getType())); + if (!resultType || !resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected valid tensor result type"); + } + + if (auto zpTy = dyn_cast(adaptor.getYZeroPoint().getType()); + zpTy && !zpTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected zero point to have static shape"); + } + + if (auto zpTy = dyn_cast(adaptor.getYScale().getType()); + zpTy && !zpTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + loc, "expected scale to have static shape"); + } + + if (!op.getSaturate()) { + return rewriter.notifyMatchFailure(loc, "Only saturate=1 is supported"); + } + + int64_t axis = op.getAxis(); + // See https://github.com/onnx/onnx/issues/6067 + if (axis == 1 && resultType.getRank() == 1) + axis = 0; + if (axis < -resultType.getRank() || axis >= resultType.getRank()) { + return rewriter.notifyMatchFailure(loc, "axis is invalid"); + } + if (axis < 0) + axis += resultType.getRank(); + + Value x = adaptor.getX(); + Type xType = x.getType(); + + // Quantization formula is saturate((x / y_scale) + y_zero_point) + // tosa.mul doesn't allow different ranks + auto expandedScaleFactorConst = tosa::expandShape( + rewriter, loc, adaptor.getYScale(), axis, resultType.getRank()); + // Replace the division by a reciprocal followed by a mul + Value recOp = tosa::CreateOpAndInfer(rewriter, + loc, expandedScaleFactorConst.getType(), expandedScaleFactorConst) + .getResult(); + Value scaledResult = tosa::CreateOpAndInfer( + rewriter, loc, xType, x, recOp, 0) + .getResult(); + + // Quantization to i4/i8/16/ is particular since the intermediate result of + // (x / y_scale) must round to the nearest even. This is particularly + // important if the intermediate result is e.g. 8.5. If we don't round to + // the nearest even before adding the (potentially odd) zero point, we would + // end up with a different result + bool quantizingToInt = isa(resultType.getElementType()); + if (quantizingToInt) { + // ONNX QuantizeLinear op supports those integer zero point types: + // int16, int4, int8, uint16, uint4, uint8 + // Convert the scaled result to a safe bitwith (i32) that avoids + // underflows/overflows + scaledResult = tosa::CreateOpAndInfer(rewriter, loc, + resultType.cloneWith({}, rewriter.getI32Type()), scaledResult) + .getResult(); + } + + // If there is no zero point, we are done + if (isa(adaptor.getYZeroPoint().getType())) { + Value result = tosa::CreateOpAndInfer( + rewriter, loc, resultType, scaledResult) + .getResult(); + rewriter.replaceOp(op, result); + return success(); + } + + Value expandedZpConst = tosa::expandShape( + rewriter, loc, adaptor.getYZeroPoint(), axis, resultType.getRank()); + + // Cast the expandedZpConst to have the same rank and element type as + // the scaledResult. tosa.add doesn't allow different ranks + Value castedZp; + if (quantizingToInt) { + castedZp = tosa::CreateOpAndInfer(rewriter, loc, + cast(expandedZpConst.getType()) + .cloneWith({}, rewriter.getI32Type()), + expandedZpConst) + .getResult(); + } else { + // zpConst has the same type as the result of QLinear which is always + // smaller than the input type. Cast it to the input type. + castedZp = tosa::CreateOpAndInfer( + rewriter, loc, expandedScaleFactorConst.getType(), expandedZpConst) + .getResult(); + } + + Value addOp = tosa::CreateOpAndInfer( + rewriter, loc, scaledResult.getType(), scaledResult, castedZp) + .getResult(); + + Value clampedRes = addOp; + if (quantizingToInt) { + // If the destination type is an integer, perform saturation. + IntegerType resTypeInt = + dyn_cast(resultType.getElementType()); + + // Compute the max/min values for the said type from the 64-bit max + auto width = resTypeInt.getIntOrFloatBitWidth(); + APInt maxVal = resTypeInt.isUnsigned() ? APInt::getMaxValue(width) + : APInt::getSignedMaxValue(width); + APInt minVal = resTypeInt.isUnsigned() ? APInt::getZero(width) + : APInt::getSignedMinValue(width); + + clampedRes = tosa::CreateOpAndInfer(rewriter, loc, + addOp.getType(), addOp, + rewriter.getIntegerAttr(rewriter.getI64Type(), minVal.sext(64)), + rewriter.getIntegerAttr(rewriter.getI64Type(), maxVal.zext(64)), + // We ignore floating point values, we're clamping integers. + rewriter.getFloatAttr( + rewriter.getF32Type(), (float)(minVal.getSExtValue())), + rewriter.getFloatAttr( + rewriter.getF32Type(), (float)(maxVal.getZExtValue()))); + } + + // Cast into the result type + Value result = tosa::CreateOpAndInfer( + rewriter, loc, resultType, clampedRes) + .getResult(); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXQuantizeLinearOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp new file mode 100644 index 0000000000..fd4c29393f --- /dev/null +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.cpp @@ -0,0 +1,428 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// +// +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// Copyright (c) 2021 Arm Limited. +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/MLIRContext.h" + +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace tosa { + +static int64_t multiplyDims(llvm::ArrayRef dims, int64_t res = 1) { + for (auto dim : dims) { + if (ShapedType::isDynamic(dim)) { + return ShapedType::kDynamic; + } + res = res * dim; + } + return res; +} + +static int64_t countDynamicDims(llvm::ArrayRef dims) { + int64_t count = 0; + for (auto dim : dims) + if (ShapedType::isDynamic(dim)) + ++count; + return count; +} + +// Lowers Gather operators to a sequence of TOSA ops. +// This Code is mostly the same as TF to TOSA. +std::optional convertGatherOp(PatternRewriter &rewriter, Location loc, + Value resultValue, Value inputValue, Value indicesValue, int32_t batchDims, + int32_t axis) { + + TosaBuilder tosaBuilder(rewriter, loc); + + auto resultType = dyn_cast(resultValue.getType()); + auto inputType = dyn_cast(inputValue.getType()); + auto indicesType = dyn_cast(indicesValue.getType()); + + if (!resultType || !inputType || !indicesType) + return std::nullopt; + + // batchDims indicates the number of batch dimensions in input and + // indices axis indicates the axis at which the gather indexing is + // applied. axis must be >= batch_dims. When axis is equal to + // batch_dims, the right-most batch dimension disappears. + // + // N: number of batches + // Computed as product of input.shape[0:batch_dims-1] + // + // W: number of indices in each batch + // Computed as product of indices.shape[batch_dims:] + // + // K: range of each index + // Computed as input.shape[axis:axis+rank(indices)-1] + // + // C: number of channels for each index + // Computed as: LeftChannels * RightChannels: + // product(input.shape[batch_dims:axis]) * product(input.shape[axis+1:]) + // + // The input tensor needs to be transposed, then reshaped to move the + // dimensions into [N, K, C] order. + // + // The dimensions of the input input[] tensor are grouped in the following + // order to begin with: + // + // [Batch, LeftChannels, Indices, RightChannels] + // |-----||------------||-------||-------------| + // N C_l K C_r + // + // Where Batch (N), Indices (K) can be one or more dimensions in size, + // while LeftChannels and RightChannels represent the group of data channels + // (C) to the left and right (C_l, C_r) of the indices; the sum of these two + // is one or more dimensions in size, but either one may be zero depending + // on how axis was specified by the caller. + // + // The resulting tensor will look like: + // + // [Batch, Indices, LeftChannels, RightChannels] + // |-----||-------||---------------------------| + // N K C + // + // The indices tensor simply needs a reshape to flatten all of the + // batch dimensions (N) together and flatten all of the indices (W) + // together. + // + // Then do the tosa.GATHER + // + // output[N,W,C] = tosa.GATHER(values[N,K,C], indices[N,W]) + // + // Finally, the resulting tensor will have shape [N, W, C], where C is a + // flattened version of [LeftChannels, RightChannels]. We need to reshape + // to unflatten to: + // + // [N, W, LeftChannels, RightChannels] + // + // and finally transpose back to the output shape + // + // [Batch, LeftChannels, Non-Batch-Indices, RightChannels] + + size_t inputRank = inputType.getShape().size(); + size_t indicesRank = indicesType.getShape().size(); + + ArrayRef inputShape = inputType.getShape(); + ArrayRef indicesShape = indicesType.getShape(); + + if (!((size_t)batchDims <= indicesRank)) { + (void)rewriter.notifyMatchFailure( + loc, "batch_dims must be <= indices_rank for a valid gather op"); + return std::nullopt; + } + + if (!(axis >= batchDims)) { + (void)rewriter.notifyMatchFailure( + loc, "axis must be >= batch_dims for a valid gather op"); + return std::nullopt; + } + + // onnx allows i64 indices, but tosa does not. + if (indicesType.getElementType().isInteger(64)) { + indicesType = + dyn_cast(indicesType.clone(rewriter.getI32Type())); + indicesValue = CreateOpAndInfer( + rewriter, loc, indicesType, indicesValue) + .getResult(); + } + + // Sizes for each of these fields. + SmallVector inputBatch, inputIndices, inputLeftChannels, + inputRightChannels; + + // Dimension indices for each of these fields. + SmallVector inputIdxBatch, inputIdxIndices, inputIdxLeftChannels, + inputIdxRightChannels; + + // Read through the input tensor dimensions left-to-right and extract the + // different fields. + for (int i = 0; i < (int)inputRank; i++) { + // When batch_dims == axis, the batch dimension gets replaced. + if (i < batchDims && i < axis) { + inputBatch.push_back(inputShape[i]); + inputIdxBatch.push_back(i); + } else if (i < axis) { + inputLeftChannels.push_back(inputShape[i]); + inputIdxLeftChannels.push_back(i); + } else if (i < (axis + 1)) { + inputIndices.push_back(inputShape[i]); + inputIdxIndices.push_back(i); + } else { + inputRightChannels.push_back(inputShape[i]); + inputIdxRightChannels.push_back(i); + } + } + + // Calculate N, K, W, C + int64_t N = multiplyDims(inputShape.take_front(batchDims)); + int64_t W = + multiplyDims(indicesShape.slice(batchDims, indicesRank - batchDims)); + int64_t K = inputShape[axis]; + + int64_t C = multiplyDims(inputShape.slice(batchDims, axis - batchDims)); + C = multiplyDims(inputShape.slice(axis + 1, inputRank - axis - 1), C); + + ///////////////////////////////////////////// + // Build up the input transpose operator + SmallVector inputTransposePerm; + + // Batch + inputTransposePerm.append(inputIdxBatch); + + // Indices + inputTransposePerm.append(inputIdxIndices); + + // LeftChannels + inputTransposePerm.append(inputIdxLeftChannels); + + // RightChannels + inputTransposePerm.append(inputIdxRightChannels); + + ///////////////////////////////////////////// + // Build up the result reshape, in prepration for transpose + // [N, W, C] -> [ Batch, Indices, LeftChannels, RightChannels ] + SmallVector resultReshapeShape; + + // Indices + // Use llvm::transform because range is an ArrayRef + llvm::transform(indicesShape, std::back_inserter(resultReshapeShape), + [](int64_t indiceDim) { return indiceDim; }); + + // Left channels + resultReshapeShape.append(inputLeftChannels); + + // Right channels. But remove the axis dimension. + resultReshapeShape.append(inputRightChannels); + + ///////////////////////////////////////////// + // Build up the result transpose operator. + SmallVector resultTransposePerm; + + // Batch dimensions + for (int i = 0; i < batchDims; i++) { + resultTransposePerm.push_back(i); + } + + // LeftChannels + for (int i = 0; i < (int)inputLeftChannels.size(); i++) { + resultTransposePerm.push_back(i + indicesType.getShape().size()); + } + + // Indices (remainder of dimensions after batch). + for (int i = batchDims; i < (int)(indicesType.getShape().size()); i++) { + resultTransposePerm.push_back(i); + } + + // RightChannels, coming from after both the Indices and LeftChannels. + for (int i = 0; i < (int)inputRightChannels.size(); i++) { + resultTransposePerm.push_back( + i + indicesType.getShape().size() + inputLeftChannels.size()); + } + + SmallVector tosaValuesShape = {N, K, C}; + SmallVector tosaIndicesShape = {N, W}; + + // Begin of rewrite. + + auto inputTransposeOp = tosaBuilder.transpose(inputValue, inputTransposePerm); + + if (countDynamicDims(tosaValuesShape) > 1) { + return (void)rewriter.notifyMatchFailure(loc, + "only one dynamic dimension allowed when reshaping indices " + "values."), + std::nullopt; + } + + auto tosaValuesReshapeOp = + tosaBuilder.reshape(inputTransposeOp, tosaValuesShape); + + if (countDynamicDims(tosaIndicesShape) > 1) { + return (void)rewriter.notifyMatchFailure(loc, + "only one dynamic dimension allowed when reshaping indices"), + std::nullopt; + } + + auto tosaIndicesReshapeOp = + tosaBuilder.reshape(indicesValue, tosaIndicesShape); + + Value tosaGatherOp = CreateOpAndInfer(rewriter, loc, + RankedTensorType::get(llvm::SmallVector(3, ShapedType::kDynamic), + resultType.getElementType()), + tosaValuesReshapeOp, tosaIndicesReshapeOp); + + if (countDynamicDims(resultReshapeShape) > 1) { + return (void)rewriter.notifyMatchFailure(loc, + "only one dynamic dimension allowed when reshaping result."), + std::nullopt; + } + + Value tosaResultReshapeOp = + tosaBuilder.reshape(tosaGatherOp, resultReshapeShape); + + return tosaBuilder.transpose(tosaResultReshapeOp, resultTransposePerm); +} + +// Common function for lowering reduce operations to TOSA ops. +template +std::optional convertReduceOpCommon(PatternRewriter &rewriter, + Operation *op, RankedTensorType output_type, Value input_value, + ElementsAttr axes_elems, bool keep_dims, Type reduce_element_type, + bool is_quantized, double input_scale, int64_t input_zp, + double output_scale, int64_t output_zp) { + RankedTensorType input_type = + dyn_cast(input_value.getType()); + if (!input_type) + return std::nullopt; + + if (!axes_elems) + return std::nullopt; + + ArrayRef input_shape = input_type.getShape(); + ArrayRef output_shape = output_type.getShape(); + auto input_rank = input_shape.size(); + Value val = input_value; + + if (axes_elems.getNumElements() == 0) { + // No axes means return the original tensor. + auto identity_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_type, val); + val = identity_op.getResult(); + } else { + // Reduce along each axis + SmallVector shape_vec(input_shape.begin(), input_shape.end()); + + if (is_quantized) { + val = buildRescaleToInt32(rewriter, op, val, input_scale, input_zp); + } + + for (int i = 0; i < axes_elems.getNumElements(); i++) { + int64_t axis_val = axes_elems.getValues()[i].getInt(); + if (axis_val < 0) + axis_val += input_rank; + auto axis_attr = rewriter.getI32IntegerAttr(axis_val); + + shape_vec[axis_val] = 1; + RankedTensorType reduce_type = + RankedTensorType::get(shape_vec, reduce_element_type); + + auto reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), reduce_type, val, axis_attr); + + val = reduce_op.getResult(); + } + + if (is_quantized) { + RankedTensorType output_rescale_type = + RankedTensorType::get(shape_vec, output_type.getElementType()); + val = buildRescale(rewriter, op, output_rescale_type, val, output_scale, + 0, output_zp, false, true); + } + + // Optionally squeeze out the reduced axes. + if (!keep_dims) { + auto reshape_op = + CreateOpAndInfer(rewriter, op->getLoc(), + output_type, val, rewriter.getDenseI64ArrayAttr(output_shape)); + val = reshape_op.getResult(); + } + } + + return val; +} + +// Lowers ReduceMean to a sequence of TOSA ops. +std::optional convertReduceMeanOp(PatternRewriter &rewriter, + Operation *op, TosaBuilder &tosaBuilder, RankedTensorType output_type, + Value input_value, ElementsAttr axes_elems, bool keep_dims) { + // reduce_mean is lowered as followed: + // op1 = reduce_sum(input) + // op2 = mul(op1, 1.0 / num_elements_on_reduced_axis) + + RankedTensorType input_type = + dyn_cast(input_value.getType()); + if (!input_type) + return std::nullopt; + + bool input_is_qtype = + isa(input_type.getElementType()); + bool output_is_qtype = + isa(output_type.getElementType()); + + if (input_is_qtype != output_is_qtype) { + op->emitOpError("ConvertReduceSumOp: input/output tensor should " + "be all quantized or all floating-point."); + return std::nullopt; + } + + // Only supports float type mean() if it's non-quantized + if (!input_is_qtype && !isa(output_type.getElementType())) { + op->emitWarning( + "Failed convertReduceMean: input unquantized type but output element " + "not FloatType!"); + return std::nullopt; + } + + int64_t input_rank = input_type.getRank(); + int64_t num_elems_on_reduced_axis = 1; + for (int i = 0; i < axes_elems.getNumElements(); i++) { + int64_t axis_val = axes_elems.getValues()[i].getInt(); + if (axis_val < 0) + axis_val += input_rank; + num_elems_on_reduced_axis *= input_type.getShape()[axis_val]; + } + double div_scale = 1.0 / static_cast(num_elems_on_reduced_axis); + + double input_scale = 1.0f; + double output_scale = 1.0f; + int64_t input_zp = 0; + int64_t output_zp = 0; + mlir::Type reduce_element_type = input_type.getElementType(); + + if (input_is_qtype) { + auto input_qtype = + cast(input_type.getElementType()); + auto output_qtype = + cast(output_type.getElementType()); + + // Combine 'div_scale' as part of output rescale + output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale(); + + input_zp = input_qtype.getZeroPoint(); + output_zp = output_qtype.getZeroPoint(); + reduce_element_type = rewriter.getI32Type(); + } + + auto val = convertReduceOpCommon(rewriter, op, + output_type, input_value, axes_elems, keep_dims, reduce_element_type, + input_is_qtype, input_scale, input_zp, output_scale, output_zp); + + if (!val.has_value()) + return std::nullopt; + + if (!input_is_qtype) { + Value div_const = tosaBuilder.getSplattedConst( + div_scale, output_type.getElementType(), output_type.getRank()); + return tosaBuilder.mul(val.value(), div_const); + } + + return val; +} +} // namespace tosa +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp index d5ef2d5053..97ffd188c2 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp @@ -5,7 +5,7 @@ //====------ ONNXToTOSACommon.hpp - ONNX dialects to TOSA lowering --------===// // // Copyright 2020-2024 The TensorFlow Authors. All Rights Reserved. -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. // // ============================================================================= // @@ -19,7 +19,7 @@ #include "DialectBuilder.hpp" #include "ONNXToTOSALegalizeUtils.hpp" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/MLIRContext.h" @@ -30,21 +30,27 @@ #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Pass/Passes.hpp" +#include +#include +#include //===----------------------------------------------------------------------===// // Functions to add lowering patterns for frontend operations. //===----------------------------------------------------------------------===// - namespace onnx_mlir { namespace tosa { -// Common function for lowering reduce operations to TOSA ops. -// Modified from TensorFlow -template -std::optional convertReduceOpCommon( - mlir::PatternRewriter &rewriter, mlir::Operation *op, - mlir::RankedTensorType outputType, mlir::Value inputValue, - mlir::ElementsAttr axesElems, bool keepDims, mlir::Type reduceElementType); +// Lowers Gather operators to a sequence of TOSA ops. +std::optional convertGatherOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value resultValue, mlir::Value inputValue, + mlir::Value indicesValue, int32_t batchDims, int32_t axis); + +// Lowers ReduceMean to a sequence of TOSA ops. +// Originates from the TorchToTosa conversion +std::optional convertReduceMeanOp(mlir::PatternRewriter &rewriter, + mlir::Operation *op, TosaBuilder &tosaBuilder, + mlir::RankedTensorType output_type, mlir::Value input_value, + mlir::ElementsAttr axes_elems, bool keep_dims); // This calculates the values that need to be added to the padding in order to // simulate the ceil mode @@ -56,7 +62,7 @@ llvm::SmallVector getCeilConstants(llvm::ArrayRef inputShape, // Create an ArrayAttr of pad from \p shapeHelper using \p padIndexOrder. // Values are calculated considering \p ceilMode. template -mlir::ArrayAttr createOrderedPadAttrForWindowBasedOps( +llvm::SmallVector createOrderedPadAttrForWindowBasedOps( mlir::PatternRewriter &rewriter, const llvm::ArrayRef inputShape, ONNXGenericPoolOpShapeHelper &shapeHelper, const int64_t ceilMode, const llvm::ArrayRef padIndexOrder); @@ -72,15 +78,23 @@ mlir::FailureOr convertPoolOp( #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc" } // namespace tosa +} // namespace onnx_mlir + +namespace onnx_mlir { //===----------------------------------------------------------------------===// // Check for valid TOSA types. //===----------------------------------------------------------------------===// -inline bool isTOSASignedInt(mlir::Type type) { +inline bool isTOSABool(mlir::Type type) { + mlir::IntegerType intType = mlir::dyn_cast(type); + return intType && intType.isSignless() && intType.getWidth() == 1; +} + +inline bool isTOSAInt(mlir::Type type) { mlir::IntegerType intType = mlir::dyn_cast(type); std::set intWidth{1, 8, 16, 32, 48, 64}; - return intType && intType.isSignless() && + return intType && (intType.isSignless() || intType.isUnsignedInteger()) && (intWidth.find(intType.getWidth()) != intWidth.end()); } @@ -107,22 +121,62 @@ void populateLoweringONNXGemmOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXSoftmaxOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); -void populateLoweringONNXReduceMeanOpToTOSAPattern(mlir::ConversionTarget &, +void populateLoweringONNXReduceOpsToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXConvOpToTOSAPattern(mlir::ConversionTarget &, - mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *, + int64_t); // `NN` directory methods: void populateLoweringONNXMaxPoolSingleOutOpToTOSAPattern( mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXAveragePoolOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); -// `Tensor` directory methods: -void populateLoweringONNXConstOpToTOSAPattern(mlir::ConversionTarget &, +void populateLoweringONNXQuantizeLinearOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXDequantizeLinearOpToTOSAPattern( + mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, + mlir::MLIRContext *); +void populateLoweringONNXMatMulOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXBatchNormalizationOpToTOSAPattern( + mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, + mlir::MLIRContext *); +// `Tensor` directory methods: void populateLoweringONNXReshapeOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXConcatOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXGatherOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); void populateLoweringONNXResizeOpToTOSAPattern(mlir::ConversionTarget &, mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXShrinkOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXConstOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXEyeLikeOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXPadOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXFlattenOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXSliceOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXSplitOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXSqueezeOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXTileOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXExpandOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXTransposeOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXWhereOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +// 'Flow' directory methods: +void populateLoweringONNXEntryPointOpToTOSAPattern(mlir::ConversionTarget &, + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); } // namespace onnx_mlir #endif diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc index 99558a5bd0..22ae929f9f 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp.inc @@ -17,16 +17,11 @@ // Common function for lowering reduce operations to TOSA ops. // Modified from TensorFlow template -std::optional convertReduceOpCommon( - mlir::PatternRewriter &rewriter, mlir::Operation *op, - mlir::RankedTensorType outputType, mlir::Value inputValue, - mlir::ElementsAttr axesElems, bool keepDims, mlir::Type reduceElementType) { +mlir::Value convertReduceOpCommon(mlir::PatternRewriter &rewriter, + mlir::Operation *op, mlir::RankedTensorType outputType, + mlir::Value inputValue, mlir::RankedTensorType inputType, + mlir::ElementsAttr axesElems, bool keepDims) { TosaBuilder tosaBuilder(rewriter, op->getLoc()); - mlir::RankedTensorType inputType = - mlir::dyn_cast(inputValue.getType()); - if (!inputType) - return std::nullopt; - llvm::ArrayRef inputShape = inputType.getShape(); llvm::ArrayRef outputShape = outputType.getShape(); auto inputRank = inputShape.size(); @@ -48,7 +43,7 @@ std::optional convertReduceOpCommon( shapeVec[axisVal] = 1; mlir::RankedTensorType reduceType = - mlir::RankedTensorType::get(shapeVec, reduceElementType); + mlir::RankedTensorType::get(shapeVec, inputType.getElementType()); auto reduceOp = CreateOpAndInfer( rewriter, op->getLoc(), reduceType, newValue, axisAttr); @@ -102,7 +97,7 @@ llvm::SmallVector getCeilConstants(llvm::ArrayRef inputShape, // Create an ArrayAttr of pad from \p shapeHelper using \p padIndexOrder. // Values are calculated considering \p ceilMode. template -mlir::ArrayAttr createOrderedPadAttrForWindowBasedOps( +llvm::SmallVector createOrderedPadAttrForWindowBasedOps( mlir::PatternRewriter &rewriter, const llvm::ArrayRef inputShape, ONNXGenericPoolOpShapeHelper &shapeHelper, const int64_t ceilMode, const llvm::ArrayRef padIndexOrder) { @@ -122,19 +117,22 @@ mlir::ArrayAttr createOrderedPadAttrForWindowBasedOps( } // reorder padding according to the passed order and considering ceilMode. - return rewriter.getI64ArrayAttr({padOrder[0], padOrder[1] + ceilConstants[0], - padOrder[2], padOrder[3] + ceilConstants[1]}); + llvm::SmallVector reorderedPads = {padOrder[0], + padOrder[1] + ceilConstants[0], padOrder[2], + padOrder[3] + ceilConstants[1]}; + + return reorderedPads; } inline mlir::LogicalResult getAvgPool2dAccType(mlir::PatternRewriter &rewriter, mlir::Value input, mlir::TypeAttr &accType) { - auto inputTy = llvm::dyn_cast(input.getType()); + auto inputTy = mlir::dyn_cast(input.getType()); if (!inputTy) return mlir::failure(); auto inputETy = inputTy.getElementType(); if (auto quantType = - llvm::dyn_cast(inputETy)) + mlir::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time @@ -154,10 +152,20 @@ mlir::FailureOr convertPoolOp( using OpAdaptor = typename ONNXPoolOp::Adaptor; mlir::Location loc = op->getLoc(); OpAdaptor adaptor(op->getOperands(), op->getAttrDictionary()); + + // If the attribute is absent, the default dilations are 1. + if (std::optional dilations = adaptor.getDilations()) { + auto intDilations = mlir::extractFromIntegerArrayAttr(*dilations); + if (llvm::any_of(intDilations, [](int64_t d) { return d != 1; })) + return rewriter.notifyMatchFailure(op, "TOSA does not support dilations"); + } + // Get shape. IndexExprBuilderForTosa createTosaIE(rewriter, loc); ONNXGenericPoolOpShapeHelper shapeHelper(op, {}, &createTosaIE); - shapeHelper.computeShapeAndAssertOnFailure(); + if (shapeHelper.computeShape().failed()) { + return rewriter.notifyMatchFailure(op, "Could not infer shapes"); + } TosaBuilder tosaBuilder(rewriter, loc); @@ -169,6 +177,11 @@ mlir::FailureOr convertPoolOp( } auto kernelShape = adaptor.getKernelShapeAttr(); + llvm::SmallVector kernelShapeVec; + llvm::transform(kernelShape, std::back_inserter(kernelShapeVec), + [](const mlir::Attribute &pad) { + return mlir::cast(pad).getInt(); + }); const int64_t ceilMode = adaptor.getCeilMode(); @@ -199,9 +212,21 @@ mlir::FailureOr convertPoolOp( llvm::SmallVector pads; IndexExpr::getLiteral(shapeHelper.pads, pads); - // reorder padding values - auto newPads = rewriter.getDenseI64ArrayAttr({pads[0], - pads[2] + ceilConstants[0], pads[1], pads[3] + ceilConstants[1]}); + llvm::SmallVector reorderedPads = { + pads[0], pads[2] + ceilConstants[0], pads[1], pads[3] + ceilConstants[1]}; + + mlir::FailureOr resizedInput = tosaBuilder.resizeWindowBasedOps( + input, mlir::cast(input.getType()).getShape(), + {kernelShapeVec[0], kernelShapeVec[1]}, reorderedPads, + shapeHelper.strides, shapeHelper.dilations); + + if (failed(resizedInput)) { + return rewriter.notifyMatchFailure( + op, "could not resize input to match parameters"); + } + + mlir::DenseI64ArrayAttr newPads = + rewriter.getDenseI64ArrayAttr(reorderedPads); auto strides = rewriter.getDenseI64ArrayAttr(shapeHelper.strides); @@ -212,19 +237,19 @@ mlir::FailureOr convertPoolOp( std::is_same::value, "Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp"); if constexpr (std::is_same::value) { - input = tosa::CreateOpAndInfer( - rewriter, loc, newResultType, input, newKernelShape, strides, newPads) + input = tosa::CreateOpAndInfer(rewriter, loc, newResultType, + *resizedInput, newKernelShape, strides, newPads) .getResult(); } else if constexpr (std::is_same::value) { mlir::TypeAttr accType; - if (failed(tosa::getAvgPool2dAccType(rewriter, input, accType))) { + if (failed(tosa::getAvgPool2dAccType(rewriter, *resizedInput, accType))) { (void)rewriter.notifyMatchFailure( op, "Failed to get accumulator type for pooling"); return mlir::failure(); } input = tosa::CreateOpAndInfer(rewriter, loc, newResultType, - input, newKernelShape, strides, newPads, accType) + *resizedInput, newKernelShape, strides, newPads, accType) .getResult(); } @@ -232,4 +257,4 @@ mlir::FailureOr convertPoolOp( // Construct the old result shape out of the new one mlir::Value transpose = tosaBuilder.transpose(input, {0, 3, 1, 2}); return transpose; -}; +} diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp index 4a154b52de..24ca82fcac 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp @@ -24,6 +24,7 @@ #include "mlir/Support/LLVM.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" // from @llvm-project #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -35,14 +36,79 @@ using namespace mlir; namespace onnx_mlir { namespace tosa { -mlir::RankedTensorType reduceAxisToOne(llvm::ArrayRef shape, - mlir::Type elementType, mlir::Attribute encoding) { +int64_t convertNegativeAxis(int64_t axis, int64_t inputRank) { + if (axis < 0) + axis += inputRank; + + // Check if axis is in correct range. + assert( + (axis >= 0 && axis < inputRank) && "axis attribute not in correct range"); + + return axis; +} + +llvm::SmallVector createInt64VectorFromIndexExpr( + llvm::ArrayRef indexVector) { + llvm::SmallVector literalVector(indexVector.size()); + llvm::transform(indexVector, literalVector.begin(), + [](const auto &indexExpr) { return indexExpr.getLiteral(); }); + return literalVector; +} + +mlir::RankedTensorType reduceAxisToOne( + int64_t rank, Type elementType, Attribute encoding) { return mlir::RankedTensorType::get( - llvm::SmallVector(shape.size(), 1), elementType, encoding); + llvm::SmallVector(rank, 1), elementType, encoding); +} + +mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val) { + if (auto source = val.getDefiningOp()) { + if (source.getValue()) + return cast(source.getValue().value()); + } + // if the constant is not an onnx.const it has to be a tosa.const + assert(val.getDefiningOp()); + return tosa::getValueFromTosaConst(val); +} + +// Create a TOSA rescale op from input framework tensor, zero points and +// rounding mode +Value buildRescale(PatternRewriter &rewriter, Operation *op, + ShapedType output_type, Value input_val, double scale, int64_t input_zp, + int64_t output_zp, bool double_round, bool scale32) { + int32_t multiplier; + int32_t shift; + + int32_t scale_width = scale32 ? 32 : 16; + + mlir::tosa::computeMultiplierAndShift(scale, multiplier, shift, scale_width); + + auto rescale_op = CreateOpAndInfer(rewriter, + op->getLoc(), output_type, input_val, + rewriter.getI32IntegerAttr(static_cast(input_zp)), + rewriter.getI32IntegerAttr(static_cast(output_zp)), + rewriter.getDenseI32ArrayAttr({multiplier}), + rewriter.getDenseI8ArrayAttr({(int8_t)shift}), + rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), + rewriter.getBoolAttr(false)); + + return rescale_op.getResult(); +} + +// Creates TOSA rescale op with int32 output +Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op, + Value input_val, double input_scale, int64_t input_zp) { + // Output is always int32 type + auto input_type = dyn_cast(input_val.getType()); + assert(input_type); + auto output_type = input_type.clone(rewriter.getI32Type()); + + return buildRescale(rewriter, op, output_type, input_val, input_scale, + input_zp, 0, false, true); } mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter, - llvm::ArrayRef onnxPads, mlir::Location loc, + llvm::ArrayRef onnxPads, Location loc, const std::initializer_list &initialVals, const std::initializer_list &lastVals) { @@ -66,5 +132,16 @@ mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter, return tosaBuilder.getConst(tosaPads, {numberOfDims, 2}); } +mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc, + mlir::Value tensor, size_t axis, size_t rank) { + auto inTy = cast(tensor.getType()); + llvm::SmallVector newShape(rank, 1); + newShape[axis] = inTy.getNumElements(); + auto resultTy = RankedTensorType::get(newShape, inTy.getElementType()); + + return rewriter.createOrFold( + loc, resultTy, tensor, newShape); +} + } // namespace tosa } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index 0343c79027..5c99af6d5d 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -4,8 +4,8 @@ //==== ONNXToTosaLegalizeUtils.hpp - ONNX dialects to TOSA lowering Utils-===// // -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. +// Copyright 2020-2024 The TensorFlow Authors. All Rights Reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. // // ============================================================================= // @@ -17,7 +17,7 @@ #ifndef ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H #define ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -25,14 +25,23 @@ #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include #include namespace onnx_mlir { namespace tosa { -// Create a RankedTensorType with shape and all elements being 1 -mlir::RankedTensorType reduceAxisToOne(llvm::ArrayRef shape, - mlir::Type elementType, mlir::Attribute encoding = {}); +// ONNX can use negative indices for axis while TOSA cannot. This functions +// makes sure the axis is in the right range for TOSA. +int64_t convertNegativeAxis(int64_t axis, int64_t inputRank); + +// Get a vector of indexExpr and extract the Int64 values +llvm::SmallVector createInt64VectorFromIndexExpr( + llvm::ArrayRef indexVector); + +// Create a RankedTensorType with the given rank and all dims being 1 +mlir::RankedTensorType reduceAxisToOne( + int64_t rank, mlir::Type elementType, mlir::Attribute encoding = {}); // Returns the value TOSA ConstOp template @@ -40,6 +49,16 @@ T getValueFromTosaConst(mlir::Value &val) { return mlir::cast(val.getDefiningOp().getValue()); } +// Retrieves an ElementsAttr out of a const operator. +// This function is made to work with both onnx.const and tosa.const +mlir::ElementsAttr getElementsAttrFromConst(mlir::Value &val); + +// Takes a 1-d `tensor` with k elements and reshapes it into an `rank`-d tensor +// with shape {1, ..., 1, k, 1, ..., 1 } +// where `k` it at position `axis`. +mlir::Value expandShape(mlir::PatternRewriter &rewriter, mlir::Location loc, + mlir::Value tensor, size_t axis, size_t rank); + // Creates a TOSA operation and performs shape inference on the individual // op. This allows shape inference during the framework to TOSA lowering. template @@ -48,7 +67,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, auto op = rewriter.create(loc, result_ty, args...); mlir::InferShapedTypeOpInterface shapeInterface = - llvm::dyn_cast(op.getOperation()); + mlir::dyn_cast(op.getOperation()); if (!shapeInterface) return op; @@ -79,6 +98,17 @@ void CreateReplaceOpAndInfer(mlir::PatternRewriter &rewriter, rewriter.replaceOp(op, result->getResults()); } +// Create a TOSA rescale op from input framework scaling, zero points and +// rounding mode +mlir::Value buildRescale(mlir::PatternRewriter &rewriter, mlir::Operation *op, + mlir::ShapedType output_type, mlir::Value input_val, double scale, + int64_t input_zp, int64_t output_zp, bool double_round, bool scale32); + +// Creates TOSA rescale op with int32 output +mlir::Value buildRescaleToInt32(mlir::PatternRewriter &rewriter, + mlir::Operation *op, mlir::Value input_val, double input_scale, + int64_t input_zp); + /// Create a padding tosa::ConstOp from ONNX to Tosa format. /// The two formats are: /// ONNX : [b1, b2, b3, b4, e1, e2, e3, e4] @@ -91,4 +121,4 @@ mlir::Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter, } // namespace tosa } // namespace onnx_mlir -#endif // ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H \ No newline at end of file +#endif // ONNXMLIR_CONVERSION_ONNXTOTOSA_TOSALEGALIZEUTILS_H diff --git a/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp b/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp new file mode 100644 index 0000000000..76fa0f9b8e --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Concat.cpp @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Concat.cpp - Concat Op --------------------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX ConcatOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXConcatLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXConcatOp::Adaptor; + LogicalResult matchAndRewrite(ONNXConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + ValueRange inputs = adaptor.getInputs(); + int64_t axis = adaptor.getAxis(); + auto resultType = op.getResult().getType(); + + for (const auto &input : inputs) { + if (!onnx_mlir::isRankedShapedType(input.getType())) + return rewriter.notifyMatchFailure( + op, "inputs are not ranked shaped tensors"); + } + int64_t inputRank = onnx_mlir::getRank(inputs[0].getType()); + + // onnx allows values beetween [-r, r-1] where r is the rank. + axis = tosa::convertNegativeAxis(axis, inputRank); + + Type newConcatOutputType = RankedTensorType::get( + llvm::SmallVector(inputRank, ShapedType::kDynamic), + cast(resultType).getElementType()); + + tosa::CreateReplaceOpAndInfer( + rewriter, op, newConcatOutputType, inputs, axis); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXConcatOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp b/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp index 051a208898..43703b3323 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Constant.cpp @@ -44,6 +44,10 @@ class ONNXConstOpLoweringToTOSA : public OpConversionPattern { op, "tosa.const does not support non-tensor types"); } Type resultType = getTypeConverter()->convertType(op.getResult().getType()); + if (!resultType) { + return rewriter.notifyMatchFailure( + op, "tosa.const does not support the requested type"); + } rewriter.replaceOpWithNewOp( op, resultType, mlir::cast(currentAttr)); return success(); @@ -58,4 +62,4 @@ void populateLoweringONNXConstOpToTOSAPattern(ConversionTarget &target, patterns.insert(typeConverter, ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp b/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp new file mode 100644 index 0000000000..941fa76d62 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Expand.cpp @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------------ Expand.cpp - Expand Op ---------------------===// +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX ExpandOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXExpandLoweringToTOSA : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXExpandOp::Adaptor; + + LogicalResult matchAndRewrite(ONNXExpandOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto shape = adaptor.getShape(); + DenseIntElementsAttr denseAttr; + if (!matchPattern(shape, m_Constant(&denseAttr))) { + return rewriter.notifyMatchFailure( + op, "onnx.expand can only be lowered with constant expanded shape"); + } + + // Convert denseAttr to DenseI64ArrayAttr. This handles both splat and + // non-splat scenarios. + ArrayBuffer shapeWideNums = getElementsWideNums(denseAttr); + ArrayRef shapeArray = + castArrayRef(shapeWideNums.get()); + + auto inputType = + mlir::dyn_cast_or_null(adaptor.getInput().getType()); + auto outputType = + mlir::dyn_cast_or_null(op.getResult().getType()); + if (!inputType || !outputType || !inputType.hasStaticShape() || + !outputType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "Unranked and dynamic types are not supported"); + } + size_t inputRank = onnx_mlir::getRank(inputType); + + // If inputRank is inferior to shapeRank we need to introduce a + // reshape before the tile + auto newInput = adaptor.getInput(); + if (inputRank != shapeArray.size()) { + llvm::SmallVector newShape = + getNewShape(inputType.getShape(), outputType.getShape()); + // If the newShape size doesn't match the output shape size, it means we + // didn't find a proper reshape to match the input to. + if (newShape.size() != outputType.getShape().size()) { + return rewriter.notifyMatchFailure( + op, "Could not find a shape that satisfies the expand constraints"); + } + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + newInput = tosaBuilder.reshape(adaptor.getInput(), newShape); + } + + auto denseShape = + getMultiplies(op, cast(newInput.getType()).getShape(), + outputType.getShape()); + auto resultElementType = cast(inputType).getElementType(); + auto newResultElementType = + getTypeConverter()->convertType(resultElementType); + + if (!isSupportedElementType(newResultElementType)) { + return rewriter.notifyMatchFailure( + op, "input/output type is invalid for tosa.tile"); + } + Type newTileOutputType = RankedTensorType::get( + llvm::SmallVector( + outputType.getShape().size(), ShapedType::kDynamic), + newResultElementType); + onnx_mlir::tosa::CreateReplaceOpAndInfer( + rewriter, op, newTileOutputType, newInput, denseShape); + return success(); + } + +private: + // Supported element types for tosa.tile + static bool isSupportedElementType(Type type) { + if (auto intTy = dyn_cast_or_null(type)) { + // Supported integer bit widths + std::set intWidth({8, 16, 32}); + return isTOSABool(type) || + (intTy.isSignless() && + (intWidth.find(intTy.getWidth()) != intWidth.end())); + } + return type.isBF16() || type.isF16() || type.isF32(); + } + + static llvm::SmallVector getNewShape( + const llvm::ArrayRef &inputShape, + const llvm::ArrayRef &outputShape) { + llvm::SmallVector result; + size_t inputIdx = 0; + for (auto outputDimension : outputShape) { + // - If the inputIdx goes beyond the input shape, it means we are + // extending the shape: + // Ex: 1x3x4 -> 1x3x4x1 + // - If the input dim is < output dim and do not divide it, + // it's a dimension being added: + // Ex: 3x1 -> 2x1x6 (first dim is a new dim and not a tiled original + // one) + // - If the output dim is < input dim, + // it's also a dim being added: + // Ex: 2x3x4 -> 1x2x3x4 + if (inputIdx >= inputShape.size() || + (inputShape[inputIdx] < outputDimension && + outputDimension % inputShape[inputIdx] != 0) || + outputDimension < inputShape[inputIdx]) { + result.push_back(1); + } else { + result.push_back(inputShape[inputIdx]); + inputIdx++; + } + } + return result; + } + + static DenseI64ArrayAttr getMultiplies(ONNXExpandOp &op, + const llvm::ArrayRef &inputShape, + const llvm::ArrayRef &outputShape) { + llvm::SmallVector multipliesArray; + for (size_t i = 0; i < outputShape.size(); ++i) { + if (i >= inputShape.size() || outputShape[i] / inputShape[i] == 0) { + multipliesArray.push_back(1); + } else { + multipliesArray.push_back(outputShape[i] / inputShape[i]); + } + } + return DenseI64ArrayAttr::get(op.getContext(), multipliesArray); + } +}; + +} // namespace + +void populateLoweringONNXExpandOpToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/EyeLike.cpp b/src/Conversion/ONNXToTOSA/Tensor/EyeLike.cpp new file mode 100644 index 0000000000..822faddcf3 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/EyeLike.cpp @@ -0,0 +1,79 @@ +// (c) Copyright 2022 - 2024 Advanced Micro Devices, Inc. All Rights Reserved. + +#include "mlir/Support/LogicalResult.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXEyeLikeLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXEyeLikeOp::Adaptor; + LogicalResult matchAndRewrite(ONNXEyeLikeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto resultType = dyn_cast(op.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "onnx.EyeLikeOp needs to have static shape for lowering to tosa"); + } + const auto elementType = resultType.getElementType(); + const auto convertedType = dyn_cast_or_null( + getTypeConverter()->convertType(resultType)); + if (!convertedType) { + return rewriter.notifyMatchFailure( + op, "EyeLike type not supported in tosa"); + } + int64_t k = 0; + if (auto kAttr = adaptor.getKAttr()) { + k = kAttr.getSInt(); + } + DenseElementsAttr replacementAttr; + if (auto intType = dyn_cast(elementType)) { + replacementAttr = getEyeLikeAttr(convertedType, resultType.getDimSize(0), + resultType.getDimSize(1), k, APInt(intType.getWidth(), 0), + APInt(intType.getWidth(), 1)); + } else if (auto floatType = dyn_cast(elementType)) { + replacementAttr = getEyeLikeAttr(convertedType, resultType.getDimSize(0), + resultType.getDimSize(1), k, + APFloat::getZero(floatType.getFloatSemantics()), + APFloat(floatType.getFloatSemantics(), 1)); + } else { + return rewriter.notifyMatchFailure(op, "Only int and float supported"); + } + + rewriter.replaceOpWithNewOp( + op, convertedType, replacementAttr); + return success(); + } + +private: + template + DenseElementsAttr getEyeLikeAttr(const ShapedType type, const int64_t dimY, + const int64_t dimX, const int64_t k, const T zero, const T one) const { + const auto size = dimX * dimY; + SmallVector vec(size, zero); + const auto smallDim = std::min(dimX - std::abs(k), dimY); + for (int64_t i = 0; i < smallDim; ++i) { + if (k >= 0) { + vec[(dimX * i) + i + k] = one; + } else { + vec[(dimX * (i - k)) + i] = one; + } + } + return DenseElementsAttr::get(type, vec); + } +}; +} // namespace + +void populateLoweringONNXEyeLikeOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp b/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp new file mode 100644 index 0000000000..e6207dee61 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Flatten.cpp @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Flatten.cpp - Flatten Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX ReshapeOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXFlattenLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXFlattenOp::Adaptor; + LogicalResult matchAndRewrite(ONNXFlattenOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value input = adaptor.getInput(); + + // tosa.reshape does not allow a dynamic entry in the new_shape attribute + if (!hasStaticShape(input.getType())) + return rewriter.notifyMatchFailure( + op, "only static shapes are supported"); + + auto loc = op->getLoc(); + TosaBuilder tosaBuilder(rewriter, loc); + + int64_t axis = adaptor.getAxis(); + auto inputType = cast(input.getType()); + + // onnx allows values beetween [-r, r] where r is the rank. + if (axis == inputType.getRank()) { + // axis == rank is valid for Flatten + } else { + // check if the axis is in range [-r, r-1] where r is the rank + axis = tosa::convertNegativeAxis(axis, inputType.getRank()); + } + + llvm::SmallVector newShape; + auto inputShape = inputType.getShape(); + if (axis == 0) { + newShape.push_back(1); + int64_t lastShape = 1; + for (const int64_t axis : inputShape) { + lastShape *= axis; + } + newShape.push_back(lastShape); + } else { + int64_t firstShape = 1; + for (int i = 0; i < axis; i++) { + firstShape *= inputShape[i]; + } + newShape.push_back(firstShape); + int64_t secondShape = 1; + for (int i = axis; i < inputType.getRank(); i++) { + secondShape *= inputShape[i]; + } + newShape.push_back(secondShape); + } + Value reshapeOp = tosaBuilder.reshape(input, newShape); + + rewriter.replaceOp(op, reshapeOp); + + return success(); + } +}; + +} // namespace + +void populateLoweringONNXFlattenOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp new file mode 100644 index 0000000000..e9e74e838e --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Gather.cpp @@ -0,0 +1,119 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Gather.cpp - Gather Op ------------------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX GatherOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXGatherLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXGatherOp::Adaptor; + LogicalResult matchAndRewrite(ONNXGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + + TosaBuilder tosaBuilder(rewriter, loc); + + Value input = adaptor.getData(); + Value indices = adaptor.getIndices(); + int64_t axis = adaptor.getAxis(); + + auto result = op.getResult(); + + auto inputType = dyn_cast(input.getType()); + if (!onnx_mlir::isRankedShapedType(inputType)) + return rewriter.notifyMatchFailure(op, "input is not a ranked tensor"); + + if (!hasStaticShape(result.getType())) + return rewriter.notifyMatchFailure(op, "dynamic shapes not supported"); + + auto resultTy = dyn_cast(op.getType()); + if (!onnx_mlir::isRankedShapedType(resultTy)) + return rewriter.notifyMatchFailure(op, "result is not a ranked tensor"); + int64_t inputRank = onnx_mlir::getRank(inputType); + + // onnx allows values beetween [-r, r-1] where r is the rank + axis = tosa::convertNegativeAxis(axis, inputRank); + + auto indicesType = cast(indices.getType()); + + APInt indicesVal; + if (indicesType.getRank() == 0 && + matchPattern(indices, m_ConstantInt(&indicesVal)) && + indicesVal.getSExtValue() >= 0) { + llvm::SmallVector starts(inputType.getRank(), 0); + llvm::SmallVector size{inputType.getShape()}; + starts[axis] = indicesVal.getSExtValue(); + size[axis] = 1; + Value sliceOp = tosaBuilder.slice(input, size, starts); + auto reshape = tosaBuilder.reshape(sliceOp, resultTy.getShape()); + rewriter.replaceOp(op, reshape); + return success(); + } + + SmallVector newIndicesValues; + newIndicesValues.resize(indicesType.getNumElements()); + + ArrayRef inputShape = cast(inputType).getShape(); + + // ONNX allows negative indices and TOSA doesn't. + // We will emit ops to compute + // newIndices = indices >= 0 ? indices : indices + dimSize + // element-wise. + + // Create an 1x..x1 constant containing the size of the gathered dimension. + auto dimSize = tosaBuilder.getSplattedConst( + inputShape[axis], indicesType.getElementType(), indicesType.getRank()); + auto indicesPlusDimSize = + tosaBuilder.binaryOp(indices, dimSize); + + auto zero = tosaBuilder.getSplattedConst( + (int64_t)0, indicesType.getElementType(), indicesType.getRank()); + auto indicesPositive = tosaBuilder.greaterEqual(indices, zero); + + auto newIndices = + tosaBuilder.select(indicesPositive, indices, indicesPlusDimSize); + + auto newGather = + tosaBuilder.gather(result, input, newIndices, 0, (int32_t)axis); + + if (!newGather.has_value()) { + return failure(); + } + rewriter.replaceOp(op, newGather.value()); + + return success(); + } +}; + +} // namespace + +void populateLoweringONNXGatherOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp new file mode 100644 index 0000000000..058cc43bad --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/PaddingOp.cpp @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Padding.cpp - Padding Op --------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX padding operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; + +namespace onnx_mlir { + +class ONNXPadOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXPadOp::Adaptor; + LogicalResult matchAndRewrite(ONNXPadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + Location loc = op.getLoc(); + + Value data = adaptor.getData(); + Value pads = adaptor.getPads(); + Value constValue = adaptor.getConstantValue(); + + auto dataType = dyn_cast(data.getType()); + if (!dataType || !dataType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "input type has no static shape"); + } + + auto elementDtype = dataType.getElementType(); + if (!isa(elementDtype) && !isTOSAInt(elementDtype)) { + return rewriter.notifyMatchFailure(op, "unsupported type"); + } + + if (!adaptor.getAxes().getDefiningOp()) { + return rewriter.notifyMatchFailure(op, "only default axes are supported"); + } + + if (!(adaptor.getMode() == "constant")) { + return rewriter.notifyMatchFailure( + op, "only 'constant' mode is supported"); + } + + if (!pads.getDefiningOp() || + !(constValue.getDefiningOp() || + constValue.getDefiningOp())) { + return rewriter.notifyMatchFailure( + op, "only tosa.const operands are supported"); + } + // creating the DenseElementsAttr using pads values. + auto denseAttr = tosa::getValueFromTosaConst(pads); + + // Reading the ONNX side pads values and store in the array. + llvm::SmallVector intValues; + bool paddingNeeded = false; + for (auto n : denseAttr.getValues()) { + intValues.push_back(n.getZExtValue()); + if (!n.isZero()) + paddingNeeded = true; + } + if (!paddingNeeded) { + // We do not need to represent the no-op pad in the resulting MLIR + rewriter.replaceOp(op, {data}); + return success(); + } + + Value padsList1 = + tosa::buildOnnxToTosaPaddingConstOp(rewriter, intValues, loc, {}, {}); + + mlir::Type resultType = + getTypeConverter()->convertType(op.getResult().getType()); + + if (!isa(constValue.getType())) { + auto valueAttr = tosa::getValueFromTosaConst(constValue); + TosaBuilder tosaBuilder(rewriter, loc); + + Value constTosaTensor; + if (isa(valueAttr.getElementType())) { + auto valueIt = valueAttr.getValues().begin(); + const float valueFloat = cast(*valueIt).getValueAsDouble(); + constTosaTensor = tosaBuilder.getSplattedConst( + valueFloat, valueAttr.getElementType(), 0); + } else { + assert(isTOSAInt(elementDtype) && "Already validated"); + auto valueIt = valueAttr.getValues().begin(); + auto valueAsAPInt = cast(*valueIt).getValue(); + auto asIntegerTy = cast(valueAttr.getElementType()); + if (asIntegerTy.isUnsigned()) { + constTosaTensor = tosaBuilder.getSplattedConst( + valueAsAPInt.getZExtValue(), asIntegerTy, 0); + } else { + constTosaTensor = tosaBuilder.getSplattedConst( + valueAsAPInt.getSExtValue(), asIntegerTy, 0); + } + } + rewriter.replaceOpWithNewOp( + op, resultType, data, padsList1, constTosaTensor); + + } else { + auto constType = RankedTensorType::get({}, elementDtype); + + DenseElementsAttr constAttr; + if (isa(elementDtype)) { + constAttr = DenseElementsAttr::get(constType, 0.0F); + } else { + assert(isTOSAInt(elementDtype) && "Already validated"); + auto tyAsInt = cast(elementDtype); + constAttr = DenseElementsAttr::get(constType, + llvm::APInt(tyAsInt.getWidth(), 0, tyAsInt.getSignedness())); + } + + rewriter.replaceOpWithNewOp(op, resultType, data, + padsList1, + rewriter.create( + op->getLoc(), constType, constAttr)); + } + + return success(); + } +}; + +void populateLoweringONNXPadOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index 3f3269a029..1c5936bab2 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" @@ -69,7 +70,7 @@ ScaleHelper normalize(int64_t output, int64_t input, bool pytorchHalfPixel, // We can compute this directly based on previous values. border = denominator * (output - 1) - numerator * (input - 1) + offset; return ScaleHelper(numerator, denominator, offset, border); -}; +} void valuesFromAxis(ArrayAttr *axis, llvm::SmallVectorImpl &axisVec) { auto axisRange = axis->getAsRange(); @@ -171,7 +172,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto resizeOp = llvm::cast(op); + auto resizeOp = mlir::cast(op); Location loc = op->getLoc(); OpAdaptor adaptor(operands, op->getAttrDictionary()); @@ -203,7 +204,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { } auto elementType = inputType.getElementType(); - if (!(isTOSAFloat(elementType) || isTOSASignedInt(elementType))) { + if (!(isa(elementType) || isTOSAInt(elementType))) { return rewriter.notifyMatchFailure( resizeOp, "Element type is not supported by TOSA."); } @@ -327,4 +328,4 @@ void populateLoweringONNXResizeOpToTOSAPattern(ConversionTarget &target, patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp new file mode 100644 index 0000000000..eb1d78926d --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Shrink.cpp @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Shrink.cpp - Shrink Op-------------------------------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers the ONNX Shrink operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXShrinkOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXShrinkOp shrinkOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = shrinkOp->getLoc(); + + auto lambd = adaptor.getLambdAttr(); + auto bias = adaptor.getBiasAttr(); + auto input = adaptor.getInput(); + + TosaBuilder tosaBuilder(rewriter, loc); + + auto inputRankedTensorTy = dyn_cast(input.getType()); + if (!inputRankedTensorTy) { + return rewriter.notifyMatchFailure( + loc, "Expected RankedTensorType for input data of ShrinkOp"); + } + + // lambd and bias have float type so it is safe to conver it to float + const float lambdAsFloat = lambd.getValue().convertToFloat(); + const float biasAsFloat = bias.getValue().convertToFloat(); + auto lambdConstOp = tosaBuilder.getSplattedConst(lambdAsFloat, + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + auto negatedLambdConstOp = tosaBuilder.getSplattedConst(-lambdAsFloat, + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + auto biasConstOp = tosaBuilder.getSplattedConst(biasAsFloat, + inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + auto zeroConstOp = tosaBuilder.getSplattedConst( + 0, inputRankedTensorTy.getElementType(), inputRankedTensorTy.getRank()); + + // Formula to be implemented: + // { x < -lambd, then y = x + bias + // { x > lambd, then y = x - bias + // { otherwise, then y = 0 + + auto firstCmp = tosaBuilder.compareOp( + rewriter, loc, negatedLambdConstOp, input); + auto firstFormula = + tosaBuilder.binaryOp(input, biasConstOp); + auto firstSelect = tosaBuilder.select(firstCmp, firstFormula, zeroConstOp); + + auto secondCmp = tosaBuilder.compareOp( + rewriter, loc, input, lambdConstOp); + auto secondFormula = + tosaBuilder.binaryOp(input, biasConstOp); + auto secondSelect = + tosaBuilder.select(secondCmp, secondFormula, firstSelect); + + rewriter.replaceOp(shrinkOp, secondSelect); + + return success(); + } +}; + +} // namespace + +void populateLoweringONNXShrinkOpToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter & /*typeConverter*/, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Slice.cpp b/src/Conversion/ONNXToTOSA/Tensor/Slice.cpp new file mode 100644 index 0000000000..1ac32b929e --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Slice.cpp @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Slice.cpp - Slice Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX SliceOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXSliceLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXSliceOp::Adaptor; + LogicalResult matchAndRewrite(ONNXSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + if (!adaptor.getStarts().getDefiningOp()) { + return rewriter.notifyMatchFailure(op, "starts must be constant"); + } + if (!adaptor.getEnds().getDefiningOp()) { + return rewriter.notifyMatchFailure(op, "ends must be constant"); + } + + // Get shape. + IndexExprBuilderForTosa createTosaIE(rewriter, loc); + ONNXSliceOpShapeHelper shapeHelper(op, {}, &createTosaIE); + if (failed(shapeHelper.computeShape())) { + return rewriter.notifyMatchFailure(op, "could not compute shape."); + } + + TosaBuilder tosaBuilder(rewriter, loc); + + Value input = adaptor.getData(); + if (!(IndexExpr::isLiteral(shapeHelper.starts))) + return rewriter.notifyMatchFailure(op, "starts has no literals."); + if (!(IndexExpr::isLiteral(shapeHelper.ends))) + return rewriter.notifyMatchFailure(op, "ends has no literals."); + if (!(IndexExpr::isLiteral(shapeHelper.steps))) + return rewriter.notifyMatchFailure(op, "steps has no literals."); + + llvm::SmallVector starts; + IndexExpr::getLiteral(shapeHelper.starts, starts); + llvm::SmallVector ends; + IndexExpr::getLiteral(shapeHelper.ends, ends); + llvm::SmallVector steps; + IndexExpr::getLiteral(shapeHelper.steps, steps); + + for (const int64_t step : steps) { + if (step != 1) + return rewriter.notifyMatchFailure( + op, "TOSA only supports step size 1."); + } + + llvm::SmallVector size; + size.resize(starts.size()); + for (size_t i = 0; i < starts.size(); i++) { + size[i] = ends[i] - starts[i]; + } + + Value newSliceOp = tosaBuilder.slice(input, size, starts); + + rewriter.replaceOp(op, newSliceOp); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXSliceOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Split.cpp b/src/Conversion/ONNXToTOSA/Tensor/Split.cpp new file mode 100644 index 0000000000..1de7dd9263 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Split.cpp @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------- Split.cpp - Split Op---------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX SplitOp operator to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +namespace onnx_mlir { +namespace { +class ONNXSplitOpLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXSplitOp::Adaptor; + LogicalResult matchAndRewrite(ONNXSplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getInput(); + ShapedType inputType = cast(input.getType()); + + // tosa.slice does not allow a dynamic entry in the size attribute + if (!hasStaticShape(inputType)) + return rewriter.notifyMatchFailure( + op, "only static shapes are supported"); + + uint64_t rank = inputType.getRank(); + int64_t splitAxis = adaptor.getAxis(); + if (splitAxis < 0) + splitAxis += rank; + + IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc()); + ONNXSplitOpShapeHelper shapeHelper( + op, adaptor.getOperands(), &createTosaIE); + + // compute shape + if (failed(shapeHelper.computeShape())) + return rewriter.notifyMatchFailure(op, "could not compute shape."); + + TosaBuilder tosaBuilder(rewriter, op->getLoc()); + uint64_t outputNum = op.getNumResults(); + SmallVector slices; + slices.reserve(outputNum); + + llvm::SmallVector size; + llvm::SmallVector starts(rank, 0); + int64_t start = 0; + + for (uint64_t i = 0; i < outputNum; i++) { + DimsExpr outputDim = shapeHelper.getOutputDims(i); + IndexExpr::getShape(outputDim, size); + starts[splitAxis] = start; + slices.push_back(tosaBuilder.slice(input, size, starts)); + start += size[splitAxis]; + } + rewriter.replaceOp(op, slices); + return success(); + } +}; +} // namespace + +void populateLoweringONNXSplitOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/ONNXToTOSA/Tensor/Squeeze.cpp b/src/Conversion/ONNXToTOSA/Tensor/Squeeze.cpp new file mode 100644 index 0000000000..4acd81fd0a --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Squeeze.cpp @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Unsqueeze.cpp - Unsqueeze Op ------------------------===// +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX UnsqueezeOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +template +class ONNXUnsqueezeSqueezeLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename SqueezeOp::Adaptor; + LogicalResult matchAndRewrite(SqueezeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + + auto resultTy = dyn_cast(op.getResult().getType()); + if (!resultTy || !resultTy.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "expected ranked tensor type with static shape"); + } + + TosaBuilder tosaBuilder(rewriter, loc); + rewriter.replaceOp( + op, tosaBuilder.reshape(adaptor.getData(), resultTy.getShape())); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXSqueezeOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert, + ONNXUnsqueezeSqueezeLoweringToTOSA>(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp b/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp new file mode 100644 index 0000000000..cafeb5522a --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Tile.cpp @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------------ Tile.cpp - Tile Op --------------------------===// +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX TileOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXTileLoweringToTOSA : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXTileOp::Adaptor; + + LogicalResult matchAndRewrite(ONNXTileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputType = adaptor.getInput().getType(); + if (!onnx_mlir::isRankedShapedType(inputType)) { + return rewriter.notifyMatchFailure( + op, "input is not a ranked shaped tensor"); + } + + auto resultElementType = cast(inputType).getElementType(); + auto newResultElementType = + getTypeConverter()->convertType(resultElementType); + + if (!isSupportedElementType(newResultElementType)) { + return rewriter.notifyMatchFailure( + op, "input/output type is invalid for tosa.tile"); + } + + int64_t inputRank = onnx_mlir::getRank(inputType); + Type newOutputType = RankedTensorType::get( + llvm::SmallVector(inputRank, ShapedType::kDynamic), + newResultElementType); + + // Create the attribute for the repetitions + DenseIntElementsAttr denseReps; + if (!matchPattern(op.getRepeats(), m_Constant(&denseReps))) { + return rewriter.notifyMatchFailure( + op, "onnx.tile can only be lowered with constant repetitions"); + } + auto newReps = rewriter.getDenseI64ArrayAttr( + llvm::to_vector(denseReps.getValues())); + + onnx_mlir::tosa::CreateReplaceOpAndInfer( + rewriter, op, newOutputType, adaptor.getInput(), newReps); + return success(); + } + +private: + static bool isSupportedElementType(Type type) { + if (auto intTy = dyn_cast_or_null(type)) { + // Supported integer bit widths + std::set intWidth({8, 16, 32}); + return isTOSABool(type) || + (intTy.isSignless() && + (intWidth.find(intTy.getWidth()) != intWidth.end())); + } + return type.isBF16() || type.isF16() || type.isF32(); + } +}; + +} // namespace + +void populateLoweringONNXTileOpToTOSAPattern(ConversionTarget & /*target*/, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp new file mode 100644 index 0000000000..514b0e5cda --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Tanspose.cpp - Transpose Op --------------------===// +// +// Copyright (c) 2022 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file lowers ONNX TransposeOp to TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +class ONNXTransposeLoweringToTOSA + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename ONNXTransposeOp::Adaptor; + LogicalResult matchAndRewrite(ONNXTransposeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + IndexExprBuilderForTosa createTosaIE(rewriter, loc); + ONNXTransposeOpShapeHelper shapeHelper(op, {}, &createTosaIE); + if (shapeHelper.computeShape().failed()) { + return rewriter.notifyMatchFailure(op, "Could not infer shapes"); + } + + TosaBuilder tosaBuilder(rewriter, loc); + + Value input = adaptor.getData(); + + auto inputType = dyn_cast(input.getType()); + + if (!inputType) + return rewriter.notifyMatchFailure(op, "input not a ranked tensor"); + + Type inputElementType = inputType.getElementType(); + + if (!isa(inputElementType) && !isTOSAInt(inputElementType) && + !inputElementType.isInteger(1)) { + return rewriter.notifyMatchFailure( + op, "input element type not supported"); + } + + auto outputType = dyn_cast(op.getResult().getType()); + + if (!outputType) + return rewriter.notifyMatchFailure(op, "output not a ranked tensor"); + + auto permVector = extractFromIntegerArrayAttr(op.getPermAttr()); + // TOSA needs a I32 array + llvm::SmallVector permVectorI32; + permVectorI32.clear(); + llvm::transform(permVector, std::back_inserter(permVectorI32), + [](const auto &valueI64) { return (int32_t)valueI64; }); + + Value transposeOp = tosaBuilder.transpose(input, permVectorI32); + + rewriter.replaceOp(op, transposeOp); + + return success(); + } +}; + +} // namespace + +void populateLoweringONNXTransposeOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToTOSA/Tensor/Where.cpp b/src/Conversion/ONNXToTOSA/Tensor/Where.cpp new file mode 100644 index 0000000000..d2141c94e1 --- /dev/null +++ b/src/Conversion/ONNXToTOSA/Tensor/Where.cpp @@ -0,0 +1,68 @@ +// (c) Copyright 2022 - 2024 Advanced Micro Devices, Inc. All Rights Reserved. + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/PatternMatch.h" +#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" +#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +class ONNXWhereLoweringToTOSA : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXWhereOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value pred = adaptor.getCondition(); + Value lhs = adaptor.getX(); + Value rhs = adaptor.getY(); + + // Check types are compatible + auto predType = dyn_cast(pred.getType()); + auto lhsType = dyn_cast(lhs.getType()); + auto rhsType = dyn_cast(rhs.getType()); + auto resultType = dyn_cast(op->getResultTypes()[0]); + + if (!predType || !lhsType || !rhsType || !resultType) { + return rewriter.notifyMatchFailure(op, "Tosa only supports TensorTypes"); + } + if (lhsType.getElementType() != rhsType.getElementType() || + lhsType.getElementType() != resultType.getElementType()) { + return rewriter.notifyMatchFailure(op, + "Expected element type for X, Y and output to be the same in " + "onnx.Where"); + } + + // Broadcast dimensions + IndexExprBuilderForTosa createTosaIE(rewriter, op->getLoc()); + ONNXBroadcastOpShapeHelper shapeHelper(op, {}, &createTosaIE); + if (shapeHelper.computeShape().succeeded() && + shapeHelper.hasRankBroadcast()) { + TosaBuilder tosaBuilder(rewriter, loc); + llvm::SmallVector newValues = + tosaBuilder.equalizeRanks({pred, lhs, rhs}); + pred = newValues[0]; + lhs = newValues[1]; + rhs = newValues[2]; + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), pred, lhs, rhs); + return success(); + } +}; + +} // namespace + +void populateLoweringONNXWhereOpToTOSAPattern(ConversionTarget &target, + RewritePatternSet &patterns, TypeConverter &typeConverter, + MLIRContext *ctx) { + patterns.insert(ctx); +} +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Dialect/Krnl/CMakeLists.txt b/src/Dialect/Krnl/CMakeLists.txt index 541437da3a..683e4500dc 100644 --- a/src/Dialect/Krnl/CMakeLists.txt +++ b/src/Dialect/Krnl/CMakeLists.txt @@ -20,6 +20,7 @@ add_onnx_mlir_library(OMKrnlOps OMSpecializedKernelOpInterface LINK_LIBS PUBLIC + OMCompilerOptions OMONNXOps MLIRLLVMCommonConversion MLIRAffineDialect diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index 59a624ad57..81b6795913 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/TypeSwitch.h" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -56,83 +57,28 @@ static StringRef getFormat(const Type &inputType) { //====---------------- Support for Krnl Builder ----------------------===// -Value KrnlBuilder::load(Value memref, ValueRange indices) const { - if (indices.size() == 0) { - // case memref<1xdtype> - MemRefType type = dyn_cast_or_null(memref.getType()); - assert(type && "Not MemRefType"); - if (type.getRank() == 1 && type.getShape()[0] == 1) { - MultiDialectBuilder create(*this); - Value iZero = create.math.constantIndex(0); - return b().create(loc(), memref, ValueRange({iZero})); - } - } - return b().create(loc(), memref, indices); -} - -mlir::Value KrnlBuilder::load(mlir::Value memref, mlir::ValueRange indices, - mlir::ValueRange offsets) const { - SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - return load(memref, computedIndices); -} - -Value KrnlBuilder::loadIE(Value memref, ArrayRef indices) const { - if (indices.size() == 0) { - // case memref<1xdtype> - MemRefType type = dyn_cast_or_null(memref.getType()); - assert(type && "Not MemRefType"); - if (type.getRank() == 1 && type.getShape()[0] == 1) { - MultiDialectBuilder create(*this); - Value iZero = create.math.constantIndex(0); - return b().create(loc(), memref, ValueRange({iZero})); - } - } - SmallVector indexValues; - IndexExpr::getValues(indices, indexValues); - return b().create(loc(), memref, indexValues); -} - -void KrnlBuilder::store(Value val, Value memref, ValueRange indices) const { - if (indices.size() == 0) { - // case memref<1xdtype> - MemRefType type = dyn_cast_or_null(memref.getType()); - assert(type && "Not MemRefType"); - if (type.getRank() == 1 && type.getShape()[0] == 1) { - MultiDialectBuilder create(*this); - Value iZero = create.math.constantIndex(0); - b().create(loc(), val, memref, ValueRange({iZero})); - return; - } - } - b().create(loc(), val, memref, indices); -} - -void KrnlBuilder::store(mlir::Value val, mlir::Value memref, - mlir::ValueRange indices, mlir::ValueRange offsets) const { - SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - store(val, memref, computedIndices); -} - -void KrnlBuilder::storeIE( - Value val, Value memref, ArrayRef indices) const { - if (indices.size() == 0) { - // case memref<1xdtype> - MemRefType type = dyn_cast_or_null(memref.getType()); - assert(type && "Not MemRefType"); - if (type.getRank() == 1 && type.getShape()[0] == 1) { - MultiDialectBuilder create(*this); - Value iZero = create.math.constantIndex(0); - b().create(loc(), val, memref, ValueRange({iZero})); - return; - } - } - SmallVector indexValues; - IndexExpr::getValues(indices, indexValues); - b().create(loc(), val, memref, indexValues); +Value KrnlBuilder::load( + Value memref, ValueRange indices, ValueRange offsets) const { + return onnx_mlir::impl::load( + *this, memref, indices, offsets); +} + +Value KrnlBuilder::loadIE( + Value memref, ArrayRef indices, ValueRange offsets) const { + return onnx_mlir::impl::loadIE( + *this, memref, indices, offsets); +} + +void KrnlBuilder::store( + Value val, Value memref, ValueRange indices, ValueRange offsets) const { + onnx_mlir::impl::store( + *this, val, memref, indices, offsets); +} + +void KrnlBuilder::storeIE(Value val, Value memref, ArrayRef indices, + ValueRange offsets) const { + onnx_mlir::impl::storeIE( + *this, val, memref, indices, offsets); } Value KrnlBuilder::getLinearOffsetIndex( @@ -149,25 +95,27 @@ Value KrnlBuilder::getLinearOffsetIndexIE( void KrnlBuilder::prefetch(Value memref, ValueRange indices, bool isWrite, unsigned localityHint, bool isDataCache) { + if (disableMemRefPrefetch) + return; b().create( loc(), memref, indices, isWrite, localityHint, isDataCache); } void KrnlBuilder::prefetchIE(Value memref, ArrayRef indices, bool isWrite, unsigned localityHint, bool isDataCache) { + if (disableMemRefPrefetch) + return; SmallVector indexValues; IndexExpr::getValues(indices, indexValues); b().create( loc(), memref, indexValues, isWrite, localityHint, isDataCache); } -void KrnlBuilder::seqstore( - mlir::Value element, mlir::Value seq, mlir::Value index) const { +void KrnlBuilder::seqstore(Value element, Value seq, Value index) const { b().create(loc(), element, seq, index); } -void KrnlBuilder::seqstore( - mlir::Value element, mlir::Value seq, IndexExpr index) const { +void KrnlBuilder::seqstore(Value element, Value seq, IndexExpr index) const { b().create(loc(), element, seq, index.getValue()); } @@ -176,7 +124,7 @@ Value KrnlBuilder::vectorTypeCast(Value sourceMemref, int64_t vectorLen) const { } void KrnlBuilder::region( - function_ref bodyBuilderFn) const { + function_ref bodyBuilderFn) const { KrnlBuilder createKrnl(b(), loc()); KrnlRegionOp regionOp = b().create(loc()); { @@ -207,25 +155,43 @@ ValueRange KrnlBuilder::getInductionVarValue(ValueRange loops) const { } void KrnlBuilder::parallel(ValueRange loops) const { - b().template create(loc(), loops); + Value noneValue; + StringAttr noneStrAttr; + b().template create(loc(), loops, noneValue, noneStrAttr); +} + +void KrnlBuilder::parallel( + ValueRange loops, Value numThreads, StringAttr procBind) const { + if (procBind.getValue().size() > 0) { + std::string str = procBind.getValue().str(); + assert((str == "primary" || str == "close" || str == "spread") && + "expected primary, close, or spread for proc_bind"); + } + b().template create(loc(), loops, numThreads, procBind); +} + +void KrnlBuilder::parallelClause( + Value parallelLoopIndex, Value numThreads, StringAttr procBind) const { + // No need to check procBind as its value are derived from parallel(...). + b().template create( + loc(), parallelLoopIndex, numThreads, procBind); } void KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops, ValueRange lbs, ValueRange ubs, - function_ref + function_ref bodyBuilderFn) const { - auto bodyBuilderFnWrapper = [&](KrnlBuilder &createKrnl, ValueRange indices, - ValueRange iterArgs) { + auto bodyBuilderFnWrapper = [&](const KrnlBuilder &createKrnl, + ValueRange indices, ValueRange iterArgs) { bodyBuilderFn(createKrnl, indices); }; iterate(originalLoops, optimizedLoops, lbs, ubs, {}, bodyBuilderFnWrapper); } -mlir::KrnlIterateOp KrnlBuilder::iterate(ValueRange originalLoops, +// Deprecated +KrnlIterateOp KrnlBuilder::iterate(ValueRange originalLoops, ValueRange optimizedLoops, ValueRange lbs, ValueRange ubs, ValueRange inits, - function_ref - bodyBuilderFn) const { + KrnlLoopBody2Fn bodyBuilderFn) const { // Check that originalLoops, lbs, and ubs have the same rank. assert(originalLoops.size() == lbs.size() && "expected same rank"); assert(originalLoops.size() == ubs.size() && "expected same rank"); @@ -246,21 +212,18 @@ KrnlIterateOp KrnlBuilder::iterate( void KrnlBuilder::iterateIE(ValueRange originalLoops, ValueRange optimizedLoops, ArrayRef lbs, ArrayRef ubs, - function_ref - bodyBuilderFn) const { - auto bodyBuilderFnWrapper = [&](KrnlBuilder &createKrnl, ValueRange indices, - ValueRange iterArgs) { + KrnlLoopBodyFn bodyBuilderFn) const { + auto bodyBuilderFnWrapper = [&](const KrnlBuilder &createKrnl, + ValueRange indices, ValueRange iterArgs) { bodyBuilderFn(createKrnl, indices); }; iterateIE(originalLoops, optimizedLoops, lbs, ubs, {}, bodyBuilderFnWrapper); } +// Deprecated. KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops, ValueRange optimizedLoops, ArrayRef lbs, ArrayRef ubs, - mlir::ValueRange inits, - function_ref - bodyBuilderFn) const { + ValueRange inits, KrnlLoopBody2Fn bodyBuilderFn) const { // Check that originalLoops, lbs, and ubs have the same rank. assert(originalLoops.size() == lbs.size() && "expected same rank"); assert(originalLoops.size() == ubs.size() && "expected same rank"); @@ -274,228 +237,91 @@ KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops, }); } -/* -Example of how to use the interface: -Say you have a loop of i=0..256, j=0..128 and want to exploit r[i,j] = a[i,j] + -b[j] + c. For the loops, we will need access functions for a, b, and r. - -Say we already have the loop for the outer loop of i - -krnl.iterate(loop i from 0 to 256) { - ii is the loop index. - - // 1) compute access function for a, b, c - // 2) launch simd loop with - // 3) simd kernel -} - -1) Access functions - Assuming here that we are not blocking the j loop, namely the simd iteration - goes over all j values, the access functions should be defined as follows. - - aAF = {ii, 0} - bAF = {0} - rAF = {ii, 0} - - If the j loop was blocked (say j=0 to 128 by 16), then instead of `0` in the - last dim, we would have 'blocked_jj' - -2) Launch simd loop - - create.krnl.simdIterateIE( - lb=LitIE(0), ub=litIE(128), totVL=8, // loop params - fullySimd=true, useParallel=false, // loop options - inputs={A, B}, inputAFs={aAF, bAF}, // inputs - outputs={R}, outputAFs={rAF}, // outputs - krnl) // lambda function for kernel - -3) Krnl for SIMD loop - - The kernel functions has 4 inputs: - a) krnl builder to further build code - b) list of loaded input values, in the same order as in inputs - c) list of results values, that must be enqueued by the kernel - d) totVL used for the loop (VL for simd, 1 for scalar) - - The same kernel will be used in a SIMD context, in which the inputs and - outputs must be vectors of VL elements, or in a scalar context, in which the - inputs and outputs must be scalars. - - In our example, the kernel is as follows - - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { - MultiDialectBuilder create(kb); - Value aVal = inputVals[0]; // simd or scalar - Value bVal = inputVals[1]; // simd or scalar - Value cVal = create.krnl.load(C, {}); // scalar always - Value newVal = create.math.add(aVal, bVal); // simd or scalar - newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is - // splatted - res.emplace_back(newVal); // Save simd or scalar result. - } - - The krnl.simdIterateIE will be in charge of loading and saving the values in - memory. The create.math functions have been extended so that when a SIMD - value is computed with a scalar, that scalar will be automaticaly splatted - (aka promoted to a vector of identical values). As a result, the kernel can - be written in a SIMD agnostic value. However, in rare situations, we may - want to know if we are in SIMD mode or not. VL will give the totVL used here - (either totVL>1 or 1). -*/ - -// Determine if an access has one element from the innermost dimensions up to -// innerDim. -bool static hasOneElementInInnermostDims(Value value, int64_t innerDim) { - // Get info. - ShapedType type = mlir::dyn_cast(value.getType()); - assert(type && "expected shaped type"); - int64_t rank = type.getRank(); - mlir::ArrayRef shape = type.getShape(); - for (int64_t i = std::max((int64_t)0, rank - innerDim); i < rank; ++i) - if (shape[i] != 1) - return false; - return true; +void KrnlBuilder::forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, + bool useParallel, KrnlLoopBodyFn builderFn) const { + ValueRange originalLoopDef = defineLoops(1); + llvm::SmallVector optLoopDef(1, originalLoopDef[0]); + if (step > 1) { + // Block loop by step. + ValueRange blockedLoopDef = block(originalLoopDef[0], step); + optLoopDef[0] = blockedLoopDef[0]; + } + if (useParallel) + parallel(optLoopDef[0]); + iterateIE(originalLoopDef, optLoopDef, {lb}, {ub}, builderFn); +} + +void KrnlBuilder::forLoopsIE(ArrayRef lbs, ArrayRef ubs, + ArrayRef steps, ArrayRef useParallel, + KrnlLoopBodyFn builderFn) const { + impl::forLoopsIE(*this, lbs, ubs, steps, useParallel, builderFn); +} + +void KrnlBuilder::forExplicitParallelLoopIE(IndexExpr lb, IndexExpr ub, + IndexExpr threadNum, KrnlLoopBodyFn builderFn) const { + IndexExpr zero = LitIE(0); + if (threadNum.isLiteralAndIdenticalTo(1)) { + // Sequential case as we have only 1 thread (parallel disabled statically). + llvm::SmallVector params = { + zero.getValue(), lb.getValue(), ub.getValue()}; + builderFn(*this, params); + return; + } + // Compute blockSize: the number of elements of (lb...ub) per thread. + IndexExpr trip = ub - lb; // Expected to be positive, aka ub>lb. + IndexExpr blockSize = trip.ceilDiv(threadNum); + // Explicit parallelism: iterate over all threads 0..threadNum in parallel. + forLoopIE(zero, threadNum, /*step*/ 1, /*parallel*/ true, + [&](const KrnlBuilder &ck, ValueRange loopInd) { + IndexExprScope scope(ck); + IndexExpr t = DimIE(loopInd[0]); + IndexExpr tTimesBlockSize = t * SymIE(blockSize); + IndexExpr currLB = SymIE(lb) + tTimesBlockSize; + IndexExpr currUB = currLB + SymIE(blockSize); + currUB = IndexExpr::min(currUB, SymIE(ub)); + // Passes the thread ID, its lower bound, and its upper bound. + llvm::SmallVector params = { + t.getValue(), currLB.getValue(), currUB.getValue()}; + builderFn(ck, params); + }); } void KrnlBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, ArrayRef inputs, ArrayRef inputAFs, ArrayRef outputs, ArrayRef outputAFs, - function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) { - int64_t inputNum = inputs.size(); - assert(inputAFs.size() == inputs.size() && "expected same size"); - int64_t outputNum = outputs.size(); - assert(outputAFs.size() == outputs.size() && "expected same size"); - MultiDialectBuilder create(*this); - - if (VL > 1) { - // Want SIMD, execute full SIMD loops blocked by VL. - ValueRange loopDef = defineLoops(1); - ValueRange blockedLoopDef = block(loopDef[0], VL); - if (useParallel) - parallel({blockedLoopDef[0]}); - - // If we are not guaranteed that every iterations are SIMD iterations, then - // we need to reduce the trip count by a bit so as to not over compute. - // If we are not guaranteed that every iterations are SIMD iterations, then - IndexExpr simdUb = ub; - if (!fullySimd) - simdUb = simdUb - (VL - 1); - iterateIE(loopDef, {blockedLoopDef[0]}, {lb}, {simdUb}, - [&](KrnlBuilder &ck, ValueRange loopInd) { - IndexExprScope scope(ck); - MultiDialectBuilder create(ck); - IndexExpr ind = DimIE(loopInd[0]); - // Load all the inputs as vectors of VL values, with a few exceptions. - // One is if the value is a "none value", leave as is. Another one is - // if the innermost dim is a scalar (ie dim[rank-1] == 1), then we - // just load the scalar. - llvm::SmallVector vecInputVals; - for (int64_t i = 0; i < inputNum; ++i) { - Value input = inputs[i]; - if (isNoneValue(input)) { - // Simply enqueue the none value. - vecInputVals.emplace_back(input); - continue; - } - MemRefType type = mlir::cast(input.getType()); - int64_t rank = type.getRank(); - DimsExpr AF = SymListIE(inputAFs[i]); - assert(rank == (int64_t)AF.size() && "AF expected input rank refs"); - if (hasOneElementInInnermostDims(input, 1)) { - // Has a reference with a scalar innermost dim, just load as a - // scalar. No need to add the induction variable. - Value scalarVal = create.krnl.loadIE(input, AF); - vecInputVals.emplace_back(scalarVal); - } else { - // Have a vector. - VectorType vecType = VectorType::get({VL}, type.getElementType()); - AF[rank - 1] = AF[rank - 1] + ind; // Add induction var. - Value vecVal = create.vec.loadIE(vecType, input, AF, {}); - vecInputVals.emplace_back(vecVal); - } - } - // Call the method to compute the values. - llvm::SmallVector vecResVals; - bodyBuilderFn(create.krnl, vecInputVals, vecResVals, VL); - assert((int64_t)vecResVals.size() == outputNum && - "loop body with incorrect number of results"); - // Store all the outputs as vectors of VL values, - for (int64_t i = 0; i < outputNum; ++i) { - MemRefType type = mlir::cast(outputs[i].getType()); - DimsExpr AF = SymListIE(outputAFs[i]); - int64_t rank = type.getRank(); - assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); - AF[rank - 1] = AF[rank - 1] + ind; - create.vec.storeIE(vecResVals[i], outputs[i], AF, {}); - } - }); - if (fullySimd) - // Asserted that we only have SIMD iterations, we are done. - return; - // Account for the loop iterations performed above. - IndexExpr tripCount = ub - lb; - IndexExpr missingIters = tripCount % VL; - IndexExpr completedIters = tripCount - missingIters; - if (missingIters.isLiteralAndIdenticalTo(0)) { - // Detect that we only have SIMD iterations, we are also done. - return; - } - // We may have additional iterations to perform, adjust lb to skip the - // completed iterations. - lb = lb + completedIters; - } - // Handle remaining scalar values (from lb to ub without unrolling). - ValueRange loopDef = defineLoops(1); - iterateIE( - loopDef, loopDef, {lb}, {ub}, [&](KrnlBuilder &ck, ValueRange loopInd) { - IndexExprScope scope(ck); - MultiDialectBuilder create(ck); - IndexExpr ind = DimIE(loopInd[0]); - // Load all the inputs as scalar values, - llvm::SmallVector scalarInputVals; - for (int64_t i = 0; i < inputNum; ++i) { - Value input = inputs[i]; - if (isNoneValue(input)) { - // Simply enqueue the none value. - scalarInputVals.emplace_back(input); - continue; - } - MemRefType type = mlir::cast(input.getType()); - int64_t rank = type.getRank(); - DimsExpr AF = SymListIE(inputAFs[i]); - if (hasOneElementInInnermostDims(input, 1)) { - // Has a reference with a scalar innermost dim, just load as a - // scalar. No need to add the induction variable. - Value scalarVal = create.krnl.loadIE(input, AF); - scalarInputVals.emplace_back(scalarVal); - } else { - AF[rank - 1] = AF[rank - 1] + ind; - Value scalarVal = create.krnl.loadIE(input, AF); - scalarInputVals.emplace_back(scalarVal); - } - } - // Call the method to compute the values. - llvm::SmallVector scalarResVals; - bodyBuilderFn(create.krnl, scalarInputVals, scalarResVals, /*VL*/ 1); - assert((int64_t)scalarResVals.size() == outputNum && - "loop body with incorrect number of results"); - // Store all the outputs as vectors of VL values, - for (int64_t i = 0; i < outputNum; ++i) { - MemRefType type = mlir::cast(outputs[i].getType()); - DimsExpr AF = SymListIE(outputAFs[i]); - int64_t rank = type.getRank(); - assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); - AF[rank - 1] = AF[rank - 1] + ind; - create.krnl.storeIE(scalarResVals[i], outputs[i], AF); - } - }); -} - -void KrnlBuilder::yield(mlir::ValueRange iterArgs) const { + ArrayRef iterateBodyFnList) const { + onnx_mlir::impl::simdIterateIE(*this, lb, ub, VL, + fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs, + iterateBodyFnList); +} + +void KrnlBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, ArrayRef inputs, ArrayRef inputAFs, + ArrayRef tmps, ArrayRef tmpAFs, ArrayRef outputs, + ArrayRef outputAFs, ArrayRef initVals, + /* reduction function (simd or scalar) */ + ArrayRef reductionBodyFnList, + /* post reduction function (simd to scalar + post processing)*/ + ArrayRef postReductionBodyFnList) const { + onnx_mlir::impl::simdReduceIE(*this, lb, ub, VL, + fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals, + reductionBodyFnList, postReductionBodyFnList); +} + +void KrnlBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, Value input, DimsExpr inputAF, Value tmp, DimsExpr tmpAF, + Value output, DimsExpr outputAF, Value initVal, + /* reduction functions (simd or scalar) */ + KrnlSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + KrnlSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE(*this, lb, ub, VL, + fullySimd, input, inputAF, tmp, tmpAF, output, outputAF, initVal, + reductionBodyFn, postReductionBodyFn); +} + +void KrnlBuilder::yield(ValueRange iterArgs) const { b().create(loc(), iterArgs); } @@ -558,11 +384,62 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name, alignment.value_or(nullptr)); } +//===----------------------------------------------------------------------===// +// Math style functions. + +// Keep code gen here in sync with Elementwise.cpp GenOpMix +// getGenOpMix +Value KrnlBuilder::roundEven(Value input) const { + Type elementType = getElementTypeOrSelf(input.getType()); + MultiDialectBuilder create(*this); + VectorType vecType = mlir::dyn_cast(input.getType()); + if (VectorMachineSupport::requireCustomASM( + GenericOps::roundEvenGop, elementType)) { + // Use Krnl round even op as LLVM does not support roundEven. + if (!vecType) + // Scalar. + return b().create(loc(), input.getType(), input); + + // Vector, enable unrolling of multiple archVL. + int64_t archVL = VectorMachineSupport::getArchVectorLength( + GenericOps::roundEvenGop, elementType); + assert(archVL > 1 && "expected vector with archVL>1"); + assert(vecType.getRank() == 1 && "1D vec only"); + int64_t vecSize = vecType.getShape()[0]; + assert(vecSize % archVL == 0 && "expected multiple of archVL"); + int64_t numArchVec = vecSize / archVL; + VectorType vecType2D = VectorType::get({numArchVec, archVL}, elementType); + // Cast input vector to a vector of chunks (archVL values that can be + // handled by one hardware SIMD instruction). + Value input2D = create.vec.shapeCast(vecType2D, input); + Value output2D = input2D; + // Iterates over all hardware SIMD chunks. + for (int64_t i = 0; i < numArchVec; ++i) { + // Extract one chunk, compute new value, insert result in corresponding + // output 2D vector. + Value subInput = create.vec.extractFrom2D(input2D, i); + Value subOutput = + b().create(loc(), subInput.getType(), subInput); + output2D = create.vec.insertInto2D(subOutput, output2D, i); + } + // Recast output 2D vector into the flat vector (same shape as input). + return create.vec.shapeCast(vecType, output2D); + } + // No need for custom support, use math roundEven. May want to evaluate + // whether to use the mlir roundEven or our own emulation. + // Note: MacOS CI has an issue with the roundEven instruction, thus continue + // to use emulation. May change in the future. + return create.math.roundEvenEmulation(input); +} + +//===----------------------------------------------------------------------===// +// C library functions. + void KrnlBuilder::memcpy(Value dest, Value src, Value numElems) const { MultiDialectBuilder create(*this); Value zero = create.math.constantIndex(0); - b().create( - loc(), dest, src, numElems, /*dest_offset=*/zero, /*src_offset=*/zero); + b().create(loc(), dest, src, numElems, + /*dest_offset=*/zero, /*src_offset=*/zero); } void KrnlBuilder::memcpy(Value dest, Value src, Value numElems, @@ -625,13 +502,12 @@ void KrnlBuilder::printf( // ============================================================================= // Return null if none is found. -ElementsAttr IndexExprBuilderForKrnl::getConst(mlir::Value value) { +ElementsAttr IndexExprBuilderForKrnl::getConst(Value value) { auto definingOp = value.getDefiningOp(); - if (auto globalOp = dyn_cast_or_null(definingOp)) { + if (auto globalOp = dyn_cast_or_null(definingOp)) { if (globalOp.getValue().has_value()) return mlir::dyn_cast(globalOp.getValueAttr()); - } else if (auto globalOp = - dyn_cast_or_null(definingOp)) { + } else if (auto globalOp = dyn_cast_or_null(definingOp)) { if (globalOp.getValue().has_value()) return mlir::dyn_cast(globalOp.getValueAttr()); } @@ -642,7 +518,7 @@ Value IndexExprBuilderForKrnl::getVal(Value intArrayVal, uint64_t i) { MultiDialectBuilder create(*this); uint64_t rank = getShapedTypeRank(intArrayVal); if (rank == 0) - return create.krnl.load(intArrayVal, {}); + return create.krnl.load(intArrayVal); uint64_t size = getArraySize(intArrayVal); assert(i < size && "out of bound reference"); Value iVal = create.math.constantIndex(i); diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index 85c40f9b9d..4673ba2a98 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -30,19 +30,16 @@ struct KrnlBuilder : public DialectBuilder { KrnlBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~KrnlBuilder() {} - mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}) const; - // When ranks of offsets indices) const; - void store( - mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}) const; - // When ranks of offsets indices = {}, + mlir::ValueRange offsets = {}) const; + void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, + mlir::ValueRange offsets = {}) const; void storeIE(mlir::Value val, mlir::Value memref, - mlir::ArrayRef indices) const; + mlir::ArrayRef indices, mlir::ValueRange offsets = {}) const; // Get linear offset for given memref at given index values. mlir::Value getLinearOffsetIndex( @@ -62,79 +59,155 @@ struct KrnlBuilder : public DialectBuilder { mlir::Value vectorTypeCast(mlir::Value sourceMemref, int64_t vectorLen) const; void region( - mlir::function_ref bodyBuilderFn) const; + mlir::function_ref bodyBuilderFn) + const; mlir::ValueRange defineLoops(int64_t originalLoopNum) const; mlir::ValueRange block(mlir::Value loop, int64_t blockSize) const; void permute(mlir::ValueRange loops, mlir::ArrayRef map) const; mlir::ValueRange getInductionVarValue(mlir::ValueRange loops) const; void parallel(mlir::ValueRange loops) const; + void parallel(mlir::ValueRange loops, mlir::Value numThreads, + mlir::StringAttr procBind) const; + void parallelClause(mlir::Value parallelLoopIndex, mlir::Value numThreads, + mlir::StringAttr procBind) const; // Iterate over optimized loops given the original loops, lbs and ubs. Lambda // function implement the body of the loop, and receive a KRNL builder and the // loop indices. + using KrnlLoopBodyFn = impl::LoopBodyFn; + using KrnlLoopBody2Fn = mlir::function_ref; + void iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ValueRange lbs, mlir::ValueRange ubs, - mlir::function_ref - bodyBuilderFn) const; + KrnlLoopBodyFn bodyBuilderFn) const; + // Deprecated. mlir::KrnlIterateOp iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ValueRange lbs, mlir::ValueRange ubs, mlir::ValueRange inits, - mlir::function_ref - bodyBuilderFn) const; + KrnlLoopBody2Fn bodyBuilderFn) const; + mlir::KrnlIterateOp iterate( const krnl::KrnlIterateOperandPack &operands) const; // Same versions with Index Expressions for bounds. void iterateIE(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ArrayRef lbs, - mlir::ArrayRef ubs, - mlir::function_ref - bodyBuilderFn) const; + mlir::ArrayRef ubs, KrnlLoopBodyFn bodyBuilderFn) const; + // Deprecated. mlir::KrnlIterateOp iterateIE(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ArrayRef lbs, mlir::ArrayRef ubs, mlir::ValueRange inits, - mlir::function_ref - bodyBuilderFn) const; - - // Iterate over a loop executing the loop body in SIMD mode (of vector length - // VL) from lb to ub. A scalar loop may execute up to VL-1 loop - // iterations when the trip count is not a multiple of VL. If fullySimd is - // true, then the call assumes that the trip count is a multiple of VL. - // - // This call needs be given each of the memref inputs to the loop body, given - // as an ordered pair memref value and its corresponding access function. Same - // hold for all the memref outputs of the loop body. - // - // The loop body is given a KRNL builder, a list of loaded input (same order - // as the input's memrefs and access functions). It will generate values that - // must be placed in the result list in the same order as the output's memrefs - // and access functions. - // - // It will be the responsibility of this call to load each of the inputs and - // store each of the outputs. When operating in SIMD mode, every input and - // output values are vectors of length VL. In scalar mode, they are simply - // scalar values. - // - // SIMD is exploited in the innermost dimension of each access function. - // This call is only applicable to loop bodies where every input/output is - // strided in its innermost dimension. Inputs can also be loop invariant - // (scalar), in term of the loop being iterated on. - // - // If useParallel is true, then the blocked SIMD loop is executed in parallel. - + KrnlLoopBody2Fn bodyBuilderFn) const; + + // Common loop interface (krnl/affine/scf). + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, + KrnlLoopBodyFn builderFn) const; + void forLoopsIE(mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, mlir::ArrayRef useParallel, + KrnlLoopBodyFn builderFn) const; + + // Loop with explicit parallelism. Loop body is invoked on each parallel + // thread with its threadID (0..threadNum-1) and its corresponding lb and ub + // (using static schedule). When threadNum==1 (compile time literal), we + // simply call the builderFn for the entire range as there is no + // parallelism, namely we call builderFn(builder, {0, lb, ub}). + void forExplicitParallelLoopIE(IndexExpr lb, IndexExpr ub, + IndexExpr threadNum, KrnlLoopBodyFn builderFn) const; + + // Common simd loop interface (krnl/affine/scf). + /* + Iterate over a loop executing the loop body in SIMD mode (of vector length + VL) from lb to ub. A scalar loop may execute up to VL-1 loop + iterations when the trip count is not a multiple of VL. If fullySimd is + true, then the call assumes that the trip count is a multiple of VL. + + This simdIterateIE needs be given each of the memref inputs to the loop + body, given as an ordered pair memref value and its corresponding access + function. Same hold for all the memref outputs of the loop body. + + The loop body is constructed by calling each of the KrnlSimdIterateBodyFn + given in the list. Each function is responsible for returning one output + value. The returned values are eventually stored in the output memrefs at a + location given by its respective output access function. + + To generate their output, each KrnlSimdIterateBodyFn function is given + a KRNL builder, a list of loaded input (same order + as the input's memrefs and access functions), and the current VectorLength + (VL). VL is either the original VL or 1 (when executing in scalar mode). + + It will be the responsibility of this call to load each of the inputs and + store each of the outputs. When operating in SIMD mode, every input and + output values are vectors of length VL. In scalar mode, they are simply + scalar values. + + SIMD is exploited in the innermost dimension of each access function. + This call is only applicable to loop bodies where every outputs are + strided in its innermost dimension. Inputs can also be loop invariant + (scalar), in term of the loop being iterated on. + + If useParallel is true, then the blocked SIMD loop is executed in parallel. + + A detailed example of how to use if found in + Dialect/Mlir/DialectBuilder.hpp.inc. + */ + + using KrnlSimdIterateBodyFn = impl::SimdIterateBodyFn; void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, - mlir::function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn); + mlir::ArrayRef bodyBuilderFnList) const; + + /* + Works similarly as simdIterateIE, but performs a reduction to a single + scalar per output value. Inputs must be strided in their innermost + dimensions. Temps are used to hold the temporary results (partial results + per SIMD lane), and the outputs have the scalar reduction outputs + + Two function lists are given: a list of reductionBodyFn to perform the + partial reductions into the temporary values tmps, finishing with up to VL + partial reductions The second list of postReductionBodyFn perform the + reductions of the up to VL partial reductions into a final scalar reduction + to be stored into the outputs (a scalar value). For some reductions, post + processing is also needed, for example, mean reduction divide the + accumulated sum by the number of elements. That step is also performed + here. + */ + using KrnlSimdReductionBodyFn = impl::SimdReductionBodyFn; + using KrnlSimdPostReductionBodyFn = + impl::SimdPostReductionBodyFn; + + void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef tmps, mlir::ArrayRef tmpAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef reductionBodyFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef postReductionBodyFnList) + const; + + /* + Same as simdReduceIE, but perform VL reductions at once. It expect at least + VL iterations in the second to last dimension of inputs/outputs. + + Unlike simdReduceIE, the second function is for post processing only. In + simdReduceIE, that function was also used to reduce the SIMD temporary + reduction into a single scalar. + + Also, at this time, simdReduce2DIE process only one reduction at a time, + whereas simdReduceIE could process an arbitrary number of reductions. + */ + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + KrnlSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + KrnlSimdPostReductionBodyFn postReductionBodyFn) const; void yield(mlir::ValueRange iterArgs) const; @@ -202,6 +275,9 @@ struct KrnlBuilder : public DialectBuilder { std::optional offset = std::nullopt, std::optional alignment = std::nullopt) const; + // Math style functions + mlir::Value roundEven(mlir::Value input) const; + // C library functions. void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems) const; void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems, @@ -218,6 +294,8 @@ struct KrnlBuilder : public DialectBuilder { mlir::StringRef msg, IndexExpr input, bool endsWithNewLine = false) const; void printf(mlir::StringRef msg, mlir::Value input, mlir::Type inputType, bool endsWithNewLine = false) const; + // Use "%s" for signature, "%t" for detailed type, "%d" for data, "%e" for end + // of string (recommended). If no "%X" pattern is given, we assume "%s%d". void printTensor(mlir::StringRef msg, mlir::Value input) const; // Onnx-mlir runtime functions. diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index 1d89b46f1e..c8220dfc53 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -89,6 +89,13 @@ def KrnlCallOp : Op:$numOfOutput, Variadic:$parameters); + // Return Value for the Call. + // No return if the type is NoneType (void in llvm) + // Only scalar type is supported now. + // In future, return of memref can be supported with pointer of OMTensor. + // The returned memref will be created inside the call. + let results = (outs Variadic>:$returnValue); + // builders to build KrnlCallOp from op and operands, helping conversion from // onnx to krnl. // The name of function can be determined by the op name and elemnt type of @@ -96,6 +103,8 @@ def KrnlCallOp : Op, + OpBuilder<(ins "mlir::StringAttr":$funcNameStr, "IntegerAttr":$numOfOutput, "mlir::ValueRange":$operands)>, OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector":$attributeNames)>, OpBuilder<(ins "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "bool":$copyAttrs)>, OpBuilder<(ins "std::string":$funcNameStr, "mlir::ValueRange":$results, "mlir::Operation *":$op, "mlir::ValueRange":$operands, "std::vector":$attributeNames)>, @@ -514,7 +523,7 @@ def KrnlUnrollOp : Op { }]; } -def KrnlParallelOp : Op { +def KrnlParallelOp : Op { let summary = "Mark Krnl loops as parallel loops"; let description = [{ Parallelize the specified loops. When multiple loop specifiers are passed @@ -522,18 +531,53 @@ def KrnlParallelOp : Op { krnl.parallel should be placed as the last operator before krnl.iterate, Since we do not want to parallelize the loop until we interpret krnl.block, krnl.permute and krnl.unroll. + + Optionally, a value may specifiy the number of threads requested for the + parallel loop. A proc_bind string may also be specified; valid values are + "primary", "close", or "spread". Default values are used when not specified. + ``` krnl.parallel (%i0, %i1) : !Krnl.loop, !Krnl.loop ``` }]; - let arguments = (ins Variadic:$loops); + let arguments = (ins Variadic:$loops, + Optional:$num_threads, + OptionalAttr:$proc_bind); let assemblyFormat = [{ - `(` $loops `)` attr-dict `:` type($loops) + `(` $loops `)` (`,` `num_threads` `(` $num_threads^ `)`)? attr-dict `:` type($loops) }]; } +def KrnlParallelClauseOp : Op { + let summary = "Attach OpenMP clauses to an index varialbe"; + let description = [{ + Attach OpenMP clauses to an index variable. That index variable + is used to uniquely associate a parallel loop with its clauses. + }]; + + let arguments = (ins Index: $parallel_loop_index, + Optional:$num_threads, + OptionalAttr:$proc_bind); + + let assemblyFormat = [{ + `(` $parallel_loop_index `)` (`,` `num_threads` `(` $num_threads^ `)`)? + attr-dict `:` type($parallel_loop_index) + }]; +} + +def KrnlRoundEvenOp : Op { + let summary = "Krnl round to nearest even operation"; + let description = [{ + Krnl round to nearest even operation. Accept scalar or vector float values. + Vector must be 1D of a size that is a multiple of the hardware vector size. + }]; + + let arguments = (ins FloatLike:$in); + let results = (outs FloatLike:$out); +} + def KrnlErfOp : Op { let summary = "Krnl erf scalar operation"; let description = [{ diff --git a/src/Dialect/Krnl/KrnlOps.cpp b/src/Dialect/Krnl/KrnlOps.cpp index 4036a15658..cec7b2d94d 100644 --- a/src/Dialect/Krnl/KrnlOps.cpp +++ b/src/Dialect/Krnl/KrnlOps.cpp @@ -4,7 +4,7 @@ //===---------------------- KrnlOps.cpp - Krnl Operations -----------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -154,12 +154,22 @@ void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, build(builder, odsState, funcNameStr, resultVals, op, operands, copyAttrs); } +void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, + std::string funcName, int64_t numOfOutput, ValueRange operands) { + build(builder, odsState, {}, funcName, numOfOutput, operands); +} + +void KrnlCallOp::build(OpBuilder &builder, ::mlir::OperationState &odsState, + StringAttr funcName, IntegerAttr numOfOutput, ValueRange operands) { + build(builder, odsState, {}, funcName, numOfOutput, operands); +} + void KrnlCallOp::getEffects( SmallVectorImpl> &effects) { for (size_t i = 0; i < getParameters().size(); i++) { - if (i < (size_t)getNumOfOutput()) + if (i < static_cast(getNumOfOutput())) effects.emplace_back(MemoryEffects::Write::get(), &getParametersMutable()[i], SideEffects::DefaultResource::get()); else @@ -596,7 +606,7 @@ ParseResult KrnlIterateOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -::llvm::SmallVector KrnlIterateOp::getLoopRegions() { +::llvm::SmallVector KrnlIterateOp::getLoopRegions() { return {&getBodyRegion()}; } @@ -669,7 +679,8 @@ void KrnlPermuteOp::build(::mlir::OpBuilder &odsBuilder, assert(rank >= 2 && "permute needs 2 or more loops"); assert(odsMap.size() == rank && "loop and size size must be identical"); for (unsigned int i = 0; i < rank; ++i) { - assert(odsMap[i] >= 0 && odsMap[i] < (int64_t)rank && "bad permute"); + assert(odsMap[i] >= 0 && odsMap[i] < static_cast(rank) && + "bad permute"); for (unsigned int j = i + 1; j < rank; ++j) assert( odsMap[i] != odsMap[j] && "map should be a strict permute pattern"); diff --git a/src/Dialect/Mlir/CMakeLists.txt b/src/Dialect/Mlir/CMakeLists.txt index 80c45ea5a2..a9030a80fd 100644 --- a/src/Dialect/Mlir/CMakeLists.txt +++ b/src/Dialect/Mlir/CMakeLists.txt @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMMlirDialects OMSpecializedKernelOpInterface LINK_LIBS PUBLIC + # OMCompilerOptions MLIRMathDialect MLIRAffineDialect MLIRSCFDialect diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index d04c2a8b83..901c3cf44b 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -25,6 +25,8 @@ #include "llvm/Support/Debug.h" // Please do not add dependences on ONNX or KRNL dialects. +// Disabled due to downstream linking issues +// #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/Mlir/VectorMachineSupport.hpp" @@ -278,6 +280,19 @@ Value MathBuilder::rem(Value lhs, Value rhs) const { Value MathBuilder::round(Value x) const { Type type = x.getType(); assert(isScalarOrVectorFloat(type) && "expected float"); + return b().create(loc(), x); +} + +Value MathBuilder::roundEven(Value x) const { + Type type = x.getType(); + assert(isScalarOrVectorFloat(type) && "expected float"); + return b().create(loc(), x); +} + +Value MathBuilder::roundEvenEmulation(Value x) const { + Type type = x.getType(); + assert(isScalarOrVectorFloat(type) && "expected float"); + // Use algorithm originally posted in ONNXtoKRNL/Math/Elementwise.cpp // lowering. @@ -320,7 +335,7 @@ Value MathBuilder::round(Value x) const { return select(rEqualHalf, y2, y1); } -Value MathBuilder::copySign(mlir::Value rem, mlir::Value dividend) const { +Value MathBuilder::copySign(Value rem, Value dividend) const { splatToMatch(rem, dividend); assert(rem.getType() == dividend.getType() && "expected same type"); if (isScalarOrVectorFloat(rem)) @@ -603,7 +618,7 @@ Value MathBuilder::constant(Type type, double val) const { b().create(loc(), b().getF64FloatAttr(val)); }) .Case([&](IntegerType elementType) { - assert(val == (int64_t)val && "value is ambiguous"); + assert(val == static_cast(val) && "value is ambiguous"); unsigned width = elementType.getWidth(); if (width == 1) @@ -614,11 +629,13 @@ Value MathBuilder::constant(Type type, double val) const { if (elementType.isUnsignedInteger()) { Type signlessTy = b().getIntegerType(width); constant = b().create(loc(), - b().getIntegerAttr(signlessTy, APInt(width, (int64_t)val))); + b().getIntegerAttr(signlessTy, + APInt(width, static_cast(val), false, true))); constant = castToUnsigned(constant, width); } else { constant = b().create(loc(), - b().getIntegerAttr(elementType, APInt(width, (int64_t)val))); + b().getIntegerAttr(elementType, + APInt(width, static_cast(val), false, true))); } } }) @@ -643,7 +660,7 @@ Value MathBuilder::constantIndex(int64_t val) const { return b().create(loc(), constantAttr); } -TypedAttr MathBuilder::negativeInfAttr(mlir::Type type) const { +TypedAttr MathBuilder::negativeInfAttr(Type type) const { TypedAttr attr; TypeSwitch(type) .Case([&](Type) { @@ -681,14 +698,14 @@ TypedAttr MathBuilder::negativeInfAttr(mlir::Type type) const { default: llvm_unreachable("unsupported element type"); } - attr = b().getIntegerAttr(type, APInt(width, value)); + attr = b().getIntegerAttr(type, APInt(width, value, false, true)); }) .Default([](Type) { llvm_unreachable("unsupported element type"); }); assert(attr != nullptr && "Expecting valid attribute"); return attr; } -TypedAttr MathBuilder::positiveInfAttr(mlir::Type type) const { +TypedAttr MathBuilder::positiveInfAttr(Type type) const { TypedAttr attr; TypeSwitch(type) .Case([&](Type) { @@ -726,7 +743,7 @@ TypedAttr MathBuilder::positiveInfAttr(mlir::Type type) const { default: llvm_unreachable("unsupported element type"); } - attr = b().getIntegerAttr(type, APInt(width, value)); + attr = b().getIntegerAttr(type, APInt(width, value, false, true)); }) .Default([](Type) { llvm_unreachable("unsupported element type"); }); assert(attr != nullptr && "Expecting valid attribute"); @@ -885,6 +902,46 @@ Value MathBuilder::cast(Type destType, Value src) const { LLVM_DEBUG(llvm::dbgs() << "srcType: " << srcType << "\n"; llvm::dbgs() << "destType: " << destType << "\n";); + // Before we process with the actual cast, there is a special case that we + // want to handle here. Cast from float to int that have different width, llvm + // generate better patterns if we first cast from float to int of the same + // width, and then from int to a different size int. + // Skip that optimization if the result is a 1 bit (boolean). + if (mlir::isa(srcElemType) && + mlir::isa(destElemType) && bitTrunc && destElemWidth > 1) { + // Quantization: float to smaller int. First determine the intermediary + // type, same integer type as destination type, with the same type width as + // the source float type. + Type step1ElementType; + IntegerType destIntType = mlir::cast(destElemType); + bool destIssSigned = destIntType.isSignless() || destIntType.isSigned(); + if (destIssSigned) + step1ElementType = b().getIntegerType(srcElemWidth); + else + step1ElementType = b().getIntegerType(srcElemWidth, false); + // Perform (recursively) the 2 step conversion. Exceptionally ok here to use + // element type here as cast will promote it to a vector if src is a vector. + Value step1Val = cast(step1ElementType, src); + return cast(destType, step1Val); + } + if (mlir::isa(srcElemType) && + mlir::isa(destElemType) && bitExtend) { + // Dequantization: small int to a float. First determine the intermediary + // type, same integer type as source type, with the same type width as + // the destination float type. + Type step1ElementType; + IntegerType srcIntType = mlir::cast(srcElemType); + bool srcIssSigned = srcIntType.isSignless() || srcIntType.isSigned(); + if (srcIssSigned) + step1ElementType = b().getIntegerType(destElemWidth); + else + step1ElementType = b().getIntegerType(destElemWidth, false); + // Perform (recursively) the 2 step conversion. Exceptionally ok here to use + // element type here as cast will promote it to a vector if src is a vector. + Value step1Val = cast(step1ElementType, src); + return cast(destType, step1Val); + } + // Handle boolean first because they need special handling. // Boolean to int/float conversions. Boolean are unsigned. if (srcElemType.isInteger(1)) { @@ -1015,9 +1072,8 @@ Value MathBuilder::castToIndex(Value src) const { // Add offsets to least significant values in indices. So if indices has 4 // values, (i, j, k, l) and offsets has 2 values (K, L), the results will be (i, // j, k+K, l+L). -void MathBuilder::addOffsetToLeastSignificant(mlir::ValueRange indices, - mlir::ValueRange offsets, - llvm::SmallVectorImpl &computedIndices) const { +void MathBuilder::addOffsetToLeastSignificant(ValueRange indices, + ValueRange offsets, llvm::SmallVectorImpl &computedIndices) const { int64_t indexRank = indices.size(); int64_t offsetRank = offsets.size(); int64_t firstOffset = indexRank - offsetRank; @@ -1033,7 +1089,7 @@ void MathBuilder::addOffsetToLeastSignificant(mlir::ValueRange indices, } } -void MathBuilder::addOffsetToLeastSignificant(mlir::ArrayRef indices, +void MathBuilder::addOffsetToLeastSignificant(ArrayRef indices, ValueRange offsets, llvm::SmallVectorImpl &computedIndices) const { SmallVector indexValues; IndexExpr::getValues(indices, indexValues); @@ -1084,8 +1140,7 @@ IntegerAttr MemRefBuilder::computeAlignment(int64_t alignment) const { // values from the list of index expressions that represent the shape of the // memref. -void MemRefBuilder::computeDynSymbols(MemRefType type, - llvm::SmallVectorImpl &dims, +void MemRefBuilder::computeDynSymbols(MemRefType type, DimsExprRef dims, llvm::SmallVectorImpl &dynSymbols) const { dynSymbols.clear(); int64_t rank = type.getRank(); @@ -1110,6 +1165,33 @@ void MemRefBuilder::computeDynSymbols(Value operandOfSameType, MemRefType type, dynSymbols.emplace_back(dim(operandOfSameType, i)); } +//===----------------------------------------------------------------------===// +// Load Store ops. + +Value MemRefBuilder::load( + Value memref, ValueRange indices, ValueRange offsets) const { + return onnx_mlir::impl::load( + *this, memref, indices, offsets); +} +Value MemRefBuilder::loadIE( + Value memref, ArrayRef indices, ValueRange offsets) const { + return onnx_mlir::impl::loadIE( + *this, memref, indices, offsets); +} + +// Add offsets (if any) to the least significant memref dims. +void MemRefBuilder::store( + Value val, Value memref, ValueRange indices, ValueRange offsets) const { + onnx_mlir::impl::store( + *this, val, memref, indices, offsets); +} + +void MemRefBuilder::storeIE(Value val, Value memref, + ArrayRef indices, ValueRange offsets) const { + onnx_mlir::impl::storeIE( + *this, val, memref, indices, offsets); +} + //===----------------------------------------------------------------------===// // Alloc functions without alignment. @@ -1133,8 +1215,7 @@ memref::AllocOp MemRefBuilder::alloc( return alloc(type, dynSymbols); } -memref::AllocOp MemRefBuilder::alloc( - MemRefType type, llvm::SmallVectorImpl &dims) const { +memref::AllocOp MemRefBuilder::alloc(MemRefType type, DimsExprRef dims) const { llvm::SmallVector dynSymbols; computeDynSymbols(type, dims, dynSymbols); return alloc(type, dynSymbols); @@ -1169,8 +1250,8 @@ memref::AllocOp MemRefBuilder::alignedAlloc( return alignedAlloc(type, dynSymbols, alignment); } -memref::AllocOp MemRefBuilder::alignedAlloc(MemRefType type, - llvm::SmallVectorImpl &dims, int64_t alignment) const { +memref::AllocOp MemRefBuilder::alignedAlloc( + MemRefType type, DimsExprRef dims, int64_t alignment) const { llvm::SmallVector dynSymbols; computeDynSymbols(type, dims, dynSymbols); return alignedAlloc(type, dynSymbols, alignment); @@ -1228,9 +1309,9 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, Type elementType = type.getElementType(); assert(!(mlir::isa(elementType)) && "unsupported vector type"); ArrayRef shape = type.getShape(); - staticSize = 1; // Multiplication of static sizes. - dynSize = LiteralIndexExpr(1); // Multiplication of dyn sizes. - bool staticShape = true; // Static until proven otherwise. + staticSize = 1; // Multiplication of static sizes. + dynSize = LitIE(1); // Multiplication of dyn sizes. + bool staticShape = true; // Static until proven otherwise. int64_t rank = type.getRank(); // Process with range [lb inclusive, ub exclusive) int64_t lb = 0, ub = rank; @@ -1253,7 +1334,7 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, if (i >= lb && i < ub) { // Keep track of static shape and dynamic sizes only when inbounds. staticShape = false; - dynSize = dynSize * SymbolIndexExpr(dynSymbols[iDim]); + dynSize = dynSize * SymIE(dynSymbols[iDim]); } iDim++; } else { @@ -1268,8 +1349,8 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, } bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, - llvm::SmallVectorImpl &dims, int64_t &staticSize, - IndexExpr &dynSize, int64_t range) const { + DimsExprRef dims, int64_t &staticSize, IndexExpr &dynSize, + int64_t range) const { llvm::SmallVector dynSymbols; computeDynSymbols(type, dims, dynSymbols); return getStaticAndDynamicMemSize( @@ -1280,7 +1361,7 @@ bool MemRefBuilder::getStaticAndDynamicMemSize(MemRefType type, // Alloc functions with alignment and padding for SIMD Value MemRefBuilder::alignedAllocWithSimdPadding( - mlir::MemRefType type, int64_t VL, int64_t alignment) const { + MemRefType type, int64_t VL, int64_t alignment) const { llvm::SmallVector dynSymbols; return alignedAllocWithSimdPadding(type, dynSymbols, VL, alignment); } @@ -1316,17 +1397,16 @@ Value MemRefBuilder::alignedAllocWithSimdPadding(MemRefType type, if (bitWidth % 8 == 0) { // We have elements that have sizes of 1 or more bytes. int64_t byteWidth = bitWidth / 8; - IndexExpr totByteSize = LiteralIndexExpr(staticSize * byteWidth) * dynSize; - totPaddedByteSize = totByteSize + LiteralIndexExpr(paddingSize * byteWidth); + IndexExpr totByteSize = LitIE(staticSize * byteWidth) * dynSize; + totPaddedByteSize = totByteSize + LitIE(paddingSize * byteWidth); } else { // We have sub-byte element sizes. Need to do precise computations. Namely // first compute tot total number of bits (including static/dynamic // and padding bit sizes), and then doing a ceil division by // 8 (number of bits in a byte). - IndexExpr totBitSize = LiteralIndexExpr(staticSize * bitWidth) * dynSize; - IndexExpr totPaddedBitSize = - totBitSize + LiteralIndexExpr(paddingSize * bitWidth); - totPaddedByteSize = totPaddedBitSize.ceilDiv(LiteralIndexExpr(8)); + IndexExpr totBitSize = LitIE(staticSize * bitWidth) * dynSize; + IndexExpr totPaddedBitSize = totBitSize + LitIE(paddingSize * bitWidth); + totPaddedByteSize = totPaddedBitSize.ceilDiv(LitIE(8)); } if (staticShape) assert(totPaddedByteSize.isLiteral() && "expected literal padded tot size"); @@ -1354,9 +1434,8 @@ Value MemRefBuilder::alignedAllocWithSimdPadding(Value operandOfSameType, return alignedAllocWithSimdPadding(type, dynSymbols, VL, alignment); } -Value MemRefBuilder::alignedAllocWithSimdPadding(MemRefType type, - llvm::SmallVectorImpl &dims, int64_t VL, - int64_t alignment) const { +Value MemRefBuilder::alignedAllocWithSimdPadding( + MemRefType type, DimsExprRef dims, int64_t VL, int64_t alignment) const { llvm::SmallVector dynSymbols; computeDynSymbols(type, dims, dynSymbols); return alignedAllocWithSimdPadding(type, dynSymbols, VL, alignment); @@ -1396,7 +1475,7 @@ memref::ReshapeOp MemRefBuilder::reshape(MemRefType destType, } memref::ReshapeOp MemRefBuilder::reshape( - llvm::SmallVectorImpl &destDims, Value valToReshape) const { + DimsExpr &destDims, Value valToReshape) const { // Compute Shape. llvm::SmallVector outputShape; IndexExpr::getShape(destDims, outputShape); @@ -1426,9 +1505,7 @@ memref::ReshapeOp MemRefBuilder::reshape( // flatten at least 1 dim (which is a noop). Output rank is Rank(input) - // dimsToFlatten + 1. Value MemRefBuilder::reshapeToFlatInnermost(Value valToReshape, - llvm::SmallVectorImpl &dims, - llvm::SmallVectorImpl &flattenedDims, - int64_t dimsToFlatten) const { + DimsExprRef dims, DimsExpr &flattenedDims, int64_t dimsToFlatten) const { // Parse input. MemRefType inputType = mlir::cast(valToReshape.getType()); assert(!hasNonIdentityLayout(inputType) && "MemRef is not normalized"); @@ -1440,7 +1517,8 @@ Value MemRefBuilder::reshapeToFlatInnermost(Value valToReshape, if (dimsToFlatten == 1) { // Flattening of the last dim is really no flattening at all. Return // original value before doing the actual reshaping, which is unnecessary. - flattenedDims = dims; + for (IndexExpr d : dims) + flattenedDims.emplace_back(d); return valToReshape; } // Compute the dimensions of the flattened array. @@ -1450,7 +1528,7 @@ Value MemRefBuilder::reshapeToFlatInnermost(Value valToReshape, for (int64_t d = 0; d < axis; ++d) flattenedDims.emplace_back(dims[d]); // Last flatten dim is the product of remaining input dims. - IndexExpr numOfFlattenedElements = LiteralIndexExpr(1); + IndexExpr numOfFlattenedElements = LitIE(1); for (int64_t d = axis; d < inputRank; ++d) numOfFlattenedElements = numOfFlattenedElements * dims[d]; flattenedDims.emplace_back(numOfFlattenedElements); @@ -1458,9 +1536,8 @@ Value MemRefBuilder::reshapeToFlatInnermost(Value valToReshape, return reshape(flattenedDims, valToReshape); } -Value MemRefBuilder::reshapeToFlat2D(Value valToReshape, - llvm::SmallVectorImpl &dims, - llvm::SmallVectorImpl &flattenedDims, int64_t axis) const { +Value MemRefBuilder::reshapeToFlat2D(Value valToReshape, DimsExprRef dims, + DimsExpr &flattenedDims, int64_t axis) const { // Parse input. MemRefType inputType = mlir::cast(valToReshape.getType()); assert(!hasNonIdentityLayout(inputType) && "MemRef is not normalized"); @@ -1472,18 +1549,19 @@ Value MemRefBuilder::reshapeToFlat2D(Value valToReshape, assert(axis > 0 && axis < inputRank && "axis is out of range"); if (inputRank == 2) { // Input is already 2D, nothing to do. - flattenedDims = dims; + for (IndexExpr d : dims) + flattenedDims.emplace_back(d); return valToReshape; } // Compute the dimensions of the flattened array. flattenedDims.clear(); // First output dim: product of input dims until axis (exclusively). - IndexExpr numElement1stDim = LiteralIndexExpr(1); + IndexExpr numElement1stDim = LitIE(1); for (int64_t d = 0; d < axis; ++d) numElement1stDim = numElement1stDim * dims[d]; flattenedDims.emplace_back(numElement1stDim); // Second output dim: product of input dims after axis (inclusively). - IndexExpr numElement2ndDim = LiteralIndexExpr(1); + IndexExpr numElement2ndDim = LitIE(1); for (int64_t d = axis; d < inputRank; ++d) numElement2ndDim = numElement2ndDim * dims[d]; flattenedDims.emplace_back(numElement2ndDim); @@ -1491,8 +1569,8 @@ Value MemRefBuilder::reshapeToFlat2D(Value valToReshape, return reshape(flattenedDims, valToReshape); } -memref::ReshapeOp MemRefBuilder::reshapeFromFlat(Value valToReshape, - llvm::SmallVectorImpl &outputDims, MemRefType outputType) const { +memref::ReshapeOp MemRefBuilder::reshapeFromFlat( + Value valToReshape, DimsExpr &outputDims, MemRefType outputType) const { assert(!hasNonIdentityLayout(outputType) && "MemRef is not normalized"); return reshape(outputDims, valToReshape); } @@ -1504,20 +1582,18 @@ memref::CastOp MemRefBuilder::cast(Value input, MemRefType outputType) const { return b().create(loc(), outputType, input); } -Value MemRefBuilder::reinterpretCast( - Value input, SmallVectorImpl &outputDims) const { - // IndexExpr zero = LiteralIndexExpr(0); +Value MemRefBuilder::reinterpretCast(Value input, DimsExpr &outputDims) const { return reinterpretCast(input, nullptr, outputDims); } Value MemRefBuilder::reinterpretCast( - Value input, Value offset, SmallVectorImpl &outputDims) const { + Value input, Value offset, DimsExpr &outputDims) const { // Compute new sizes and strides. int64_t rank = outputDims.size(); SmallVector sizesIE, stridesIE; sizesIE.resize(rank); stridesIE.resize(rank); - IndexExpr strideIE = LiteralIndexExpr(1); + IndexExpr strideIE = LitIE(1); for (int i = rank - 1; i >= 0; --i) { sizesIE[i] = outputDims[i]; stridesIE[i] = strideIE; @@ -1588,25 +1664,21 @@ memref::ViewOp MemRefBuilder::view(Value input, int64_t byteOffset, loc(), outputType, input, offset, outputDynSymbols); } -memref::SubViewOp MemRefBuilder::subView(Value val, - llvm::SmallVectorImpl &offsets, - llvm::SmallVectorImpl &sizes, - llvm::SmallVectorImpl &strides) const { +memref::SubViewOp MemRefBuilder::subView(Value val, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) const { return b().create(loc(), val, offsets, sizes, strides); } memref::SubViewOp MemRefBuilder::subView(MemRefType outputType, Value val, - llvm::SmallVectorImpl &offsets, - llvm::SmallVectorImpl &sizes, - llvm::SmallVectorImpl &strides) const { + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) const { return b().create( loc(), outputType, val, offsets, sizes, strides); } memref::SubViewOp MemRefBuilder::subView(Value input, - llvm::SmallVectorImpl &offsetsIE, - llvm::SmallVectorImpl &sizesIE, - llvm::SmallVectorImpl &stridesIE) const { + ArrayRef offsetsIE, ArrayRef sizesIE, + ArrayRef stridesIE) const { SmallVector offsets, sizes, strides; IndexExpr::getOpOrFoldResults(offsetsIE, offsets); IndexExpr::getOpOrFoldResults(sizesIE, sizes); @@ -1642,25 +1714,49 @@ Value MemRefBuilder::dim(Value val, Value index) const { void MemRefBuilder::prefetch(Value memref, ValueRange indices, bool isWrite, unsigned locality, bool isData) { + // Disabled due to downstream linking issues + // if (disableMemRefPrefetch) + // return; b().create( loc(), memref, indices, isWrite, locality, isData); } -void MemRefBuilder::prefetchIE(Value memref, - llvm::SmallVectorImpl &indices, bool isWrite, unsigned locality, - bool isData) { +void MemRefBuilder::prefetchIE(Value memref, ArrayRef indices, + bool isWrite, unsigned locality, bool isData) { + // Disabled due to downstream linking issues + // if (disableMemRefPrefetch) + // return; SmallVector indexVals; IndexExpr::getValues(indices, indexVals); prefetch(memref, indexVals, isWrite, locality, isData); } +//===----------------------------------------------------------------------===// +// Queries + +/*static*/ bool MemRefBuilder::isNoneValue(Value value) { + return mlir::isa(value.getType()); +} + +/*static*/ bool MemRefBuilder::hasOneElementInInnermostDims( + Value value, int64_t innerDim) { + // Get info. + ShapedType type = mlir::dyn_cast(value.getType()); + assert(type && "expected shaped type"); + int64_t rank = type.getRank(); + ArrayRef shape = type.getShape(); + for (int64_t i = std::max((int64_t)0, rank - innerDim); i < rank; ++i) + if (shape[i] != 1) + return false; + return true; +} + //===----------------------------------------------------------------------===// // Structured Control Flow (SCF). //===----------------------------------------------------------------------===// -void SCFBuilder::ifThenElse(Value cond, - function_ref thenFn, - function_ref elseFn) const { +void SCFBuilder::ifThenElse( + Value cond, SCFThenElseBodyFn thenFn, SCFThenElseBodyFn elseFn) const { if (!elseFn) { b().create(loc(), cond, /* then */ @@ -1687,24 +1783,39 @@ void SCFBuilder::ifThenElse(Value cond, } } -void SCFBuilder::forLoop(Value lowerBound, Value upperBound, int64_t step, - function_ref bodyFn) const { +void SCFBuilder::forLoop( + Value lb, Value ub, int64_t step, SCFLoopBodyFn bodyFn) const { MathBuilder createMath(*this); Value stepVal = createMath.constantIndex(step); - b().create(loc(), lowerBound, upperBound, stepVal, std::nullopt, + b().create(loc(), lb, ub, stepVal, std::nullopt, [&](OpBuilder &childBuilder, Location childLoc, Value inductionVar, ValueRange args) { SCFBuilder builder(childBuilder, childLoc); - bodyFn(builder, inductionVar); + bodyFn(builder, {inductionVar}); yield(); }); } -void SCFBuilder::parallelLoop(ValueRange lowerBounds, ValueRange upperBounds, - ValueRange steps, - function_ref bodyFn) const { - // SmallVectorImpl ivStorage; - b().create(loc(), lowerBounds, upperBounds, steps, +void SCFBuilder::forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, + bool useParallel, SCFLoopBodyFn bodyFn) const { + if (useParallel) { + MathBuilder createMath(*this); + Value stepVal = createMath.constantIndex(step); + parallelLoops({lb.getValue()}, {ub.getValue()}, {stepVal}, bodyFn); + } else { + forLoop(lb.getValue(), ub.getValue(), step, bodyFn); + } +} + +void SCFBuilder::forLoopsIE(ArrayRef lbs, ArrayRef ubs, + ArrayRef steps, ArrayRef useParallel, + SCFLoopBodyFn builderFn) const { + impl::forLoopsIE(*this, lbs, ubs, steps, useParallel, builderFn); +} + +void SCFBuilder::parallelLoops(ValueRange lbs, ValueRange ubs, ValueRange steps, + SCFLoopBodyFn bodyFn) const { + b().create(loc(), lbs, ubs, steps, [&](OpBuilder &childBuilder, Location childLoc, ValueRange inductionVars) { SCFBuilder builder(childBuilder, childLoc); @@ -1715,6 +1826,40 @@ void SCFBuilder::parallelLoop(ValueRange lowerBounds, ValueRange upperBounds, void SCFBuilder::yield() const { b().create(loc()); } +void SCFBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, bool useParallel, ArrayRef inputs, + ArrayRef inputAFs, ArrayRef outputs, + ArrayRef outputAFs, + ArrayRef bodyFnList) const { + onnx_mlir::impl::simdIterateIE(*this, lb, ub, VL, + fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs, bodyFnList); +} + +void SCFBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, ArrayRef inputs, ArrayRef inputAFs, + ArrayRef tmps, ArrayRef tmpAFs, ArrayRef outputs, + ArrayRef outputAFs, ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef reductionFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef postReductionFnList) const { + onnx_mlir::impl::simdReduceIE(*this, lb, ub, VL, + fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals, + reductionFnList, postReductionFnList); +} + +void SCFBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, Value input, DimsExpr inputAF, Value tmp, DimsExpr tmpAF, + Value output, DimsExpr outputAF, Value initVal, + /* reduction functions (simd or scalar) */ + SCFSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + SCFSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE(*this, lb, ub, VL, + fullySimd, input, inputAF, tmp, tmpAF, output, outputAF, initVal, + reductionBodyFn, postReductionBodyFn); +} + //===----------------------------------------------------------------------===// // Vector Builder //===----------------------------------------------------------------------===// @@ -1762,44 +1907,33 @@ int64_t VectorBuilder::getArchVectorLength(Value vecValue) const { return getArchVectorLength(vecType.getElementType()); } -Value VectorBuilder::load( - VectorType vecType, Value memref, ValueRange indices) const { - return b().create(loc(), vecType, memref, indices); -} -mlir::Value VectorBuilder::load(mlir::VectorType vecType, mlir::Value memref, - mlir::ValueRange indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MultiDialectBuilder create(*this); - create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); - return load(vecType, memref, computedIndices); -} - -mlir::Value VectorBuilder::loadIE(mlir::VectorType vecType, mlir::Value memref, - llvm::ArrayRef indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; +Value VectorBuilder::load(VectorType vecType, Value memref, ValueRange indices, + ValueRange offsets) const { + // Cannot use the onnx_mlir::impl::load because we also need to pass the type. + llvm::SmallVector computedIndices; MultiDialectBuilder create(*this); create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); - return load(vecType, memref, computedIndices); + return b().create(loc(), vecType, memref, computedIndices); } -void VectorBuilder::store(Value val, Value memref, ValueRange indices) const { - b().create(loc(), val, memref, indices); +Value VectorBuilder::loadIE(VectorType vecType, Value memref, + llvm::ArrayRef indices, ValueRange offsets) const { + // Cannot use the onnx_mlir::impl::load because we also need to pass the type. + llvm::SmallVector indexValues; + IndexExpr::getValues(indices, indexValues); + return load(vecType, memref, indexValues, offsets); } -void VectorBuilder::store(mlir::Value val, mlir::Value memref, - mlir::ValueRange indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MultiDialectBuilder create(*this); - create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); - store(val, memref, computedIndices); +void VectorBuilder::store( + Value val, Value memref, ValueRange indices, ValueRange offsets) const { + onnx_mlir::impl::store( + *this, val, memref, indices, offsets); } -void VectorBuilder::storeIE(mlir::Value val, mlir::Value memref, - llvm::ArrayRef indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MultiDialectBuilder create(*this); - create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); - store(val, memref, computedIndices); +void VectorBuilder::storeIE(Value val, Value memref, + llvm::ArrayRef indices, ValueRange offsets) const { + onnx_mlir::impl::storeIE( + *this, val, memref, indices, offsets); } Value VectorBuilder::fma(Value lhs, Value rhs, Value acc) const { @@ -1953,7 +2087,7 @@ Value VectorBuilder::reduction( // For example, when we passe N=VL input vectors, the output has one vector; // when we passe N=2VL input vectors, the output has 2 vectors... -void VectorBuilder::multiReduction(SmallVectorImpl &inputVecArray, +void VectorBuilder::multiReduction(ArrayRef inputVecArray, F2 reductionFct, SmallVectorImpl &outputVecArray) { uint64_t N = inputVecArray.size(); assert(N > 0 && "expected at least one value to reduce"); @@ -2002,6 +2136,47 @@ void VectorBuilder::multiReduction(SmallVectorImpl &inputVecArray, } } +// Cast vectors to vectors of different shape (e.g. 1D to 2D and back). +Value VectorBuilder::shapeCast(VectorType newType, Value vector) const { + return b().create(loc(), newType, vector); +} + +// Extract 1D vector from 2D vector. +Value VectorBuilder::extractFrom2D(Value vector2D, int64_t position) const { + llvm::SmallVector pos = {position}; + return b().create(loc(), vector2D, pos); +} + +// Insert 1D vector into 2D vector. +Value VectorBuilder::insertInto2D( + Value vector, Value vector2D, int64_t position) const { + llvm::SmallVector pos = {position}; + return b().create(loc(), vector, vector2D, pos); +} + +Value VectorBuilder::extractElement(Value vector, int64_t index) const { + MultiDialectBuilder create(*this); + VectorType type = llvm::cast(vector.getType()); + int64_t VL = type.getShape()[0]; + assert(type.getRank() == 1 && "expected 1D vector only"); + assert(index >= 0 && index < VL && "out of range vector index"); + Value position = create.math.constantIndex(index); + return b().create(loc(), vector, position); +} + +Value VectorBuilder::insertElement( + Value vector, Value element, int64_t index) const { + MultiDialectBuilder create(*this); + VectorType type = llvm::cast(vector.getType()); + int64_t VL = type.getShape()[0]; + assert(type.getRank() == 1 && "expected 1D vector only"); + assert(index >= 0 && index < VL && "out of range vector index"); + Value position = create.math.constantIndex(index); + // Unlike LLVM insert element which takes , vector + // take + return b().create(loc(), element, vector, position); +} + //===----------------------------------------------------------------------===// // LLVM Builder //===----------------------------------------------------------------------===// @@ -2099,7 +2274,8 @@ Value LLVMBuilder::constant(Type type, int64_t val) const { assert(type.isSignless() && "LLVM::ConstantOp requires a signless type."); constant = b().create(loc(), type, - b().getIntegerAttr(type, APInt(width, (int64_t)val))); + b().getIntegerAttr( + type, APInt(width, static_cast(val), false, true))); } }) .Case([&](Type) { @@ -2150,7 +2326,7 @@ LLVM::LLVMFuncOp LLVMBuilder::func( StringRef funcName, Type funcType, bool createUniqueFunc) const { // If createUniqueFunc, we create two functions: name and name_postfix. // They have the same signatures and `name` will call `name_postfix`. - // `name_postfix` funtion is expected to be unique across all generated + // `name_postfix` function is expected to be unique across all generated // modules, allowing to run multiple models at the same time. LLVM::LLVMFuncOp funcOp = b().create(loc(), funcName, funcType); diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index 26314c04f1..d59f253128 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -55,7 +55,7 @@ struct DialectBuilder { protected: // Private getters of builder and location (concise version). mlir::OpBuilder &b() const { - assert(builder); + assert(builder && "builder is null"); return *builder; } mlir::Location loc() const { return location; } @@ -119,6 +119,9 @@ struct MathBuilder final : DialectBuilder { // "B" below indicates that the operation will splat scalar values if one of // the input value is itself a vector. + // "B" below indicates that the operation will splat scalar values if one of + // the input value is itself a vector. + mlir::Value abs(mlir::Value val) const; mlir::Value add(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value andi(mlir::Value lhs, mlir::Value rhs) const; // B/Int only. @@ -142,6 +145,8 @@ struct MathBuilder final : DialectBuilder { mlir::Value pow(mlir::Value base, mlir::Value exp) const; // B/Float only. mlir::Value rem(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value round(mlir::Value) const; // Float only. + mlir::Value roundEven(mlir::Value) const; // Float only. + mlir::Value roundEvenEmulation(mlir::Value) const; // Float only. mlir::Value sqrt(mlir::Value val) const; // Float only. mlir::Value sub(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value tanh(mlir::Value val) const; // Float only. @@ -254,6 +259,18 @@ struct MemRefBuilder final : DialectBuilder { // Constants static const int64_t defaultAlign; + // Common load/store interface (krnl/affine/memref) + // Add offsets (if any) to the least significant memref dims. + mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}, + mlir::ValueRange offsets = {}) const; + mlir::Value loadIE(mlir::Value memref, mlir::ArrayRef indices = {}, + mlir::ValueRange offsets = {}) const; + void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, + mlir::ValueRange offsets = {}) const; + void storeIE(mlir::Value val, mlir::Value memref, + mlir::ArrayRef indices = {}, + mlir::ValueRange offsets = {}) const; + // Info: get static and dynamic size of memory in number of elementary type. // Array of vector types are not supported at this time. // @@ -266,9 +283,8 @@ struct MemRefBuilder final : DialectBuilder { bool getStaticAndDynamicMemSize(mlir::MemRefType type, mlir::ValueRange dynSymbols, int64_t &staticSize, IndexExpr &dynSize, int64_t range = 1000) const; - bool getStaticAndDynamicMemSize(mlir::MemRefType type, - llvm::SmallVectorImpl &dims, int64_t &staticSize, - IndexExpr &dynSize, int64_t range = 1000) const; + bool getStaticAndDynamicMemSize(mlir::MemRefType type, DimsExprRef dims, + int64_t &staticSize, IndexExpr &dynSize, int64_t range = 1000) const; // Same as above, but does not track of dynamic size. static bool getStaticMemSize( mlir::MemRefType type, int64_t &staticSize, int64_t range = 1000); @@ -280,8 +296,7 @@ struct MemRefBuilder final : DialectBuilder { mlir::MemRefType type, mlir::ValueRange dynSymbols) const; mlir::memref::AllocOp alloc( mlir::Value operandOfSameType, mlir::MemRefType type) const; - mlir::memref::AllocOp alloc( - mlir::MemRefType type, llvm::SmallVectorImpl &dims) const; + mlir::memref::AllocOp alloc(mlir::MemRefType type, DimsExprRef dims) const; // Alloc for static shapes with alignment. // Minimum alignment is gDefaultAllocAlign. @@ -292,8 +307,7 @@ struct MemRefBuilder final : DialectBuilder { mlir::ValueRange dynSymbols, int64_t align = defaultAlign) const; mlir::memref::AllocOp alignedAlloc(mlir::Value operandOfSameType, mlir::MemRefType type, int64_t align = defaultAlign) const; - mlir::memref::AllocOp alignedAlloc(mlir::MemRefType type, - llvm::SmallVectorImpl &dims, + mlir::memref::AllocOp alignedAlloc(mlir::MemRefType type, DimsExprRef dims, int64_t align = defaultAlign) const; // Alloc for shapes with alignment and padding for safe full SIMD @@ -313,13 +327,15 @@ struct MemRefBuilder final : DialectBuilder { mlir::MemRefType type, int64_t VL = 1, int64_t align = defaultAlign) const; mlir::Value alignedAllocWithSimdPadding(mlir::MemRefType type, - llvm::SmallVectorImpl &dims, int64_t VL = 1, - int64_t align = defaultAlign) const; + DimsExprRef dims, int64_t VL = 1, int64_t align = defaultAlign) const; // The alloca instruction allocates memory on the stack frame of the // currently executing function, to be automatically released when this // function returns to its caller. It is strongly suggested to place alloca // instructions outside of a loop. + // + // When possible, DO NOT USE ALLOCA except for a few scalars. + // mlir::memref::AllocaOp alloca(mlir::MemRefType type) const; mlir::memref::AllocaOp alignedAlloca( mlir::MemRefType type, int64_t align = defaultAlign) const; @@ -331,44 +347,39 @@ struct MemRefBuilder final : DialectBuilder { mlir::Value valToReshape, mlir::Value outputShapeStoredInMem) const; // Reshape to dimensions passed by destDims. Will create data-structure to // hold the dims, save into it, and the perform the actual reshape. - mlir::memref::ReshapeOp reshape(llvm::SmallVectorImpl &outputDims, - mlir::Value valToReshape) const; + mlir::memref::ReshapeOp reshape( + DimsExpr &outputDims, mlir::Value valToReshape) const; // Flatten innermost dimensions of a MemRef. User provide the value to // reshape (valToReshape), its dims (dims), and the number of innermost // loops to collapse (dimsToFlatten). The function computes the new // flattened dimensions (flattenDims) and return the flattened value. Values // of dimsToFlatten are in the [1, rank of input] range. Legal only on types // with identity layouts. - mlir::Value reshapeToFlatInnermost(mlir::Value valToReshape, - llvm::SmallVectorImpl &dims, - llvm::SmallVectorImpl &flattenDims, - int64_t dimsToFlatten) const; + mlir::Value reshapeToFlatInnermost(mlir::Value valToReshape, DimsExprRef dims, + DimsExpr &flattenDims, int64_t dimsToFlatten) const; // Flatten to a 2D MemRef, with outer dim including outermost dim to axis // -1, and inner dim including the remaining innermost dims. Values of axis // are in the [1, rank of input) range. Negative axis values are taken from // the back. Legal only on types with identity layouts. - mlir::Value reshapeToFlat2D(mlir::Value valToReshape, - llvm::SmallVectorImpl &dims, - llvm::SmallVectorImpl &flattenDims, int64_t axis) const; + mlir::Value reshapeToFlat2D(mlir::Value valToReshape, DimsExprRef dims, + DimsExpr &flattenDims, int64_t axis) const; // Perform the reverse operation; given a flattened value, unflatten it by // giving the function its original unflattened dimensions (outputDims) and // type (outputType). Legal only on types with identity layouts. mlir::memref::ReshapeOp reshapeFromFlat(mlir::Value valToReshape, - llvm::SmallVectorImpl &outputDims, - mlir::MemRefType outputType) const; + DimsExpr &outputDims, mlir::MemRefType outputType) const; // Casts. mlir::memref::CastOp cast( mlir::Value input, mlir::MemRefType outputType) const; + mlir::Value reinterpretCast(mlir::Value input, DimsExpr &outputDims) const; mlir::Value reinterpretCast( - mlir::Value input, llvm::SmallVectorImpl &outputDims) const; - mlir::Value reinterpretCast(mlir::Value input, mlir::Value offset, - llvm::SmallVectorImpl &outputDims) const; + mlir::Value input, mlir::Value offset, DimsExpr &outputDims) const; // Does not support layouts at this time. Does only work for values that are // then loaded with affine or memref scalar load/store (MLIR limitations). mlir::Value collapseShape(mlir::Value input, - llvm::ArrayRef reassociation); + mlir::ArrayRef reassociation); // Create a view of input value (xi8) starting at byteOffset and // shaped by outputType. @@ -377,38 +388,43 @@ struct MemRefBuilder final : DialectBuilder { // Create a subview of val. mlir::memref::SubViewOp subView(mlir::Value val, - llvm::SmallVectorImpl &offsets, // Offset for each val dims. - llvm::SmallVectorImpl &sizes, // Sizes for each val dims. - llvm::SmallVectorImpl &strides) // Stride for each val dims. + mlir::ArrayRef offsets, // Offset for each val dims. + mlir::ArrayRef sizes, // Sizes for each val dims. + mlir::ArrayRef strides) // Stride for each val dims. const; // Create a subview of val. mlir::memref::SubViewOp subView(mlir::MemRefType outputType, mlir::Value val, - llvm::SmallVectorImpl &offsets, // Offset for each val dims. - llvm::SmallVectorImpl &sizes, // Sizes for each val dims. - llvm::SmallVectorImpl &strides) // Stride for each val dims. + mlir::ArrayRef offsets, // Offset for each val dims. + mlir::ArrayRef sizes, // Sizes for each val dims. + mlir::ArrayRef strides) // Stride for each val dims. const; // Create a subview of val. Size of 1 => remove that dim. mlir::memref::SubViewOp subView(mlir::Value val, - llvm::SmallVectorImpl &offsets, // Offset for each val dims. - llvm::SmallVectorImpl &sizes, // Sizes for each val dims. - llvm::SmallVectorImpl &strides) // Stride for each val dims. + mlir::ArrayRef offsets, // Offset for each val dims. + mlir::ArrayRef sizes, // Sizes for each val dims. + mlir::ArrayRef strides) // Stride for each val dims. const; mlir::Value dim(mlir::Value val, int64_t index) const; mlir::Value dim(mlir::Value val, mlir::Value index) const; - void prefetchIE(mlir::Value memref, llvm::SmallVectorImpl &indices, + void prefetchIE(mlir::Value memref, mlir::ArrayRef indices, bool isWrite, unsigned locality, bool isData = true); void prefetch(mlir::Value memref, mlir::ValueRange indices, bool isWrite, unsigned locality, bool isData = true); + // Queries about memory + static bool isNoneValue(mlir::Value value); + // Check if "innerDims" innermost dims are scalar (size 1). + static bool hasOneElementInInnermostDims(mlir::Value value, int64_t innerDim); + private: mlir::IntegerAttr computeAlignment(int64_t alignment) const; void computeDynSymbols( mlir::MemRefType type, // Use type to determine dynamic dimensions. - llvm::SmallVectorImpl &dims, // Get dyn syms from index expr. + DimsExprRef dims, // Get dyn syms from index expr. llvm::SmallVectorImpl &dynSymbols) // Output dim symbols. const; void computeDynSymbols( @@ -418,6 +434,39 @@ struct MemRefBuilder final : DialectBuilder { const; }; +//===----------------------------------------------------------------------===// +// Functions definitions for SIMD methods (simdIterate & simdReduce) +//===----------------------------------------------------------------------===// + +namespace impl { + +// For simdIterate: given a list of inputs, create one output value. +template +using SimdIterateBodyFn = std::function inputVals, int64_t VL)>; + +// For simdReduce: take one input & one temp reduction value, and generate the +// new reduction value. +template +using SimdReductionBodyFn = std::function; + +// For simdReduce: take one temp simd reduction value, create a scalar +// reduction, and possibly apply post processing to it (e.g. div by number of +// elements). +// +// For simdReduce2D: only the post processing. Reduction is done before. +template +using SimdPostReductionBodyFn = std::function; + +// Function used for (nearly) all loops, where there is typically one value in +// the provided ValueRange per loop nest. +template +using LoopBodyFn = mlir::function_ref; + +} // namespace impl + //===----------------------------------------------------------------------===// // Structured Control Flow (SCF) Builder //===----------------------------------------------------------------------===// @@ -430,17 +479,53 @@ struct SCFBuilder final : DialectBuilder { /// Create an if then with optional else. Construct does not generate a /// result (unlike some scf::if) and introduces the yields automatically. - void ifThenElse(mlir::Value cond, - mlir::function_ref thenFn, - mlir::function_ref elseFn = nullptr) const; - // Create a for loop. - void forLoop(mlir::Value lowerBound, mlir::Value upperBound, int64_t step, - mlir::function_ref bodyFn) const; - // Create a parallel for loop. - void parallelLoop(mlir::ValueRange lowerBounds, mlir::ValueRange upperBounds, - mlir::ValueRange steps, - mlir::function_ref bodyFn) const; + using SCFThenElseBodyFn = mlir::function_ref; + void ifThenElse(mlir::Value cond, SCFThenElseBodyFn thenFn, + SCFThenElseBodyFn elseFn = nullptr) const; + // Common loop interface (krnl/affine/scf). + using SCFLoopBodyFn = impl::LoopBodyFn; + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, + SCFLoopBodyFn bodyFn) const; + void forLoopsIE(mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, mlir::ArrayRef useParallel, + SCFLoopBodyFn builderFn) const; + // Custom interface + void forLoop( + mlir::Value lb, mlir::Value ub, int64_t step, SCFLoopBodyFn bodyFn) const; + void parallelLoops(mlir::ValueRange lbs, mlir::ValueRange ubs, + mlir::ValueRange steps, SCFLoopBodyFn bodyFn) const; + void yield() const; + + // Common simd loop interface (krnl/affine/scf). + // For detailed description, see KrnlBuilder.hpp file. + using SCFSimdIterateBodyFn = impl::SimdIterateBodyFn; + void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + bool useParallel, mlir::ArrayRef inputs, + mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, + mlir::ArrayRef outputAFs, + mlir::ArrayRef simdIterateBodyList) const; + + // For detailed description, see KrnlBuilder.hpp file. + using SCFSimdReductionBodyFn = impl::SimdReductionBodyFn; + using SCFSimdPostReductionBodyFn = impl::SimdPostReductionBodyFn; + void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef temps, mlir::ArrayRef tempAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef simdReductionBodyFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef simdPostReductionBodyFnList) + const; + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + SCFSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + SCFSimdPostReductionBodyFn postReductionBodyFn) const; }; //===----------------------------------------------------------------------===// @@ -472,21 +557,19 @@ struct VectorBuilder final : DialectBuilder { // Vector load: memref is expected to be scalar, will load a vector's worth // of values: e.g. %result = vector.load %base[%i, %j] : // memref<100x100xf32>, vector<8xf32>. + // Add offsets (if any) to the least significant memref dims. mlir::Value load(mlir::VectorType vecType, mlir::Value memref, - mlir::ValueRange indices = {}) const; - // When ranks of offsets indices, mlir::ValueRange offsets) const; + mlir::ArrayRef indices = {}, + mlir::ValueRange offsets = {}) const; // Vector store: memref can be a scalar, will store the vector values. - void store( - mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}) const; - // When ranks of offsets indices, mlir::ValueRange offsets) const; + mlir::ArrayRef indices = {}, + mlir::ValueRange offsets = {}) const; // Splat: a single value is copied. mlir::Value splat(mlir::VectorType vecType, mlir::Value val) const; @@ -503,9 +586,21 @@ struct VectorBuilder final : DialectBuilder { mlir::Value mergeHigh(mlir::Value lhs, mlir::Value rhs, int64_t step) const; mlir::Value mergeLow(mlir::Value lhs, mlir::Value rhs, int64_t step) const; mlir::Value reduction(CombiningKind kind, mlir::Value value) const; - void multiReduction(llvm::SmallVectorImpl &inputVecArray, + void multiReduction(mlir::ArrayRef inputVecArray, F2 reductionFct, llvm::SmallVectorImpl &outputVecArray); + // Cast vectors to vectors of different shape (e.g. 1D to 2D and back). + mlir::Value shapeCast(mlir::VectorType newType, mlir::Value vector) const; + // Extract and insert 1D vector from/to 2D vector. + mlir::Value extractFrom2D(mlir::Value vector2D, int64_t position) const; + mlir::Value insertInto2D( + mlir::Value vector, mlir::Value vector2D, int64_t position) const; + + // Insert and extract one element (scalar). + mlir::Value extractElement(mlir::Value vector, int64_t position) const; + mlir::Value insertElement( + mlir::Value vector, mlir::Value element, int64_t position) const; + private: bool isPowerOf2(uint64_t num) const; uint64_t getLengthOf1DVector(mlir::Value vec) const; @@ -523,40 +618,71 @@ struct GenericAffineBuilder final : DialectBuilder { GenericAffineBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~GenericAffineBuilder() {} - mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}) const; - // When ranks of offsets indices, - mlir::ValueRange offsets) const; - - void store( - mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}) const; - // When ranks of offsets indices = {}, + mlir::ValueRange offsets = {}) const; + void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, + mlir::ValueRange offsets = {}) const; void storeIE(mlir::Value val, mlir::Value memref, - llvm::ArrayRef indices, mlir::ValueRange offsets) const; + mlir::ArrayRef indices = {}, + mlir::ValueRange offsets = {}) const; mlir::Operation *prefetch(mlir::Value memref, mlir::AffineMap map, mlir::ValueRange indices, bool isWrite, unsigned localityHint, bool isDataCache = true); - void forIE(IndexExpr lb, IndexExpr ub, int64_t step, - mlir::function_ref builderFn) - const; - void forIE(llvm::SmallVectorImpl &lbs, - llvm::SmallVectorImpl &ubs, - llvm::SmallVectorImpl &steps, - mlir::function_ref - builderFn) const; + // Common loop interface (krnl/affine/scf). + using GenericAffineLoopBodyFn = impl::LoopBodyFn; + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, + GenericAffineLoopBodyFn builderFn) const; + void forLoopsIE(mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, mlir::ArrayRef useParallel, + GenericAffineLoopBodyFn builderFn) const; + + // Custom interface + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, + GenericAffineLoopBodyFn builderFn) const; // Sequential only. + + // Common simd loop interface (krnl/affine/scf). + using GenericAffineSimdIterateBodyFn = + impl::SimdIterateBodyFn>; + void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + bool useParallel, mlir::ArrayRef inputs, + mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, + mlir::ArrayRef outputAFs, + mlir::ArrayRef simdIterateBodyList) const; + + using GenericAffineSimdReductionBodyFn = + impl::SimdReductionBodyFn>; + using GenericAffineSimdPostReductionBodyFn = + impl::SimdPostReductionBodyFn>; + void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef temps, mlir::ArrayRef tempAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef simdReductionBodyFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef + simdPostReductionBodyFnList) const; + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + GenericAffineSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + GenericAffineSimdPostReductionBodyFn postReductionBodyFn) const; // This if then else construct has no arguments to the blocks. - void ifThenElse(IndexExprScope &scope, - llvm::SmallVectorImpl &conditions, - mlir::function_ref thenFn, - mlir::function_ref elseFn) - const; + using GenericAffineThenElseBodyFn = + mlir::function_ref &)>; + void ifThenElseIE(IndexExprScope &scope, mlir::ArrayRef conditions, + GenericAffineThenElseBodyFn thenFn, + GenericAffineThenElseBodyFn elseFn) const; // AffineApplyOp mlir::Value apply(mlir::AffineMap map, mlir::ValueRange operands) const; @@ -564,14 +690,6 @@ struct GenericAffineBuilder final : DialectBuilder { void yield() const; private: - // Support for multiple forIE loops. - void recursionForIE(llvm::SmallVectorImpl &lbs, - llvm::SmallVectorImpl &ubs, - llvm::SmallVectorImpl &steps, - llvm::SmallVectorImpl &loopIndices, - mlir::function_ref - builderFn) const; - // Support for adding blocks. void appendToBlock(mlir::Block *block, mlir::function_ref builderFn) const; @@ -589,8 +707,9 @@ using AffineBuilder = GenericAffineBuilder; - using valueFuncRef = mlir::function_ref; + using voidFuncRef = mlir::function_ref; + using valueFuncRef = + mlir::function_ref; LLVMBuilder(mlir::Location loc) : DialectBuilder(loc) {} LLVMBuilder(mlir::OpBuilder &b, mlir::Location loc) @@ -616,7 +735,7 @@ struct LLVMBuilder final : DialectBuilder { // BrOp void br( - llvm::ArrayRef destOperands, mlir::Block *destBlock) const; + mlir::ArrayRef destOperands, mlir::Block *destBlock) const; // CallOp mlir::Value call(mlir::ArrayRef resultTypes, @@ -628,8 +747,8 @@ struct LLVMBuilder final : DialectBuilder { // CondBrOp void condBr(mlir::Value cond, mlir::Block *trueBlock, - llvm::ArrayRef trueOperands, mlir::Block *falseBlock, - llvm::ArrayRef falseOperands) const; + mlir::ArrayRef trueOperands, mlir::Block *falseBlock, + mlir::ArrayRef falseOperands) const; // ConstantOp mlir::Value constant(mlir::Type type, int64_t val) const; @@ -641,7 +760,7 @@ struct LLVMBuilder final : DialectBuilder { // ExtractValueOp mlir::Value extractValue(mlir::Type resultType, mlir::Value container, - llvm::ArrayRef position) const; + mlir::ArrayRef position) const; // FuncOp (assume non-variadic functions, otherwise add support like in // seen in `call` in this file). @@ -650,7 +769,7 @@ struct LLVMBuilder final : DialectBuilder { // GEPOp mlir::Value getElemPtr(mlir::Type resultType, mlir::Type elemType, - mlir::Value base, llvm::ArrayRef indices) const; + mlir::Value base, mlir::ArrayRef indices) const; // GlobalOp mlir::LLVM::GlobalOp globalOp(mlir::Type resultType, bool isConstant, @@ -667,7 +786,7 @@ struct LLVMBuilder final : DialectBuilder { // InsertValueOp mlir::Value insertValue(mlir::Type resultType, mlir::Value container, - mlir::Value val, llvm::ArrayRef position) const; + mlir::Value val, mlir::ArrayRef position) const; // Inttoptr mlir::Value inttoptr(mlir::Type type, mlir::Value val) const; @@ -719,7 +838,7 @@ struct LLVMBuilder final : DialectBuilder { // Get or insert a function declaration at the beginning of the module. mlir::FlatSymbolRefAttr getOrInsertSymbolRef(mlir::ModuleOp module, llvm::StringRef symName, mlir::Type resultType, - llvm::ArrayRef operandTypes, bool isVarArg = false) const; + mlir::ArrayRef operandTypes, bool isVarArg = false) const; /// Generate code that looks like "if then with optional else" at LLVM. /// The following prototype code will be generated: @@ -882,7 +1001,9 @@ struct MultiDialectBuilder : MultiDialectBuilder { }; // Include template implementations. +#ifndef ONNX_MLIR_DIALECT_BUILDER_MLIR_INC #include "DialectBuilder.hpp.inc" +#endif } // namespace onnx_mlir #endif diff --git a/src/Dialect/Mlir/DialectBuilder.hpp.inc b/src/Dialect/Mlir/DialectBuilder.hpp.inc index bf424ae920..ac4e07909c 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp.inc +++ b/src/Dialect/Mlir/DialectBuilder.hpp.inc @@ -1,10 +1,10 @@ //===---- DialectBuilder.hpp.inc - Helper functions for MLIR dialects -----===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // -// This file contains helper functions for building MLIR operations. +// This file contains template helper functions for building MLIR operations. // // Note on usage of template keyword. Since the GenericAffineBuilder is // templated, and we use templated functions (such as create), we must add @@ -13,55 +13,631 @@ // //===----------------------------------------------------------------------===// -// Implementation of GenericAffineBuilder -template -mlir::Value GenericAffineBuilder::load( - mlir::Value memref, mlir::ValueRange indices) const { - return b().template create(loc(), memref, indices); +#ifndef ONNX_MLIR_DIALECT_BUILDER_MLIR_H +// This include is only here to include builder in the editors. Will be skipped +// when actually compiling. +#define ONNX_MLIR_DIALECT_BUILDER_MLIR_INC 1 +#include "DialectBuilder.hpp" +#undef ONNX_MLIR_DIALECT_BUILDER_MLIR_INC +#endif + +//===----------------------------------------------------------------------===// +// Templates for load / store +//===----------------------------------------------------------------------===// + +namespace impl { // Hide support for loads / stores in impl namespace. + +template +mlir::Value load(const BUILDER &b, mlir::Value memref, mlir::ValueRange indices, + mlir::ValueRange offsets) { + // Handle offsets. + llvm::SmallVector computedIndices; + MultiDialectBuilder create(b); + create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); + // Perform load. + if (computedIndices.size() == 0) { + // case memref<1xdtype> + auto type = mlir::cast(memref.getType()); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + mlir::Value iZero = create.math.constantIndex(0); + return b.getBuilder().template create( + b.getLoc(), memref, mlir::ValueRange({iZero})); + } + } + return b.getBuilder().template create( + b.getLoc(), memref, computedIndices); } -template -mlir::Value GenericAffineBuilder::load(mlir::Value memref, - mlir::ValueRange indices, mlir::ValueRange offsets) const { +template +mlir::Value loadIE(const BUILDER &b, mlir::Value memref, + mlir::ArrayRef indices, mlir::ValueRange offsets) { + llvm::SmallVector indexValues; + IndexExpr::getValues(indices, indexValues); + return load(b, memref, indexValues, offsets); +} + +template +void store(const BUILDER &b, mlir::Value val, mlir::Value memref, + mlir::ValueRange indices, mlir::ValueRange offsets) { llvm::SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - return load(memref, computedIndices); + MultiDialectBuilder create(b); + create.math.addOffsetToLeastSignificant(indices, offsets, computedIndices); + if (computedIndices.size() == 0) { + // case memref<1xdtype> + auto type = mlir::cast(memref.getType()); + if (type.getRank() == 1 && type.getShape()[0] == 1) { + mlir::Value iZero = create.math.constantIndex(0); + b.getBuilder().template create( + b.getLoc(), val, memref, mlir::ValueRange({iZero})); + return; + } + } + b.getBuilder().template create( + b.getLoc(), val, memref, computedIndices); +} + +template +void storeIE(const BUILDER &b, mlir::Value val, mlir::Value memref, + mlir::ArrayRef indices, mlir::ValueRange offsets) { + llvm::SmallVector indexValues; + IndexExpr::getValues(indices, indexValues); + store(b, val, memref, indexValues, offsets); } + +//===----------------------------------------------------------------------===// +// Templates for multi-dimensional loop iterator. +//===----------------------------------------------------------------------===// + +template +void recursionForLoopsIE(const BUILDER &builder, mlir::ArrayRef lbs, + mlir::ArrayRef ubs, mlir::ArrayRef steps, + mlir::ArrayRef useParallel, + llvm::SmallVector &loopIndices, + LoopBodyFn builderFn) { + int64_t d = loopIndices.size(); + if (d < (int64_t)lbs.size()) { + // Issue a loop and recurse again. + builder.forLoopIE(lbs[d], ubs[d], steps[d], useParallel[d], + [&](const BUILDER &b, mlir::ValueRange loopInd) { + loopIndices.emplace_back(loopInd[0]); + recursionForLoopsIE( + b, lbs, ubs, steps, useParallel, loopIndices, builderFn); + }); + } else { + // Call lambda function + BUILDER b(builder); + builderFn(b, loopIndices); + } +} + +template +void forLoopsIE(const BUILDER &builder, mlir::ArrayRef lbs, + mlir::ArrayRef ubs, mlir::ArrayRef steps, + mlir::ArrayRef useParallel, LoopBodyFn builderFn) { + assert(lbs.size() == ubs.size() && "expect same size"); + assert(lbs.size() == steps.size() && "expect same size"); + assert(lbs.size() == useParallel.size() && "expect same size"); + llvm::SmallVector loopIndices; + recursionForLoopsIE( + builder, lbs, ubs, steps, useParallel, loopIndices, builderFn); +} + +} // namespace impl + +//===----------------------------------------------------------------------===// +// Templates for SIMD code gen (instantiated for KRNL and SCF builders) +//===----------------------------------------------------------------------===// + +// Forward declaration to keep template testing happy. +struct KrnlBuilder; + +namespace impl { // Hide support for SIMD iterate/reduce in impl namespace. + +/* +Example of how to use the interface: + +Say you have a loop of i=0..256, j=0..128 and want to exploit r[i,j] = a[i,j] + +b[j] + c. For the loops, we will need access functions for a, b, and r. + +Say we already have the loop for the outer loop of i + +krnl.iterate(loop i from 0 to 256) { + ii is the loop index. + + // 1) compute access function for a, b, c + // 2) launch simd loop with + // 3) simd kernel +} + +1) Access functions + Assuming here that we are not blocking the j loop, namely the simd iteration + goes over all j values, the access functions should be defined as follows. + + aAF = {ii, 0} + bAF = {0} + rAF = {ii, 0} + + If the j loop was blocked (say j=0 to 128 by 16), then instead of `0` in the + last dim, we would have 'blocked_jj' + +2) Launch simd loop + + create.krnl.simdIterateIE( + lb=LitIE(0), ub=litIE(128), totVL=8, // loop params + fullySimd=true, useParallel=false, // loop options + inputs={A, B}, inputAFs={aAF, bAF}, // inputs + outputs={R}, outputAFs={rAF}, // outputs + {krnl}) // lambda function for kernel + +3) Krnl for SIMD loop + + The kernel functions has 4 inputs: + a) krnl builder to further build code + b) list of loaded input values, in the same order as in inputs + c) list of results values, that must be enqueued by the kernel + d) totVL used for the loop (VL for simd, 1 for scalar) + + The same kernels will be used in a SIMD context, in which the inputs and + outputs must be vectors of VL elements, or in a scalar context, in which the + inputs and outputs must be scalars. + + In our example, the kernel is as follows + + [&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { + MultiDialectBuilder create(kb); + Value aVal = inputVals[0]; // simd or scalar + Value bVal = inputVals[1]; // simd or scalar + Value cVal = create.krnl.load(C); // scalar always + Value newVal = create.math.add(aVal, bVal); // simd or scalar + newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is + // splatted + return newVal; // Save simd or scalar result. + } + + The krnl.simdIterateIE will be in charge of loading and saving the values in + memory. The create.math functions have been extended so that when a SIMD + value is computed with a scalar, that scalar will be automaticaly splatted + (aka promoted to a vector of identical values). As a result, the kernel can + be written in a SIMD agnostic value. However, in rare situations, we may + want to know if we are in SIMD mode or not. VL will give the totVL used here + (either totVL>1 or 1). +*/ + +// Definition of SimdIterateBodyFn, see Mlir/DialectBuilder.hpp + +template +void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, + int64_t VL, bool fullySimd, bool useParallel, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef> iterateBodyList) { + int64_t inputNum = inputs.size(); + assert(inputAFs.size() == inputs.size() && "expected same size"); + int64_t outputNum = outputs.size(); + assert(outputAFs.size() == outputs.size() && "expected same size"); + int64_t fnNum = iterateBodyList.size(); + assert((int64_t)fnNum == outputNum && "expect 1 loop function per output"); + + if (VL > 1) { + // Want SIMD, execute full SIMD loops blocked by VL. + + // If we are not guaranteed that every iterations are SIMD iterations, + // then we need to reduce the trip count by a bit so as to not over + // compute. If we are not guaranteed that every iterations are SIMD + // iterations, then + IndexExpr simdUb = ub; + if (!fullySimd) + simdUb = simdUb - (VL - 1); + + // Define the loop block + auto simdLoopBody = [&](const BUILDER b, mlir::ValueRange loopInd) { + IndexExprScope scope(b); + VectorBuilder createVec(b); + MEM_BUILDER createMem(b); + IndexExpr ind = DimIE(loopInd[0]); + llvm::SmallVector vecInputVals; + for (int64_t i = 0; i < inputNum; ++i) { + mlir::Value input = inputs[i]; + if (MemRefBuilder::isNoneValue(input)) { + // Simply enqueue the none value. + vecInputVals.emplace_back(input); + continue; + } + auto type = mlir::cast(input.getType()); + int64_t rank = type.getRank(); + DimsExpr AF = SymListIE(inputAFs[i]); + assert(rank == (int64_t)AF.size() && "AF expected input rank refs"); + if (MemRefBuilder::hasOneElementInInnermostDims(input, 1)) { + // Has a reference with a scalar innermost dim, just load as a + // scalar. No need to add the induction variable. + vecInputVals.emplace_back(createMem.loadIE(input, AF)); + } else { + // Have a vector. + auto vecType = mlir::VectorType::get({VL}, type.getElementType()); + AF[rank - 1] = AF[rank - 1] + ind; // Add induction var. + vecInputVals.emplace_back(createVec.loadIE(vecType, input, AF)); + } + } + // Call the method to compute the values. + llvm::SmallVector vecResVals; + for (int64_t f = 0; f < outputNum; ++f) { + vecResVals.emplace_back(iterateBodyList[f](b, vecInputVals, VL)); + } + // Store all the outputs as vectors of VL values, + for (int64_t i = 0; i < outputNum; ++i) { + auto type = mlir::cast(outputs[i].getType()); + DimsExpr AF = SymListIE(outputAFs[i]); + int64_t rank = type.getRank(); + assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); + AF[rank - 1] = AF[rank - 1] + ind; + createVec.storeIE(vecResVals[i], outputs[i], AF); + } + }; + + // Invocation of the (possibly parallel) SIMD loop. + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, simdUb, VL, useParallel, simdLoopBody); + else + llvm_unreachable("BUILDER type not supported\n"); + + if (fullySimd) + // Asserted that we only have SIMD iterations, we are done. + return; + // Account for the loop iterations performed above. + IndexExpr tripCount = ub - lb; + IndexExpr missingIters = tripCount % VL; + IndexExpr completedIters = tripCount - missingIters; + if (missingIters.isLiteralAndIdenticalTo(0)) { + // Detect that we only have SIMD iterations, we are also done. + return; + } + // We may have additional iterations to perform, adjust lb to skip the + // completed iterations. + lb = lb + completedIters; + } + // Handle remaining scalar values (from lb to ub without unrolling). + auto scalarLoopBody = [&](const BUILDER b, mlir::ValueRange loopInd) { + IndexExprScope scope(b); + MEM_BUILDER createMem(b); + + IndexExpr ind = DimIE(loopInd[0]); + // Load all the inputs as scalar values, + llvm::SmallVector scalarInputVals; + for (int64_t i = 0; i < inputNum; ++i) { + mlir::Value input = inputs[i]; + if (MemRefBuilder::isNoneValue(input)) { + // Simply enqueue the none value. + scalarInputVals.emplace_back(input); + continue; + } + auto type = mlir::cast(input.getType()); + int64_t rank = type.getRank(); + DimsExpr AF = SymListIE(inputAFs[i]); + if (MemRefBuilder::hasOneElementInInnermostDims(input, 1)) { + // Has a reference with a scalar innermost dim, just load as a + // scalar. No need to add the induction variable. + scalarInputVals.emplace_back(createMem.loadIE(input, AF)); + } else { + AF[rank - 1] = AF[rank - 1] + ind; + scalarInputVals.emplace_back(createMem.loadIE(input, AF)); + } + } + // Call the method to compute the values. + llvm::SmallVector scalarResVals; + for (int64_t f = 0; f < outputNum; ++f) { + scalarResVals.emplace_back(iterateBodyList[f](b, scalarInputVals, 1)); + } + // Store all the outputs as vectors of VL values, + for (int64_t i = 0; i < outputNum; ++i) { + auto type = mlir::cast(outputs[i].getType()); + DimsExpr AF = SymListIE(outputAFs[i]); + int64_t rank = type.getRank(); + assert(rank == (int64_t)AF.size() && "AF expected ouput rank refs"); + AF[rank - 1] = AF[rank - 1] + ind; + createMem.storeIE(scalarResVals[i], outputs[i], AF); + } + }; + + // Invocation of the scalar loop. + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, ub, 1, false /*parallel*/, scalarLoopBody); + else + llvm_unreachable("BUILDER type not supported\n"); +} + +/* + Note that because reductions are always between 2 values, the reduction + function takes 1 input and one temp value, where the temp contains the partial + result. So if we have 2 reductions (aka 2 outputs), we also need 2 inputs and + 2 temp. A call to function reductionBodyFnList[k] (namely the kth entry in the + list) will be instantiated with the kth input value and the kth temp, and its + result is ultimately saved into the kth output. + + This was not the case for simdIterateIE, where all of the inputs are provided + to each of the functions computing one output. Here we only pass a pair of + input & temp value to each function. + + This is reflected in the Body types below. + + Allows calls with no outputs and no post-processing functions. In such case, + only perform the reductions into the tmps. +*/ + +// Definition of SimdReductionBodyFn & SimdPostReductionBodyFn, see +// Mlir/DialectBuilder.hpp + +template +void simdReduceIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, + int64_t VL, bool fullySimd, mlir::ArrayRef inputs, + mlir::ArrayRef inputAFs, mlir::ArrayRef tmps, + mlir::ArrayRef tmpAFs, mlir::ArrayRef outputs, + mlir::ArrayRef outputAFs, mlir::ArrayRef initVals, + /* reduction functions (simd or scalar) */ + mlir::ArrayRef> reductionBodyFnList, + /* post reduction functions (simd to scalar + post processing)*/ + mlir::ArrayRef> postReductionBodyFnList) { + + MultiDialectBuilder create(builder); + MEM_BUILDER createMem(builder); + + uint64_t inputSize = inputs.size(); + uint64_t tmpSize = tmps.size(); + uint64_t outputSize = outputs.size(); + // Test same number of values & AFs. + assert(inputAFs.size() == inputSize && "expect same input size"); + assert(tmpAFs.size() == tmpSize && "expect same tmps size"); + assert(outputAFs.size() == outputSize && "expect output same size"); + // Same number of init, reduction functions, tmps as input. + assert(reductionBodyFnList.size() == inputSize && "1 red fn per input"); + assert(tmpSize == inputSize && "expect 1 tmp per input"); + assert(initVals.size() == inputSize && "expect 1 init per input"); + // Same number of post reductions as output. + assert(postReductionBodyFnList.size() == outputSize && "1 red fn per output"); + // Gather element and vector types and perform the inits. Do it in SIMD mode + // regardless. + llvm::SmallVector vectorTypes; + for (uint64_t i = 0; i < inputSize; ++i) { + mlir::Value initVal = initVals[i]; + mlir::Type elementType = initVal.getType(); + auto vectorType = mlir::VectorType::get({VL}, elementType); + vectorTypes.emplace_back(vectorType); + mlir::Value initVec = create.vec.splat(vectorType, initVal); + create.vec.storeIE(initVec, tmps[i], tmpAFs[i]); + } + if (VL > 1) { + // Logic: see simdIterateIE. + IndexExpr simdUb = ub; + if (!fullySimd) + simdUb = simdUb - (VL - 1); + + auto simdLoopBody = [&](const BUILDER &b, mlir::ValueRange loopInd) { + IndexExprScope scope(b); + MultiDialectBuilder create(b); + // Load inputs in SIMD mode, indexed by loopInd[0] in innermost dim. + llvm::SmallVector inputVals; + for (uint64_t i = 0; i < inputSize; ++i) { + auto inputType = mlir::cast(inputs[i].getType()); + auto vecType = mlir::VectorType::get({VL}, inputType.getElementType()); + inputVals.emplace_back( + create.vec.loadIE(vecType, inputs[i], inputAFs[i], {loopInd[0]})); + } + // Load tmp value in SIMD mode (no indexing, same value over & over). + llvm::SmallVector tmpVals; + for (uint64_t i = 0; i < inputSize; ++i) { + tmpVals.emplace_back( + create.vec.loadIE(vectorTypes[i], tmps[i], tmpAFs[i])); + } + // Call reduction, one per function each with their input and tmp value. + llvm::SmallVector resultVals; + for (uint64_t i = 0; i < inputSize; ++i) { + resultVals.emplace_back( + reductionBodyFnList[i](b, inputVals[i], tmpVals[i], VL)); + } + // Save tmp values in SIMD mode. + for (uint64_t i = 0; i < inputSize; ++i) { + create.vec.storeIE(resultVals[i], tmps[i], tmpAFs[i]); + } + }; + + // Want SIMD, execute full SIMD loops reductions blocked by VL. + // Perform SIMD reduction: iterates over all SIMD vectors. + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, simdUb, VL, false /*parallel*/, simdLoopBody); + else + llvm_unreachable("BUILDER type not supported"); + + if (fullySimd) { + // No leftovers, no additional iterations to be done. + } else { + // Account for the loop iterations performed above. + IndexExpr tripCount = ub - lb; + IndexExpr missingIters = tripCount % VL; + IndexExpr completedIters = tripCount - missingIters; + if (missingIters.isLiteralAndIdenticalTo(0)) { + // Detected that we have no missing iterations. Ee are done, namely + // fullySimd is true. + fullySimd = true; + } else { + // We may have additional iterations to perform, adjust lb to skip the + // completed iterations. + lb = lb + completedIters; + } + } + } else { + // VL was 1, set fullySimd to false so that we execute all iterations + // sequentially. + fullySimd = false; + } + if (!fullySimd) { + // We have leftover iterations to be done in sequential mode. + // Handle remaining scalar values (from lb to ub without unrolling). + + auto scalarLoopBody = [&](const BUILDER &b, mlir::ValueRange loopInd) { + IndexExprScope scope(b); + MEM_BUILDER createMem(b); + IndexExpr ind = DimIE(loopInd[0]); + // We now perform sequential reduction in the tmps 1st element. Load + // inputs in sequential mode indexed by loopInd[0] in innermost dim. + llvm::SmallVector inputVals; + for (uint64_t i = 0; i < inputSize; ++i) { + inputVals.emplace_back( + createMem.loadIE(inputs[i], inputAFs[i], {loopInd[0]})); + } + // Load tmps in scalar mode (no indexing, same value over & over). + llvm::SmallVector tmpVals; + for (uint64_t i = 0; i < inputSize; ++i) { + tmpVals.emplace_back(createMem.loadIE(tmps[i], tmpAFs[i])); + } + // Call reduction. + llvm::SmallVector resultVals; + for (uint64_t i = 0; i < inputSize; ++i) { + resultVals.emplace_back( + reductionBodyFnList[i](b, inputVals[i], tmpVals[i], 1)); + } + // Save tmp values in sequential mode. + for (uint64_t i = 0; i < inputSize; ++i) { + createMem.storeIE(resultVals[i], tmps[i], tmpAFs[i]); + } + }; + + // Perform scalar loop. + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, ub, 1, false /*parallel*/, scalarLoopBody); + else + llvm_unreachable("BUILDER type not supported"); + } + + if (outputSize == 0) + return; // No outputs, we are done. + + // Now perform post processing. Load all tmps. + assert(tmpSize == outputSize && "expect one tmp per output"); + llvm::SmallVector tmpVals; + for (uint64_t o = 0; o < outputSize; ++o) { + // Load tmp in vector mode. + tmpVals.emplace_back(create.vec.loadIE(vectorTypes[o], tmps[o], tmpAFs[o])); + } + llvm::SmallVector finalResults; + // Invoke the post processing operations, which takes each tmp vector and + // reduces it to a scalar. + for (uint64_t o = 0; o < outputSize; ++o) { + finalResults.emplace_back( + postReductionBodyFnList[o](builder, tmpVals[o], 1)); + } + // Store the scalar reductions. + for (uint64_t o = 0; o < outputSize; ++o) { + createMem.storeIE(finalResults[o], outputs[o], outputAFs[o]); + } +} + +template +void simdReduce2DIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, + int64_t VL, bool fullySimd, mlir::Value input, DimsExpr inputAF, + mlir::Value tmp, DimsExpr tmpAF, mlir::Value output, DimsExpr outputAF, + mlir::Value initVal, + /* reduction functions (simd or scalar) */ + SimdReductionBodyFn reductionBodyFn, + /* post reduction functions (simd to scalar + post processing)*/ + SimdPostReductionBodyFn postReductionBodyFn) { + // Expect 2D or more input and tmp. + auto inputType = mlir::cast(input.getType()); + auto tmpType = mlir::cast(tmp.getType()); + uint64_t inputRank = inputType.getRank(); + uint64_t tmpRank = tmpType.getRank(); + assert(inputRank == inputAF.size() && "expected same size"); + assert(tmpRank == tmpAF.size() && "expected same size"); + assert(inputRank >= 2 && "expected rank 2D+"); + assert(tmpRank >= 2 && "expected rank 2D+"); + mlir::Type elementType = inputType.getElementType(); + + // Perform a VL x VL reduction along the innermost 2 dimensions. + // Reuse the simdReduceIE functionality to do so. + llvm::SmallVector newInputs(VL, input); + llvm::SmallVector newInputAFs(VL, inputAF); + llvm::SmallVector newTmps(VL, tmp); + llvm::SmallVector newTmpAFs(VL, tmpAF); + llvm::SmallVector newInitVals(VL, initVal); + llvm::SmallVector, 8> newReductionBodyFnList( + VL, reductionBodyFn); + + // Init the new data structures for VL reductions of VL values + uint64_t inputM2 = inputRank - 2; + uint64_t tmpM2 = tmpRank - 2; + for (int64_t v = 0; v < VL; ++v) { + // Each inputs/tmp is offset by 1 in the second to last dim; + newInputAFs[v][inputM2] = newInputAFs[v][inputM2] + v; + newTmpAFs[v][tmpM2] = newTmpAFs[v][tmpM2] + v; + } + // Step 1: perform the reduction of VL vectors into VL tmps. No output & post + // reduction as we will do it here. + builder.simdReduceIE(lb, ub, VL, fullySimd, newInputs, newInputAFs, newTmps, + newTmpAFs, {}, {}, newInitVals, newReductionBodyFnList, {}); + + // Step 2, perform reduction of VL vectors of VL values into 1 vector of VL. + // Load all temp vectors. + llvm::SmallVector redIn, redOut; + MultiDialectBuilder create(builder); + mlir::VectorType vecType = mlir::VectorType::get({VL}, elementType); + for (int64_t v = 0; v < VL; ++v) { + redIn.emplace_back(create.vec.loadIE(vecType, newTmps[v], newTmpAFs[v])); + } + // Reduce all of the temp vectors at once. + auto redFct = [&](mlir::Value a, mlir::Value b) -> mlir::Value { + return reductionBodyFn(builder, a, b, VL); + }; + create.vec.multiReduction(redIn, redFct, redOut); + // The redOut list should have one value with SIMD of VL. + assert(redOut.size() == 1 && "expected only one val"); + mlir::Value accumulatedVal = redOut[0]; + // Perform post processing (e.g. division by number of elements). + accumulatedVal = postReductionBodyFn(builder, accumulatedVal, VL); + // Store final values. + create.vec.storeIE(accumulatedVal, output, outputAF); +} + +} // namespace impl + +//===----------------------------------------------------------------------===// +// Templates for GenericAffineBuilder +//===----------------------------------------------------------------------===// + template -mlir::Value GenericAffineBuilder::loadIE(mlir::Value memref, - llvm::ArrayRef indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - return load(memref, computedIndices); +mlir::Value GenericAffineBuilder::load(mlir::Value memref, + mlir::ValueRange indices, mlir::ValueRange offsets) const { + return onnx_mlir::impl::load( + *this, memref, indices, offsets); } template -inline void GenericAffineBuilder::store( - mlir::Value val, mlir::Value memref, mlir::ValueRange indices) const { - b().template create(loc(), val, memref, indices); +mlir::Value GenericAffineBuilder::loadIE(mlir::Value memref, + mlir::ArrayRef indices, mlir::ValueRange offsets) const { + return onnx_mlir::impl::loadIE( + *this, memref, indices, offsets); } template inline void GenericAffineBuilder::store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - store(val, memref, computedIndices); + onnx_mlir::impl::store( + *this, val, memref, indices, offsets); } template inline void GenericAffineBuilder::storeIE(mlir::Value val, - mlir::Value memref, llvm::ArrayRef indices, + mlir::Value memref, mlir::ArrayRef indices, mlir::ValueRange offsets) const { - llvm::SmallVector computedIndices; - MathBuilder createMath(*this); - createMath.addOffsetToLeastSignificant(indices, offsets, computedIndices); - store(val, memref, computedIndices); + onnx_mlir::impl::storeIE( + *this, val, memref, indices, offsets); } template @@ -74,45 +650,114 @@ inline mlir::Operation *GenericAffineBuilder::prefetch( } template -inline void GenericAffineBuilder::forIE(IndexExpr lb, - IndexExpr ub, int64_t step, - mlir::function_ref builderFn) - const { - // Transform IndexExpressions into value maps and list of operands. +inline void GenericAffineBuilder::forLoopIE(IndexExpr lb, + IndexExpr ub, int64_t step, bool useParallel, + GenericAffineLoopBodyFn builderFn) const { + // Transform IndexExpressions into value maps and list of + // operands. mlir::AffineMap lbMap, ubMap; llvm::SmallVector lbOperands, ubOperands; lb.getAffineMapAndOperands(lbMap, lbOperands); ub.getAffineMapAndOperands(ubMap, ubOperands); - // Create affine for. - b().template create(loc(), lbOperands, lbMap, - ubOperands, ubMap, step, mlir::ValueRange{}, - [&](mlir::OpBuilder &b, mlir::Location loc, mlir::Value index, - mlir::ValueRange args) { - GenericAffineBuilder createAffine(b, loc); - builderFn(createAffine, index); - createAffine.yield(); - }); + + if (useParallel) { + // Create affine parallel for. + llvm::SmallVector types; + llvm::SmallVector reds; + llvm::SmallVector lbs, ubs; + llvm::SmallVector steps; + lbs.emplace_back(lbMap); + ubs.emplace_back(ubMap); + steps.emplace_back(step); + auto parallelLoop = b().template create( + loc(), types, reds, lbs, lbOperands, ubs, ubOperands, steps); + mlir::Block *bodyBlock = parallelLoop.getBody(); + // From extractInductionVars in AffineOps.cpp. + assert(bodyBlock->getNumArguments() == 1 && "expected one loop index"); + mlir::Value index = bodyBlock->getArgument(0); + // Code inspired from AffineForOp::build in AffineOps.cpp. + mlir::OpBuilder::InsertionGuard guard(b()); + b().setInsertionPointToStart(bodyBlock); + GenericAffineBuilder createAffine(b(), loc()); + builderFn(createAffine, {index}); + createAffine.yield(); + } else { + // Create affine for. + b().template create(loc(), lbOperands, lbMap, + ubOperands, ubMap, step, mlir::ValueRange{}, + [&](mlir::OpBuilder &b, mlir::Location loc, mlir::Value index, + mlir::ValueRange args) { + GenericAffineBuilder createAffine(b, loc); + builderFn(createAffine, {index}); + createAffine.yield(); + }); + } } template -inline void GenericAffineBuilder::forIE( - llvm::SmallVectorImpl &lbs, - llvm::SmallVectorImpl &ubs, - llvm::SmallVectorImpl &steps, - mlir::function_ref - builderFn) const { - assert(lbs.size() == ubs.size() && "expected identical sizes"); - assert(lbs.size() == steps.size() && "expected identical sizes"); - llvm::SmallVector loopIndices; - recursionForIE(lbs, ubs, steps, loopIndices, builderFn); +inline void GenericAffineBuilder::forLoopsIE( + mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, mlir::ArrayRef useParallel, + GenericAffineLoopBodyFn builderFn) const { + impl::forLoopsIE(*this, lbs, ubs, steps, useParallel, builderFn); +} + +// Sequential only version. +template +inline void GenericAffineBuilder::forLoopIE(IndexExpr lb, + IndexExpr ub, int64_t step, GenericAffineLoopBodyFn builderFn) const { + forLoopIE(lb, ub, step, false /*use parallel*/, builderFn); +} + +template +inline void GenericAffineBuilder::simdIterateIE(IndexExpr lb, + IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef bodyFnList) const { + onnx_mlir::impl::simdIterateIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, useParallel, inputs, + inputAFs, outputs, outputAFs, bodyFnList); +} + +template +inline void GenericAffineBuilder::simdReduceIE(IndexExpr lb, + IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef tmps, mlir::ArrayRef tmpAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef reductionFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef postReductionFnList) + const { + onnx_mlir::impl::simdReduceIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, inputs, inputAFs, tmps, + tmpAFs, outputs, outputAFs, initVals, reductionFnList, + postReductionFnList); +} + +template +inline void GenericAffineBuilder::simdReduce2DIE( + IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, mlir::Value input, + DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, mlir::Value output, + DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + GenericAffineSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + GenericAffineSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, input, inputAF, tmp, tmpAF, + output, outputAF, initVal, reductionBodyFn, postReductionBodyFn); } // This if then else construct has no arguments to the blocks. template -inline void GenericAffineBuilder::ifThenElse( - IndexExprScope &scope, llvm::SmallVectorImpl &conditions, - mlir::function_ref thenFn, - mlir::function_ref elseFn) const { +inline void GenericAffineBuilder::ifThenElseIE( + IndexExprScope &scope, mlir::ArrayRef conditions, + GenericAffineThenElseBodyFn thenFn, + GenericAffineThenElseBodyFn elseFn) const { int64_t rank = conditions.size(); llvm::SmallVector affineCond; bool allTrue = true; @@ -153,32 +798,14 @@ inline void GenericAffineBuilder::ifThenElse( } template -inline void GenericAffineBuilder::yield() const { - b().template create(loc()); +mlir::Value GenericAffineBuilder::apply( + mlir::AffineMap map, mlir::ValueRange operands) const { + return b().template create(loc(), map, operands); } -// Support for multiple forIE loops. template -void GenericAffineBuilder::recursionForIE( - llvm::SmallVectorImpl &lbs, - llvm::SmallVectorImpl &ubs, - llvm::SmallVectorImpl &steps, - llvm::SmallVectorImpl &loopIndices, - mlir::function_ref - builderFn) const { - int d = loopIndices.size(); - if (d < (int)lbs.size()) { - // Issue a loop and recurse again. - forIE(lbs[d], ubs[d], steps[d], - [&](GenericAffineBuilder &createAffine, mlir::Value i) { - loopIndices.emplace_back(i); - recursionForIE(lbs, ubs, steps, loopIndices, builderFn); - }); - } else { - // Call lambda function - GenericAffineBuilder createAffine(b(), loc()); - builderFn(createAffine, loopIndices); - } +inline void GenericAffineBuilder::yield() const { + b().template create(loc()); } // Support for adding blocks. @@ -194,9 +821,3 @@ inline void GenericAffineBuilder::appendToBlock( b().setInsertionPoint(&block->back()); builderFn(block->getArguments()); } - -template -mlir::Value GenericAffineBuilder::apply( - mlir::AffineMap map, mlir::ValueRange operands) const { - return b().template create(loc(), map, operands); -} diff --git a/src/Dialect/Mlir/IndexExpr.cpp b/src/Dialect/Mlir/IndexExpr.cpp index a72cb1321a..65ac1f30ce 100644 --- a/src/Dialect/Mlir/IndexExpr.cpp +++ b/src/Dialect/Mlir/IndexExpr.cpp @@ -49,7 +49,7 @@ IndexExprScope::IndexExprScope(OpBuilder *rewriter, Location loc) getCurrentScopePtr() = this; } -IndexExprScope::IndexExprScope(DialectBuilder &db) +IndexExprScope::IndexExprScope(const DialectBuilder &db) : IndexExprScope(&db.getBuilder(), db.getLoc()) {} // Nested scopes. @@ -73,7 +73,7 @@ IndexExprScope::IndexExprScope( } IndexExprScope::IndexExprScope( - DialectBuilder &innerDb, IndexExprScope *enclosingScope) + const DialectBuilder &innerDb, IndexExprScope *enclosingScope) : IndexExprScope(&innerDb.getBuilder(), enclosingScope) {} IndexExprScope::~IndexExprScope() { @@ -338,7 +338,7 @@ bool IndexExpr::isLiteralAndSmallerThan(IndexExpr const b) const { } // All element in list are literals. -/*static*/ bool IndexExpr::isLiteral(SmallVectorImpl &list) { +/*static*/ bool IndexExpr::isLiteral(ArrayRef list) { for (IndexExpr i : list) if (!i.isLiteral()) return false; @@ -346,8 +346,7 @@ bool IndexExpr::isLiteralAndSmallerThan(IndexExpr const b) const { } // All element in list are literals and non-negative (i.e. >= 0). -/*static*/ bool IndexExpr::isNonNegativeLiteral( - SmallVectorImpl &list) { +/*static*/ bool IndexExpr::isNonNegativeLiteral(ArrayRef list) { for (IndexExpr i : list) if (!i.isLiteral() || i.getLiteral() < 0) return false; @@ -371,7 +370,7 @@ bool IndexExpr::canBeUsedInScope() const { switch (getKind()) { case IndexExprKind::NonAffine: case IndexExprKind::Predicate: - // Its ok to use a nonaffine index expressions from enclosing scopes. + // Its ok to use a non-affine index expressions from enclosing scopes. assert(hasValue() && "must have value to be used from enclosing scopes"); return getScope().isEnclosingScope(); break; @@ -462,7 +461,7 @@ void IndexExpr::debugPrint(const std::string &msg) const { } void IndexExpr::debugPrint( - const std::string &msg, const SmallVectorImpl &list) { + const std::string &msg, const ArrayRef list) { LLVM_DEBUG({ int s = list.size(); llvm::dbgs() << msg.c_str() << " (" << s << " elements)\n"; @@ -525,7 +524,7 @@ void IndexExpr::debugPrint( /* static*/ void IndexExpr::getAffineMapAndOperands( ArrayRef indexExprArray, AffineMap &map, - SmallVectorImpl &operands) { + SmallVectorImpl &operands) { assert(indexExprArray.size() > 0 && "expected at least one index expr"); SmallVector affineExprList; for (IndexExpr expr : indexExprArray) { @@ -559,10 +558,12 @@ static bool isIdentical(const IndexExpr litExpr, double dval) { return litExpr.isLiteralAndIdenticalTo(ival); } -// Used for add/sub/mult/ceilDiv/floorDiv -IndexExpr IndexExpr::binaryOp(IndexExpr const b, bool affineWithLitB, - bool hasNeutralA, bool hasNeutralB, double neutralVal, F2 litFct, - F2 affineExprFct, F2 valueFct) const { +// Used for add/sub/mult/ceilDiv/floorDiv. +// Add/sub: B does not need to be a literal for the result to be affine. +// All the other ones (mul, div*, mod) require the B to be a literal. +IndexExpr IndexExpr::binaryOp(IndexExpr const b, bool propagateIntoMinMax, + bool affineWithLitB, bool hasNeutralA, bool hasNeutralB, double neutralVal, + F2 litFct, F2 affineExprFct, F2 valueFct) const { assert(litFct && "expect lit function"); assert(valueFct && "expect value function"); assert(canBeUsedInScope() && "a cannot be used in current scope"); @@ -573,9 +574,10 @@ IndexExpr IndexExpr::binaryOp(IndexExpr const b, bool affineWithLitB, bool canBeAffine = (affineExprFct != nullptr); bool resIsAffine = resIsLit || (canBeAffine && isAffine() && b.isAffine() && (!affineWithLitB || b.isLiteral())); - // Test if we have a neutral value. - if (hasNeutralA && isIdentical(*this, neutralVal)) - return b.deepCopy(); // Copy of the other value (use same questionmark). + if (resIsAffine) + // Test if we have a neutral value. + if (hasNeutralA && isIdentical(*this, neutralVal)) + return b.deepCopy(); // Copy of the other value (use same questionmark). if (hasNeutralB && isIdentical(b, neutralVal)) { return deepCopy(); // Copy of the other value (use same questionmark). } @@ -591,6 +593,43 @@ IndexExpr IndexExpr::binaryOp(IndexExpr const b, bool affineWithLitB, if (resIsAffine) // Use affine values. return affineExprFct(*this, b); + // See if we have a min/max on one side that we can propagate into. + if (canBeAffine && propagateIntoMinMax) { + Value valA = this->getValue(); + bool hasMinMaxA = valA.getDefiningOp() || + valA.getDefiningOp(); + Value valB = b.getValue(); + bool hasMinMaxB = valB.getDefiningOp() || + valB.getDefiningOp(); + // Can handle only cases where either a or b are min/max and the other one + // is affine. + if ((hasMinMaxA && !hasMinMaxB && b.isAffine()) || + (!hasMinMaxA && hasMinMaxB && this->isAffine())) { + // Of the two inputs, find out the one with the min/max. + IndexExpr minMaxIE = hasMinMaxA ? *this : b; + // Retrieve the map and list of dim/symbols in the current scope + bool isMin; + llvm::SmallVector vals; + AffineMap map; + assert(minMaxIE.retrieveAffineMinMax(isMin, vals, map) && "expected one"); + // Perform the affineExprFct for each min/max terms. + llvm::SmallVector updatedMinMaxExprs; + for (AffineExpr affineExpr : map.getResults()) { + IndexExpr oldAffineExpr = AffineIndexExpr(affineExpr); + IndexExpr newAffineExpr; + if (hasMinMaxA) + newAffineExpr = affineExprFct(oldAffineExpr, b); + else + newAffineExpr = affineExprFct(*this, oldAffineExpr); + updatedMinMaxExprs.emplace_back(newAffineExpr); + } + // Create new operation. + if (isMin) { + return IndexExpr::min(updatedMinMaxExprs); + } + return IndexExpr::max(updatedMinMaxExprs); + } + } // Use values. return valueFct(*this, b); } @@ -669,7 +708,8 @@ IndexExpr IndexExpr::compareOp( // Cannot have affine results, disable and pass null lambda function. // Ignore possible neutral values. assert(!areFloat(b) && "integer compare"); - return binaryOp(b, false, false, false, 0.0, litFct, nullptr, valueFct); + return binaryOp( + b, false, false, false, false, 0.0, litFct, nullptr, valueFct); } // Floating point version. @@ -719,7 +759,7 @@ IndexExpr IndexExpr::compareOp( // Ignore possible neutral values. assert(areFloat(b) && "float compare"); return binaryOp( - b, false, false, false, 0.0, litFloatFct, nullptr, valueFloatFct); + b, false, false, false, false, 0.0, litFloatFct, nullptr, valueFloatFct); } // Conjunction of two conditions: And @@ -778,8 +818,8 @@ IndexExpr IndexExpr::operator!() const { // The affine reduction lambda function processes the whole list and must init // the result. Literal and Values treat one operation at a time -/* static*/ IndexExpr IndexExpr::reductionOp(SmallVectorImpl &vals, - F2Self litRed, Flist affineRed, F2Self valueRed) { +/* static*/ IndexExpr IndexExpr::reductionOp( + ArrayRef vals, F2Self litRed, Flist affineRed, F2Self valueRed) { // If no values, result is undefined. int size = vals.size(); if (size == 0) @@ -831,10 +871,10 @@ IndexExpr IndexExpr::operator!() const { IndexExpr IndexExpr::operator+(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getLiteral() + bb.getLiteral()); + return LitIE(aa.getLiteral() + bb.getLiteral()); }; F2 litFloatFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getFloatLiteral() + bb.getFloatLiteral()); + return LitIE(aa.getFloatLiteral() + bb.getFloatLiteral()); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { return AffineIndexExpr(aa.getAffineExpr() + bb.getAffineExpr()); @@ -845,16 +885,18 @@ IndexExpr IndexExpr::operator+(IndexExpr const b) const { }; // Neutral value: a + 0 = a, 0 + b = b. if (areFloat(b)) - return binaryOp(b, false, true, true, 0.0, litFloatFct, nullptr, valueFct); - return binaryOp(b, false, true, true, 0.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, false, true, true, 0.0, litFloatFct, nullptr, valueFct); + return binaryOp( + b, true, false, true, true, 0.0, litFct, affineExprFct, valueFct); } IndexExpr IndexExpr::operator-(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getLiteral() - bb.getLiteral()); + return LitIE(aa.getLiteral() - bb.getLiteral()); }; F2 litFloatFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getFloatLiteral() - bb.getFloatLiteral()); + return LitIE(aa.getFloatLiteral() - bb.getFloatLiteral()); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { return AffineIndexExpr(aa.getAffineExpr() - bb.getAffineExpr()); @@ -865,16 +907,18 @@ IndexExpr IndexExpr::operator-(IndexExpr const b) const { }; // Neutral value: a - 0 = a. if (areFloat(b)) - return binaryOp(b, false, false, true, 0.0, litFloatFct, nullptr, valueFct); - return binaryOp(b, false, false, true, 0.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, false, false, true, 0.0, litFloatFct, nullptr, valueFct); + return binaryOp( + b, true, false, false, true, 0.0, litFct, affineExprFct, valueFct); } IndexExpr IndexExpr::operator*(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getLiteral() * bb.getLiteral()); + return LitIE(aa.getLiteral() * bb.getLiteral()); }; F2 litFloatFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { - return LiteralIndexExpr(aa.getFloatLiteral() * bb.getFloatLiteral()); + return LitIE(aa.getFloatLiteral() * bb.getFloatLiteral()); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { return AffineIndexExpr(aa.getAffineExpr() * bb.getAffineExpr()); @@ -885,12 +929,14 @@ IndexExpr IndexExpr::operator*(IndexExpr const b) const { }; // Neutral value: a * 1 = a, 1 * b = b. if (areFloat(b)) - return binaryOp(b, false, true, true, 1.0, litFloatFct, nullptr, valueFct); + return binaryOp( + b, false, false, true, true, 1.0, litFloatFct, nullptr, valueFct); // For affine, requires one to be a literal, and in "b" (argument). if (isLiteral()) return b.binaryOp( - *this, true, true, true, 1.0, litFct, affineExprFct, valueFct); - return binaryOp(b, true, true, true, 1.0, litFct, affineExprFct, valueFct); + *this, false, true, true, true, 1.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, true, true, true, 1.0, litFct, affineExprFct, valueFct); } // Int operator @@ -898,7 +944,7 @@ IndexExpr IndexExpr::floorDiv(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { int64_t rval = std::floor((1.0 * aa.getLiteral()) / (1.0 * bb.getLiteral())); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { // Operand bb must be a literal. @@ -917,14 +963,15 @@ IndexExpr IndexExpr::floorDiv(IndexExpr const b) const { // Index b must be a literal. // Neutral value: a / 1 = a. assert(!areFloat(b) && "floor div only supports int"); - return binaryOp(b, true, false, true, 1.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, true, false, true, 1.0, litFct, affineExprFct, valueFct); } // Int operator IndexExpr IndexExpr::ceilDiv(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { int64_t rval = std::ceil((1.0 * aa.getLiteral()) / (1.0 * bb.getLiteral())); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { // Operand bb must be a literal. @@ -941,14 +988,15 @@ IndexExpr IndexExpr::ceilDiv(IndexExpr const b) const { // Index b must be a literal. // Neutral value: a / 1 = a. assert(!areFloat(b) && "ceil div only supports int"); - return binaryOp(b, true, false, true, 1.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, true, false, true, 1.0, litFct, affineExprFct, valueFct); } // Int operator IndexExpr IndexExpr::operator%(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { int64_t rval = llvm::mod(aa.getLiteral(), bb.getLiteral()); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F2 affineExprFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { // Operand bb must be a literal. @@ -965,14 +1013,15 @@ IndexExpr IndexExpr::operator%(IndexExpr const b) const { // Index b must be a literal. // Neutral value: ignore here that x % x = 0. assert(!areFloat(b) && "mod only supports int"); - return binaryOp(b, true, false, false, 1.0, litFct, affineExprFct, valueFct); + return binaryOp( + b, false, true, false, false, 1.0, litFct, affineExprFct, valueFct); } // Float operator IndexExpr IndexExpr::operator/(IndexExpr const b) const { F2 litFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { double rval = aa.getFloatLiteral() / bb.getFloatLiteral(); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F2 valueFct = [](IndexExpr const aa, IndexExpr const bb) -> IndexExpr { MathBuilder createMath(aa.getRewriter(), aa.getLoc()); @@ -980,14 +1029,16 @@ IndexExpr IndexExpr::operator/(IndexExpr const b) const { }; // Neutral value: x / 1 = x. assert(areFloat(b) && "float only; int: use ceilDiv or floorDiv"); - return binaryOp(b, false, false, true, 1.0, litFct, nullptr, valueFct); + // Note: there are no affine functions for float, so affineWithLitB==true or + // false is irrelevant. + return binaryOp(b, false, false, false, true, 1.0, litFct, nullptr, valueFct); } // Float operator. IndexExpr IndexExpr::ceil() const { F1 litFct = [](IndexExpr const aa) -> IndexExpr { double rval = std::ceil(aa.getFloatLiteral()); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F1 valueFct = [](IndexExpr const aa) -> IndexExpr { MathBuilder createMath(aa.getRewriter(), aa.getLoc()); @@ -1003,7 +1054,7 @@ IndexExpr IndexExpr::ceil() const { IndexExpr IndexExpr::floor() const { F1 litFct = [](IndexExpr const aa) -> IndexExpr { double rval = std::floor(aa.getFloatLiteral()); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F1 valueFct = [](IndexExpr const aa) -> IndexExpr { MathBuilder createMath(aa.getRewriter(), aa.getLoc()); @@ -1020,7 +1071,7 @@ IndexExpr IndexExpr::floor() const { IndexExpr IndexExpr::convertToFloat() const { F1 litFct = [](IndexExpr const aa) -> IndexExpr { double rval = (double)aa.getLiteral(); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F1 valueFct = [](IndexExpr const aa) -> IndexExpr { MathBuilder createMath(aa.getRewriter(), aa.getLoc()); @@ -1037,7 +1088,7 @@ IndexExpr IndexExpr::convertToFloat() const { IndexExpr IndexExpr::convertToIndex() const { F1 litFct = [](IndexExpr const aa) -> IndexExpr { int64_t rval = (int64_t)aa.getFloatLiteral(); - return LiteralIndexExpr(rval); + return LitIE(rval); }; F1 valueFct = [](IndexExpr const aa) -> IndexExpr { MathBuilder createMath(aa.getRewriter(), aa.getLoc()); @@ -1116,7 +1167,7 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { return NonAffineIndexExpr(results); } -/*static*/ IndexExpr IndexExpr::min(SmallVectorImpl &vals) { +/*static*/ IndexExpr IndexExpr::min(ArrayRef vals) { // Res is already an literal int, we are reducing into it. F2Self litFct = [](IndexExpr res, IndexExpr const aa) -> IndexExpr { if (aa.isLiteralAndSmallerThan(res)) @@ -1124,13 +1175,13 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { return res; }; Flist affineExprFct = [&](IndexExpr res, - SmallVectorImpl &vvals) -> IndexExpr { + ArrayRef vvals) -> IndexExpr { // Create a list of affine expression assert(vvals.size() > 1 && "come here only with 2 or more values"); SmallVector affineExprs; // Important to get the affine expressions before getting the // dims/symbols. - for (IndexExpr &vv : vvals) { + for (IndexExpr vv : vvals) { affineExprs.emplace_back(vv.getAffineExpr()); } // Compute a map including the list of affine expressions. @@ -1171,11 +1222,11 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { /*static*/ IndexExpr IndexExpr::min( IndexExpr const first, int64_t const second) { - SmallVector list = {first, LiteralIndexExpr(second)}; + SmallVector list = {first, LitIE(second)}; return min(list); } -/*static*/ IndexExpr IndexExpr::max(SmallVectorImpl &vals) { +/*static*/ IndexExpr IndexExpr::max(ArrayRef vals) { // Res is already an literal int, we are reducing into it. F2Self litFct = [](IndexExpr res, IndexExpr const aa) -> IndexExpr { if (aa.isLiteralAndGreaterThan(res)) @@ -1183,13 +1234,13 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { return res; }; Flist affineExprFct = [&](IndexExpr res, - SmallVectorImpl &vvals) -> IndexExpr { + ArrayRef vvals) -> IndexExpr { // Create a list of affine expression assert(vvals.size() > 1 && "come here only with 2 or more values"); SmallVector affineExprs; // Important to get the affine expressions before getting the // dims/symbols. - for (IndexExpr &vv : vvals) { + for (IndexExpr vv : vvals) { affineExprs.emplace_back(vv.getAffineExpr()); } // Compute a map including the list of affine expressions. @@ -1230,7 +1281,7 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { /*static*/ IndexExpr IndexExpr::max( IndexExpr const first, int64_t const second) { - SmallVector list = {first, LiteralIndexExpr(second)}; + SmallVector list = {first, LitIE(second)}; return max(list); } @@ -1238,16 +1289,38 @@ IndexExpr IndexExpr::clamp(IndexExpr const min, IndexExpr const max) const { // IndexExpr Ops Derivatives //===----------------------------------------------------------------------===// +bool IndexExpr::retrieveAffineMinMax( + bool &isMin, llvm::SmallVectorImpl &vals, AffineMap &map) const { + Value val = this->getValue(); + auto minOp = val.getDefiningOp(); + auto maxOp = val.getDefiningOp(); + // Expect here the defining op to be either min or max. + if (minOp == nullptr && maxOp == nullptr) + return false; + isMin = minOp != nullptr; + if (isMin) + map = minOp.getAffineMap(); + else + map = maxOp.getAffineMap(); + IndexExprScope &scope = this->getScope(); + scope.getDimAndSymbolList(vals); + return true; +} + +//===----------------------------------------------------------------------===// +// IndexExpr Ops Derivatives +//===----------------------------------------------------------------------===// + IndexExpr IndexExpr::operator+(int64_t const b) const { - return *this + LiteralIndexExpr(b); + return *this + LitIE(b); } IndexExpr IndexExpr::operator-(int64_t const b) const { - return *this - LiteralIndexExpr(b); + return *this - LitIE(b); } IndexExpr IndexExpr::operator*(int64_t const b) const { - return *this * LiteralIndexExpr(b); + return *this * LitIE(b); } IndexExpr IndexExpr::operator==(IndexExpr const b) const { @@ -1257,7 +1330,7 @@ IndexExpr IndexExpr::operator==(IndexExpr const b) const { } IndexExpr IndexExpr::operator==(int64_t const b) const { - return *this == LiteralIndexExpr(b); + return *this == LitIE(b); } IndexExpr IndexExpr::operator!=(IndexExpr const b) const { @@ -1267,7 +1340,7 @@ IndexExpr IndexExpr::operator!=(IndexExpr const b) const { } IndexExpr IndexExpr::operator!=(int64_t const b) const { - return *this != LiteralIndexExpr(b); + return *this != LitIE(b); } IndexExpr IndexExpr::operator<=(IndexExpr const b) const { @@ -1277,7 +1350,7 @@ IndexExpr IndexExpr::operator<=(IndexExpr const b) const { } IndexExpr IndexExpr::operator<=(int64_t const b) const { - return *this <= LiteralIndexExpr(b); + return *this <= LitIE(b); } IndexExpr IndexExpr::operator<(IndexExpr const b) const { @@ -1287,7 +1360,7 @@ IndexExpr IndexExpr::operator<(IndexExpr const b) const { } IndexExpr IndexExpr::operator<(int64_t const b) const { - return *this < LiteralIndexExpr(b); + return *this < LitIE(b); } IndexExpr IndexExpr::operator>=(IndexExpr const b) const { @@ -1297,7 +1370,7 @@ IndexExpr IndexExpr::operator>=(IndexExpr const b) const { } IndexExpr IndexExpr::operator>=(int64_t const b) const { - return *this >= LiteralIndexExpr(b); + return *this >= LitIE(b); } IndexExpr IndexExpr::operator>(IndexExpr const b) const { @@ -1307,36 +1380,36 @@ IndexExpr IndexExpr::operator>(IndexExpr const b) const { } IndexExpr IndexExpr::operator>(int64_t const b) const { - return *this > LiteralIndexExpr(b); + return *this > LitIE(b); } IndexExpr IndexExpr::operator%(int64_t const b) const { - return *this % LiteralIndexExpr(b); + return *this % LitIE(b); } IndexExpr IndexExpr::floorDiv(int64_t const b) const { - return this->floorDiv(LiteralIndexExpr(b)); + return this->floorDiv(LitIE(b)); } IndexExpr IndexExpr::ceilDiv(int64_t const b) const { - return this->ceilDiv(LiteralIndexExpr(b)); + return this->ceilDiv(LitIE(b)); } IndexExpr IndexExpr::clamp(int64_t min, IndexExpr max) { - return clamp(LiteralIndexExpr(min), max); + return clamp(LitIE(min), max); } /*static*/ IndexExpr IndexExpr::select( IndexExpr const compare, int64_t const trueVal, IndexExpr const falseVal) { - return select(compare, LiteralIndexExpr(trueVal), falseVal); + return select(compare, LitIE(trueVal), falseVal); } /*static*/ IndexExpr IndexExpr::select( IndexExpr const compare, IndexExpr const trueVal, int64_t const falseVal) { - return select(compare, trueVal, LiteralIndexExpr(falseVal)); + return select(compare, trueVal, LitIE(falseVal)); } /*static*/ IndexExpr IndexExpr::select( IndexExpr const compare, int64_t const trueVal, int64_t const falseVal) { - return select(compare, LiteralIndexExpr(trueVal), LiteralIndexExpr(falseVal)); + return select(compare, LitIE(trueVal), LitIE(falseVal)); } IndexExpr IndexExpr::selectOrSelf( @@ -1887,7 +1960,7 @@ void getIndexExprListFromInt( ArrayRef inputList, llvm::SmallVectorImpl &outputList) { outputList.clear(); for (int64_t item : inputList) - outputList.emplace_back(LiteralIndexExpr(item)); + outputList.emplace_back(LitIE(item)); } // Create a list of IndexExpr of kind LiteralIndexExpr/Questionmark from a @@ -1900,7 +1973,7 @@ void getIndexExprListFromShape( outputList.emplace_back(QuestionmarkIndexExpr(/*isFloat*/ false)); else { assert(item >= 0 && "expected kDynamic, not -1"); - outputList.emplace_back(LiteralIndexExpr(item)); + outputList.emplace_back(LitIE(item)); } } } diff --git a/src/Dialect/Mlir/IndexExpr.hpp b/src/Dialect/Mlir/IndexExpr.hpp index 0c867d0586..678fb664ea 100644 --- a/src/Dialect/Mlir/IndexExpr.hpp +++ b/src/Dialect/Mlir/IndexExpr.hpp @@ -237,8 +237,8 @@ result in a new Dim variable. for (int ii = 0; ii < outputRank; ++ii) { Value inductionVal = outputLoops.getInductionVar(ii); DimIndexExpr inductionIndex(inductionVal); - IndexExpr start = SymbolIndexExpr(shapeHelper.starts[ii]); - IndexExpr step = SymbolIndexExpr(shapeHelper.steps[ii]); + IndexExpr start = SymIE(shapeHelper.starts[ii]); + IndexExpr step = SymIE(shapeHelper.steps[ii]); loadIndices.emplace_back((step * inductionIndex) + start); storeIndices.emplace_back(inductionIndex); } @@ -337,12 +337,12 @@ class IndexExprScope { // Constructor for a scope. Top level scope must provide rewriter (possibly // null if we cannot generate code at this time) and location. IndexExprScope(mlir::OpBuilder *rewriter, mlir::Location loc); - IndexExprScope(DialectBuilder &db); + IndexExprScope(const DialectBuilder &db); // Constructor for subsequent nested scopes. Providing enclosing scope is // technically not necessary (nullptr can be passed); it is used to allow a // user to explicitly name the enclosing scope. IndexExprScope(mlir::OpBuilder *rewriter, IndexExprScope *enclosingScope); - IndexExprScope(DialectBuilder &db, IndexExprScope *enclosingScope); + IndexExprScope(const DialectBuilder &db, IndexExprScope *enclosingScope); // Destructor which release all IndexExpr associated with this scope. virtual ~IndexExprScope(); @@ -404,6 +404,7 @@ class IndexExprScope { //===----------------------------------------------------------------------===// using DimsExpr = llvm::SmallVector; +using DimsExprRef = mlir::ArrayRef; // Data structure that is the public interface for IndexExpr. It is a shallow // data structure that is simply a pointer to the actual data (IndexExprImpl). @@ -470,8 +471,8 @@ class IndexExpr { bool isLiteralAndSmallerThan(double b) const; // Values smaller. bool isLiteralAndSmallerThan(IndexExpr const b) const; // Values smaller. // Test if all element in list are literals. - static bool isLiteral(llvm::SmallVectorImpl &list); - static bool isNonNegativeLiteral(llvm::SmallVectorImpl &list); + static bool isLiteral(mlir::ArrayRef list); + static bool isNonNegativeLiteral(mlir::ArrayRef list); // Getters. IndexExprScope &getScope() const { return *getScopePtr(); } @@ -564,10 +565,10 @@ class IndexExpr { IndexExpr selectOrSelf(IndexExpr const compare, int64_t const trueVal) const; // Return min or max of a list of IndexExpr. - static IndexExpr min(llvm::SmallVectorImpl &vals); + static IndexExpr min(mlir::ArrayRef vals); static IndexExpr min(IndexExpr const first, IndexExpr const second); static IndexExpr min(IndexExpr const first, int64_t const second); - static IndexExpr max(llvm::SmallVectorImpl &vals); + static IndexExpr max(mlir::ArrayRef vals); static IndexExpr max(IndexExpr const first, IndexExpr const second); static IndexExpr max(IndexExpr const first, int64_t const second); @@ -581,7 +582,7 @@ class IndexExpr { // Debug (enable running with --debug-only=index-expr, for example). void debugPrint(const std::string &msg) const; static void debugPrint( - const std::string &msg, const llvm::SmallVectorImpl &list); + const std::string &msg, const mlir::ArrayRef list); protected: // Private queries. @@ -598,8 +599,7 @@ class IndexExpr { using F1 = std::function; using F2 = std::function; using F2Self = std::function; - using Flist = - std::function &)>; + using Flist = std::function)>; using F3 = std::function; // Support for operations: common handling for multiple operations. @@ -610,15 +610,15 @@ class IndexExpr { IndexExpr unaryOp( bool resIsFloat, F1 litFct, F1 affineExprFct, F1 valueFct) const; // Res is float is the same as a & b. - IndexExpr binaryOp(IndexExpr const b, bool affineWithLitB, bool hasNeutralA, - bool hasNeutralB, double neutralVal, F2 fInteger, F2 fAffine, - F2 fValue) const; + IndexExpr binaryOp(IndexExpr const b, bool propagateIntoMinMax, + bool affineWithLitB, bool hasNeutralA, bool hasNeutralB, + double neutralVal, F2 fInteger, F2 fAffine, F2 fValue) const; IndexExpr compareOp( mlir::arith::CmpIPredicate comparePred, IndexExpr const b) const; IndexExpr compareOp( mlir::arith::CmpFPredicate comparePred, IndexExpr const b) const; - static IndexExpr reductionOp(llvm::SmallVectorImpl &vals, - F2Self litRed, Flist affineRed, F2Self valueRed); + static IndexExpr reductionOp(mlir::ArrayRef vals, F2Self litRed, + Flist affineRed, F2Self valueRed); // Data: pointer to implemented object. IndexExprImpl *indexExprObj = nullptr; }; @@ -828,6 +828,7 @@ class SymbolIndexExpr : public IndexExpr { //===----------------------------------------------------------------------===// using LitIE = LiteralIndexExpr; +using PredIE = PredicateIndexExpr; using SymIE = SymbolIndexExpr; using DimIE = DimIndexExpr; @@ -842,7 +843,7 @@ inline IndexExpr operator*(int64_t const a, const IndexExpr &b) { return b * a; } inline IndexExpr operator-(int64_t const a, const IndexExpr &b) { - return LiteralIndexExpr(a) - b; + return LitIE(a) - b; } //===----------------------------------------------------------------------===// @@ -870,19 +871,19 @@ void getIndexExprList( inline llvm::SmallVector DimListIE(mlir::ValueRange range) { llvm::SmallVector outputList; - getIndexExprList(range, outputList); + getIndexExprList(range, outputList); return outputList; } inline llvm::SmallVector SymListIE(mlir::ValueRange range) { llvm::SmallVector outputList; - getIndexExprList(range, outputList); + getIndexExprList(range, outputList); return outputList; } // Create a list of IndexExpr of kind INDEX_EXPR from another list of IndexExpr. template -void getIndexExprList(const llvm::SmallVectorImpl &inputList, +void getIndexExprList(const mlir::ArrayRef inputList, llvm::SmallVectorImpl &outputList) { outputList.clear(); for (auto item : inputList) @@ -890,16 +891,16 @@ void getIndexExprList(const llvm::SmallVectorImpl &inputList, } inline llvm::SmallVector DimListIE( - const llvm::SmallVectorImpl &inputList) { + const mlir::ArrayRef inputList) { llvm::SmallVector outputList; - getIndexExprList(inputList, outputList); + getIndexExprList(inputList, outputList); return outputList; } inline llvm::SmallVector SymListIE( - const llvm::SmallVectorImpl &inputList) { + const mlir::ArrayRef inputList) { llvm::SmallVector outputList; - getIndexExprList(inputList, outputList); + getIndexExprList(inputList, outputList); return outputList; } diff --git a/src/Dialect/Mlir/IndexExprBuilder.cpp b/src/Dialect/Mlir/IndexExprBuilder.cpp index f19a0dc8c8..d1f4d35442 100644 --- a/src/Dialect/Mlir/IndexExprBuilder.cpp +++ b/src/Dialect/Mlir/IndexExprBuilder.cpp @@ -4,7 +4,7 @@ //===------------ IndexExprBuilder.cpp - builder for index expressions ----===// // -// Copyright 2022-2023 The IBM Research Authors. +// Copyright 2022-2024 The IBM Research Authors. // // ============================================================================= // @@ -47,6 +47,8 @@ APFloat getFloatValue(ElementsAttr elementsAttr, Type elType, uint64_t i) { return APFloat(onnx_mlir::castArrayRef(array)[i]); if (elType.isF64()) return APFloat(onnx_mlir::castArrayRef(array)[i]); + if (elType.isBF16()) + return APFloat(onnx_mlir::castArrayRef(array)[i]); llvm_unreachable("Unexpected float type"); } return elementsAttr.getValues()[i]; @@ -130,14 +132,14 @@ IndexExpr IndexExprBuilder::getIntFromArrayAsLiteral( if (i >= size) return UndefinedIndexExpr(); int64_t val = mlir::cast(intAttrArray.getValue()[i]).getInt(); - return LiteralIndexExpr(val); + return LitIE(val); } IndexExpr IndexExprBuilder::getIntFromArrayAsLiteral( ArrayAttr intAttrArray, uint64_t i, int64_t outOfBoundVal) { IndexExpr indexExpr = getIntFromArrayAsLiteral(intAttrArray, i); // Undefined value are set to default value. - return indexExpr.isUndefined() ? LiteralIndexExpr(outOfBoundVal) : indexExpr; + return indexExpr.isUndefined() ? LitIE(outOfBoundVal) : indexExpr; } void IndexExprBuilder::getIntFromArrayAsLiterals( @@ -147,10 +149,11 @@ void IndexExprBuilder::getIntFromArrayAsLiterals( if (len == -1) // Meaning pick up the full size of the list. len = size; else - assert((uint64_t)len <= size && "requesting too many elements"); + assert( + static_cast(len) <= size && "requesting too many elements"); if (len == 0) return; - for (uint64_t i = 0; i < (uint64_t)len; ++i) { + for (uint64_t i = 0; i < static_cast(len); ++i) { IndexExpr indexExpr = getIntFromArrayAsLiteral(intAttrArray, i); assert(!indexExpr.isUndefined() && "expected defined index expr"); list.emplace_back(indexExpr); @@ -163,7 +166,7 @@ void IndexExprBuilder::getIntFromArrayAsLiterals(ArrayAttr intAttrArray, assert(len >= 0 && "expect a defined size"); if (len == 0) return; - for (uint64_t i = 0; i < (uint64_t)len; ++i) { + for (uint64_t i = 0; i < static_cast(len); ++i) { IndexExpr indexExpr = getIntFromArrayAsLiteral(intAttrArray, i, outOfBoundVal); assert(!indexExpr.isUndefined() && "expected defined index expr"); @@ -197,17 +200,17 @@ IndexExpr IndexExprBuilder::getValFromArray( assert(arraySize == size && "expected given size to be the same as the " "one detected from the array value"); } - if (size == ShapedType::kDynamic || i >= (uint64_t)size) { + if (size == ShapedType::kDynamic || i >= static_cast(size)) { return UndefinedIndexExpr(); } if (ElementsAttr elementsAttr = getConst(array)) { if (isFloat) { double floatVal = getFloatValue(elementsAttr, elType, i).convertToDouble(); - return LiteralIndexExpr(floatVal); + return LitIE(floatVal); } else { int64_t intVal = getIntValue(elementsAttr, elType, i).getSExtValue(); - return LiteralIndexExpr(intVal); + return LitIE(intVal); } } // If our scalar array is not a constant; we have a runtime value. @@ -220,10 +223,10 @@ IndexExpr IndexExprBuilder::getValFromArray( if (isFloat) { double floatVal = getFloatValue(elementsAttr, elType, 0).convertToDouble(); - return LiteralIndexExpr(floatVal); + return LitIE(floatVal); } else { int64_t intVal = getIntValue(elementsAttr, elType, 0).getSExtValue(); - return LiteralIndexExpr(intVal); + return LitIE(intVal); } } // Otherwise, we can write code. @@ -234,9 +237,9 @@ IndexExpr IndexExprBuilder::getValFromArray( } Value castedVal = createMath.castToIndex(val); if (makeSymbol) - return SymbolIndexExpr(castedVal); + return SymIE(castedVal); else - return DimIndexExpr(castedVal); + return DimIE(castedVal); } return QuestionmarkIndexExpr(isFloat); } @@ -278,21 +281,21 @@ IndexExpr IndexExprBuilder::getIntFromArrayAsSymbolWithOutOfBound( Value intArray, uint64_t i, int64_t defaultLiteral) { IndexExpr indexExpr = getIntFromArrayAsSymbol(intArray, i); // Undefined value are set to default value. - return indexExpr.isUndefined() ? LiteralIndexExpr(defaultLiteral) : indexExpr; + return indexExpr.isUndefined() ? LitIE(defaultLiteral) : indexExpr; } IndexExpr IndexExprBuilder::getIntFromArrayAsDimWithOutOfBound( Value intArray, uint64_t i, int64_t defaultLiteral) { IndexExpr indexExpr = getIntFromArrayAsDim(intArray, i); // Undefined value are set to default value. - return indexExpr.isUndefined() ? LiteralIndexExpr(defaultLiteral) : indexExpr; + return indexExpr.isUndefined() ? LitIE(defaultLiteral) : indexExpr; } IndexExpr IndexExprBuilder::getFloatFromArrayAsNonAffineWithOutOfBound( Value floatArray, uint64_t i, double defaultLiteral) { IndexExpr indexExpr = getFloatFromArrayAsNonAffine(floatArray, i); // Undefined value are set to default value. - return indexExpr.isUndefined() ? LiteralIndexExpr(defaultLiteral) : indexExpr; + return indexExpr.isUndefined() ? LitIE(defaultLiteral) : indexExpr; } void IndexExprBuilder::getIntFromArrayAsSymbols( @@ -302,10 +305,11 @@ void IndexExprBuilder::getIntFromArrayAsSymbols( if (len == -1) // Meaning pick up the full size of the list. len = size; else - assert((uint64_t)len <= size && "requesting too many elements"); + assert( + static_cast(len) <= size && "requesting too many elements"); if (len == 0) return; - for (uint64_t i = 0; i < (uint64_t)len; ++i) { + for (uint64_t i = 0; i < static_cast(len); ++i) { IndexExpr indexExpr = getIntFromArrayAsSymbol(intArray, i); assert(!indexExpr.isUndefined() && "expected defined index expr"); list.emplace_back(indexExpr); @@ -319,10 +323,11 @@ void IndexExprBuilder::getIntFromArrayAsDims( if (len == -1) // Meaning pick up the full size of the list. len = size; else - assert((uint64_t)len <= size && "requesting too many elements"); + assert( + static_cast(len) <= size && "requesting too many elements"); if (len == 0) return; - for (uint64_t i = 0; i < (uint64_t)len; ++i) { + for (uint64_t i = 0; i < static_cast(len); ++i) { IndexExpr indexExpr = getIntFromArrayAsDim(intArray, i); assert(!indexExpr.isUndefined() && "expected defined index expr"); list.emplace_back(indexExpr); @@ -336,10 +341,11 @@ void IndexExprBuilder::getFloatFromArrayAsNonAffine( if (len == -1) // Meaning pick up the full size of the list. len = size; else - assert((uint64_t)len <= size && "requesting too many elements"); + assert( + static_cast(len) <= size && "requesting too many elements"); if (len == 0) return; - for (uint64_t i = 0; i < (uint64_t)len; ++i) { + for (uint64_t i = 0; i < static_cast(len); ++i) { IndexExpr indexExpr = getFloatFromArrayAsNonAffine(floatArray, i); assert(!indexExpr.isUndefined() && "expected defined index expr"); list.emplace_back(indexExpr); @@ -373,7 +379,7 @@ IndexExpr IndexExprBuilder::getShapeAsLiteral( int64_t shape = getShape(tensorOrMemrefValue, i); assert( shape != ShapedType::kDynamic && "expected compile time constant shape"); - return LiteralIndexExpr(shape); + return LitIE(shape); } IndexExpr IndexExprBuilder::getShapeAsSymbol( @@ -381,7 +387,7 @@ IndexExpr IndexExprBuilder::getShapeAsSymbol( if (isLiteralShape(tensorOrMemrefValue, i)) return getShapeAsLiteral(tensorOrMemrefValue, i); if (Value val = getShapeVal(tensorOrMemrefValue, i)) - return SymbolIndexExpr(val); + return SymIE(val); return QuestionmarkIndexExpr(tensorOrMemrefValue, i); } @@ -390,7 +396,7 @@ IndexExpr IndexExprBuilder::getShapeAsDim( if (isLiteralShape(tensorOrMemrefValue, i)) return getShapeAsLiteral(tensorOrMemrefValue, i); if (Value val = getShapeVal(tensorOrMemrefValue, i)) - return DimIndexExpr(val); + return DimIE(val); return QuestionmarkIndexExpr(tensorOrMemrefValue, i); } @@ -430,7 +436,7 @@ IndexExpr IndexExprBuilder::isTileFull( // However, if UB is divisible by Block, then its full no matter what. if (UB.isLiteral() && (UB.getLiteral() % block.getLiteral() == 0)) { // Last tile is guaranteed to be full because UB is divisible by block. - return LiteralIndexExpr(1); // 1 >= 0 is true + return LitIE(1); // 1 >= 0 is true } // True if i <= (UB - block), namely UB - block - i >= 0. // Affine expressions compared to >= 0 diff --git a/src/Dialect/Mlir/IndexExprDetail.cpp b/src/Dialect/Mlir/IndexExprDetail.cpp index 1b234bcbe4..d82fca89a1 100644 --- a/src/Dialect/Mlir/IndexExprDetail.cpp +++ b/src/Dialect/Mlir/IndexExprDetail.cpp @@ -429,12 +429,16 @@ void IndexExprImpl::getAffineMapAndOperands( // will extract the correct info. if (auto affineMinOp = getValue().getDefiningOp()) { map = affineMinOp.getAffineMap(); + // Wonder if specialized list is better than all dims and syms + // (scope.getDimAndSymbolList(operands)). for (Value val : affineMinOp.getMapOperands()) operands.emplace_back(val); return; } if (auto affineMaxOp = getValue().getDefiningOp()) { map = affineMaxOp.getAffineMap(); + // Wonder if specialized list is better than all dims and syms + // (scope.getDimAndSymbolList(operands)). for (Value val : affineMaxOp.getMapOperands()) operands.emplace_back(val); return; diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index 3ac5bdcb01..954eb030d1 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -9,6 +9,7 @@ // ============================================================================= #include "src/Dialect/Mlir/VectorMachineSupport.hpp" +#include "src/Compiler/CompilerOptions.hpp" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" @@ -28,14 +29,17 @@ namespace onnx_mlir { *VectorMachineSupport::globalVectorMachineSupport = nullptr; /*static*/ void VectorMachineSupport::setGlobalVectorMachineSupport( - std::string arch, std::string cpu, std::string attr) { - // IBM Z servers use mcpu. - if (cpu.compare("z14") == 0) { - globalVectorMachineSupport = new Z14VectorMachineSupport(); - } else if (cpu.compare("z15") == 0) { - globalVectorMachineSupport = new Z15VectorMachineSupport(); - } else if (cpu.compare("z16") == 0) { - globalVectorMachineSupport = new Z16VectorMachineSupport(); + const std::string &arch, const std::string &cpu, const std::string &attr) { + // IBM Z servers use march (deprecated mcpu), process here. + int64_t zArchNum = getZArchNum(arch, cpu); + if (zArchNum == 12) { + globalVectorMachineSupport = new ZArch12VectorMachineSupport(); + } else if (zArchNum == 13) { + globalVectorMachineSupport = new ZArch13VectorMachineSupport(); + } else if (zArchNum == 14) { + globalVectorMachineSupport = new ZArch14VectorMachineSupport(); + } else if (zArchNum == 15) { + globalVectorMachineSupport = new ZArch15VectorMachineSupport(); // Intel uses arch } else if (arch.compare("x86-64") == 0) { // Intel arch @@ -78,21 +82,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { } /*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps, - Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) { + Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure) { int64_t size = genOps.size(); + vectorizedOpNum = maxVectorRegisterPressure = 0; if (!hasSimd()) { - vectorizedOpNum = 0; scalarOpNum = size; return 1; } int64_t totProcessedValues = 0.0; - vectorizedOpNum = 0; scalarOpNum = 0; + bool hasRegisterPressure = false; + // Determine which operations support SIMD and accumulate their vector // lengths. for (auto pair : genOps) { GenericOps genOp = pair.first; int64_t num = pair.second; + // Handle other metrics first. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num); + hasRegisterPressure = true; + continue; + } + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t vl = getArchVectorLength(genOp, elementType); // If past last value, assume 1; otherwise use actual value. // Accumulate weighted scalar/vectorized num and vl length. @@ -101,12 +114,17 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { else scalarOpNum += num; // For VL, when an operation is scalar, it still process 1 element - int64_t processedValues = std::max((int64_t)1, vl); + int64_t processedValues = std::max(static_cast(1), vl); totProcessedValues += processedValues * num; } + // Compute final values int64_t totNum = vectorizedOpNum + scalarOpNum; - scalarOpNum = size - vectorizedOpNum; + if (!hasRegisterPressure) { + // Estimate default register pressure as one per 2 vector operation. + maxVectorRegisterPressure = + std::max(vectorizedOpNum / 2, static_cast(1)); + } return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0; } @@ -114,14 +132,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { // IBM Z servers // ============================================================================= -int64_t Z16VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, Type elementType) { +bool ZArch14VectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + bool isFloat = mlir::isa(elementType); + if (isFloat) { + switch (genOp) { + case GenericOps::roundEvenGop: + return true; + default: + return false; + } + } + // Integer + return false; +} + +int64_t ZArch14VectorMachineSupport::computeArchVectorLength( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); - // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -137,10 +171,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: /* Supported via compare and select */ case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: /* Use load integer & rounding modes*/ @@ -153,6 +187,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: return archVL; default: @@ -161,7 +196,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -189,14 +224,21 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= +bool SSE42x86VectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + return false; +} + int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( - GenericOps Gop, mlir::Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -212,10 +254,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // Supports only 32 and 64 bit Floats; There is support for extended too // but ignore this for now. if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -227,7 +269,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: - case GenericOps::RoundGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: return archVL; @@ -237,7 +279,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -275,14 +317,21 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= +bool NeonVectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + return false; +} + int64_t NeonVectorMachineSupport::computeArchVectorLength( - GenericOps Gop, mlir::Type elementType) { + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); int64_t bitWidth = elementType.getIntOrFloatBitWidth(); int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType); bool isFloat = mlir::isa(elementType); // Support shared between int and float. - switch (Gop) { + switch (genOp) { case GenericOps::ScalarOnlyGop: return 1; // Must be scalar. case GenericOps::SelectGop: @@ -297,10 +346,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( if (isFloat) { // Supports only 32 and 64 bit Floats; if (!(bitWidth == 32 || bitWidth == 64 || - (bitWidth == 16 && Gop == GenericOps::ConversionGop))) + (bitWidth == 16 && genOp == GenericOps::ConversionGop))) return UNSUPPORTED; // Now we have a supported length, test for specific operations. - switch (Gop) { + switch (genOp) { case GenericOps::AbsGop: case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::CeilGop: @@ -312,7 +361,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: - case GenericOps::RoundGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: return archVL; @@ -322,7 +371,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( } } // Support for integer (we consider bit-wide ops as byte wide ops). - switch (Gop) { + switch (genOp) { // 1 - 16 byte operations. case GenericOps::ArithmeticGop: /* Add/sub,... */ case GenericOps::ConversionGop: @@ -367,13 +416,22 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) { u[genOp] = num; } // Merge entries from the second mix. - for (auto pair : mix1) { + for (auto pair : mix2) { GenericOps genOp = pair.first; int64_t num = pair.second; - if (u.find(genOp) != u.end()) - u[genOp] += num; // Has this op already, add to it. - else + if (u.find(genOp) != u.end()) { + // Merge the 2 operation counts/metrics. + if (genOp == GenericOps::EstimatedVectorRegisterPressure) { + // For register pressure, pick the max of both. + u[genOp] = std::max(u[genOp], num); + } else { + // For operation count, use the sum of both + u[genOp] += num; + } + } else { + // First time we have this. u[genOp] = num; + } } return u; } diff --git a/src/Dialect/Mlir/VectorMachineSupport.hpp b/src/Dialect/Mlir/VectorMachineSupport.hpp index bcd2ad1a88..86327f4ae4 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.hpp +++ b/src/Dialect/Mlir/VectorMachineSupport.hpp @@ -32,6 +32,10 @@ namespace onnx_mlir { // (e.g. all the compares). enum class GenericOps { + ///////////////////////////////////// + // Generic ops. + ///////////////////////////////////// + AbsGop, ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */ CeilDivGop, @@ -52,7 +56,7 @@ enum class GenericOps { MulGop, PowGop, RemGop, - RoundGop, + roundEvenGop, /* FP to FP round to nearest even ONNX */ ScalarOnlyGop, /* Any ops that are guaranteed to be scalar on any arch. */ SelectGop, ShiftGop, /* Shift operations: logical/arithmetic. */ @@ -62,6 +66,17 @@ enum class GenericOps { TrigArcGop, /* Arc trigonometry ops: asin, acos, atan. */ TrigGop, /* Trigonometry ops: sin, cos, tan. */ TrigHyperbolicGop, /* Hyperbolic trig. */ + + LastGop, /* Marker of the last op. Used to delineate from other metrics. */ + + ///////////////////////////////////// + // Metrics others than operations. + ///////////////////////////////////// + + // Metric that provides an estimate of the maximum number of vector registers + // used in a kernel. If none is provided, we estimate the pressure based on + // the number of operations. + EstimatedVectorRegisterPressure, }; // Describe the mix of Generic operations in a given kernel. Each generic @@ -83,7 +98,7 @@ class VectorMachineSupport { public: // Must call setGlobalVectorMachineSupport once before using any calls below. static void setGlobalVectorMachineSupport( - std::string arch, std::string cpu, std::string attr); + const std::string &arch, const std::string &cpu, const std::string &attr); static void clearGlobalVectorMachineSupport(); static std::string getArchName() { return vms()->computeArchName(); } @@ -92,6 +107,11 @@ class VectorMachineSupport { // support. static bool hasSimd() { return getArchVectorRegisterNum() > 0; } + // Determine if custom asm is needed (aka operation not supported by llvm). + static bool requireCustomASM(GenericOps gop, mlir::Type elementType) { + return vms()->needCustomASM(gop, elementType); + } + // When querying Vector length for machines with unsupported simd, UNSUPPORTED // (aka 0) is returned. static const int64_t UNSUPPORTED = 1; @@ -132,12 +152,17 @@ class VectorMachineSupport { // number of times that generic operation was found. Note that scalar // operation have a vector length of one in the weighted average as they still // contribute one result. + // Max vector register pressure is also reported, either from an explicit + // mention in the genOps, or estimated as one vector register per vector + // operation. static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType, - int64_t &vectorizedOpNum, int64_t &scalarOpNum); + int64_t &vectorizedOpNum, int64_t &scalarOpNum, + int64_t &maxVectorRegisterPressure); protected: // Virtual functions that do the actual work. Called by the "get" functions. virtual std::string computeArchName() = 0; + virtual bool needCustomASM(GenericOps gop, mlir::Type elementType) = 0; virtual int64_t computeArchVectorRegisterNum() = 0; virtual int64_t computeArchVectorBitWidth() = 0; virtual int64_t computeArchVectorLength(mlir::Type elementType); @@ -160,6 +185,9 @@ class NoVectorMachineSupport : public VectorMachineSupport { virtual ~NoVectorMachineSupport() = default; std::string computeArchName() override { return "no_vector"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override { + return false; + } int64_t computeArchVectorRegisterNum() override { return 0; } int64_t computeArchVectorBitWidth() override { return 0; } int64_t computeArchVectorLength(mlir::Type elementType) override { @@ -173,21 +201,23 @@ class NoVectorMachineSupport : public VectorMachineSupport { // Support for IBM Z servers. -class Z16VectorMachineSupport : public VectorMachineSupport { +class ZArch14VectorMachineSupport : public VectorMachineSupport { public: - Z16VectorMachineSupport() = default; - virtual ~Z16VectorMachineSupport() = default; + ZArch14VectorMachineSupport() = default; + virtual ~ZArch14VectorMachineSupport() = default; - std::string computeArchName() override { return "z16"; } + std::string computeArchName() override { return "z16/arch14 equivalent"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 32; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( GenericOps gop, mlir::Type elementType) override; }; -// TODO: create models for z14 and z15. -using Z14VectorMachineSupport = Z16VectorMachineSupport; -using Z15VectorMachineSupport = Z16VectorMachineSupport; +// TODO: create models for arch12, arch13, arch15. +using ZArch12VectorMachineSupport = ZArch14VectorMachineSupport; +using ZArch13VectorMachineSupport = ZArch14VectorMachineSupport; +using ZArch15VectorMachineSupport = ZArch14VectorMachineSupport; // Support for x86 processors (SSE 4.2 and AVX2) class SSE42x86VectorMachineSupport : public VectorMachineSupport { @@ -196,6 +226,7 @@ class SSE42x86VectorMachineSupport : public VectorMachineSupport { virtual ~SSE42x86VectorMachineSupport() = default; std::string computeArchName() override { return "x86-sse4.2"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 16; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( @@ -219,6 +250,7 @@ class NeonVectorMachineSupport : public VectorMachineSupport { virtual ~NeonVectorMachineSupport() = default; std::string computeArchName() override { return "arm64-neon"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 32; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( diff --git a/src/Dialect/ONNX/AdditionalONNXOps.td b/src/Dialect/ONNX/AdditionalONNXOps.td index a72af4d7c2..11962e1ed7 100644 --- a/src/Dialect/ONNX/AdditionalONNXOps.td +++ b/src/Dialect/ONNX/AdditionalONNXOps.td @@ -369,7 +369,7 @@ def ONNXYieldOp : ONNX_Op<"Yield", [Pure, ReturnLike, Terminator]> { //===----------------------------------------------------------------------===// // BatchNorm in Inference mode. def ONNXBatchNormalizationInferenceModeOp: ONNX_Op<"BatchNormalizationInferenceMode", - [Pure, DeclareOpInterfaceMethods, + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BatchNormalization operation in test mode"; let description = [{ @@ -421,7 +421,8 @@ def ONNXBatchNormalizationInferenceModeOp: ONNX_Op<"BatchNormalizationInferenceM //===----------------------------------------------------------------------===// // MaxPoolSingleOutOp def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", - [Pure, DeclareOpInterfaceMethods, + [Pure, OpVersionTrait<12>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxPool operation with a single output."; let description = [{ diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 3917c94aa4..b68e1cf247 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -80,6 +80,7 @@ add_onnx_mlir_library(OMONNXOps ONNXOps/Tensor/Gather.cpp ONNXOps/Tensor/GatherElements.cpp ONNXOps/Tensor/GatherND.cpp + ONNXOps/Tensor/GridSample.cpp ONNXOps/Tensor/Identity.cpp ONNXOps/Tensor/NonZero.cpp ONNXOps/Tensor/OneHot.cpp diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index 38cead9931..672b4f6865 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -4,7 +4,7 @@ //===----- DialectBuilder.cpp - Helper functions for ONNX dialects -------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -36,6 +36,12 @@ IntegerAttr OnnxBuilder::getSignedInt64Attr(int64_t n) const { // Basic operations // ============================================================================= +Value OnnxBuilder::abs(Value input) const { + Type outputType = input.getType(); // input == output type. + return createTypedOpAndInferShapes( + toTensor(outputType), toTensor(input)); +} + Value OnnxBuilder::add(Value A, Value B) const { assert((mlir::cast(A.getType()).getElementType() == mlir::cast(B.getType()).getElementType()) && @@ -138,6 +144,13 @@ void OnnxBuilder::dimGroup(Value input, int axis, int groupID) const { b().create(loc(), input, axisAttr, groupIDAttr); } +Value OnnxBuilder::dequantizeLinear( + Type resType, Value X, Value scale, Value zeroPoint, int axis) const { + IntegerAttr axisAttr = getSignedInt64Attr(axis); + return createOpAndInferShapes( + resType, toTensor(X), toTensor(scale), toTensor(zeroPoint), axisAttr); +} + Value OnnxBuilder::div(Value A, Value B) const { assert((mlir::cast(A.getType()).getElementType() == mlir::cast(B.getType()).getElementType()) && @@ -150,6 +163,11 @@ Value OnnxBuilder::expand(Type outputType, Value input, Value shape) const { outputType, toTensor(input), toTensor(shape)); } +Value OnnxBuilder::gelu(Value input, StringAttr approximateAttr) const { + return createOpAndInferShapes( + toTensor(input.getType()), input, approximateAttr); +} + // ONNXLayerNormalizationOp, version with one output only (Y). Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, Value bias, int64_t axis, FloatAttr epsilon) const { @@ -164,6 +182,19 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, toTensor(bias), axisAttr, epsilon, stashTypeAttr); return layerNormOp.getY(); } +// In the case of GroupNormalization when stashType can be specified +Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale, + Value bias, int64_t axis, FloatAttr epsilon, IntegerAttr stashType) const { + IntegerAttr axisAttr = getSignedInt64Attr(axis); + Value noneVal = none(); + Type noneType = noneVal.getType(); + ONNXLayerNormalizationOp layerNormOp = + createOpAndInferShapes( + /*Y type*/ toTensor(outputType), /*mean type*/ noneType, + /*std dev Type*/ noneType, toTensor(input), toTensor(scale), + toTensor(bias), axisAttr, epsilon, stashType); + return layerNormOp.getY(); +} Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale, Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale, @@ -296,8 +327,8 @@ Value OnnxBuilder::reshape(Type outputType, Value input, Value shape) const { toTensor(outputType), toTensor(input), toTensor(shape)); } -Value OnnxBuilder::reshape(Type outputType, Value input, Value shape, - mlir::IntegerAttr allowZero) const { +Value OnnxBuilder::reshape( + Type outputType, Value input, Value shape, IntegerAttr allowZero) const { return createTypedOpAndInferShapes( toTensor(outputType), toTensor(input), toTensor(shape), allowZero); } @@ -426,7 +457,6 @@ TensorType OnnxBuilder::toTensor(Type input) const { } TypeRange OnnxBuilder::toTensors(TypeRange inputs) const { - assert(inputs.size() >= 2 && "Expect at least two inputs"); if (llvm::all_of(inputs, [](Type t) { return (mlir::isa(t)); })) return inputs; assert(llvm::all_of(inputs, [](Type t) { diff --git a/src/Dialect/ONNX/DialectBuilder.hpp b/src/Dialect/ONNX/DialectBuilder.hpp index 9ff98a3755..7ade044b50 100644 --- a/src/Dialect/ONNX/DialectBuilder.hpp +++ b/src/Dialect/ONNX/DialectBuilder.hpp @@ -41,6 +41,9 @@ struct OnnxBuilder : DialectBuilder { OnnxOpType createTypedOpAndInferShapes( mlir::Type result_ty, Args &&... args) const; + // ONNXAbsOp + mlir::Value abs(mlir::Value input) const; + // ONNXAddOp mlir::Value add(mlir::Value A, mlir::Value B) const; @@ -74,6 +77,10 @@ struct OnnxBuilder : DialectBuilder { mlir::ArrayRef kernelShape, mlir::ArrayRef pads, mlir::ArrayRef strides) const; + // ONNXDequantizeLinearOp + mlir::Value dequantizeLinear(mlir::Type resType, mlir::Value X, + mlir::Value scale, mlir::Value zeroPoint, int axis = 1) const; + // ONNXDivOp mlir::Value div(mlir::Value A, mlir::Value B) const; @@ -87,10 +94,17 @@ struct OnnxBuilder : DialectBuilder { mlir::Value expand( mlir::Type outputType, mlir::Value input, mlir::Value shape) const; + // ONNXGeluOp + mlir::Value gelu(mlir::Value input, mlir::StringAttr approximateAttr) const; + // ONNXLayerNormalizationOp, version with one output only (Y). mlir::Value layerNorm(mlir::Type outputType, mlir::Value input, mlir::Value scale, mlir::Value bias, int64_t axis, mlir::FloatAttr epsilon) const; + // In the case of GroupNormalization when stashType can be specified + mlir::Value layerNorm(mlir::Type outputType, mlir::Value input, + mlir::Value scale, mlir::Value bias, int64_t axis, + mlir::FloatAttr epsilon, mlir::IntegerAttr stashType) const; // ONNXQLinearMatMulOp mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a, @@ -327,5 +341,11 @@ struct IndexExprBuilderForAnalysis : IndexExprBuilder { // Include inline code definitions. #include "DialectBuilder.hpp.inc" +template +void copySingleResultType(OnnxOp opToCopyFrom, mlir::Value &valueToCopyTo) { + assert(opToCopyFrom->getNumResults() == 1); + valueToCopyTo.setType(opToCopyFrom->getResult(0).getType()); +} + } // namespace onnx_mlir #endif diff --git a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp index 3739904c17..cabfe58c02 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp @@ -116,6 +116,11 @@ class DisposableElementsAttr return *this ? mlir::cast(*this) : nullptr; } + // Clears the buffer payload shared_ptr which decreases the reference count + // and, if it reaches zero, frees or closes the underlying MemoryBuffer's + // heap allocation or file. Called from DisposablePool. + void dispose(); + private: // Called from DisposablePool who calls with a unique id and records the // created instance. @@ -123,11 +128,6 @@ class DisposableElementsAttr BType bufferBType, ArrayRef strides, const Buffer &buffer, Transformer transformer); - // Clears the buffer payload shared_ptr which decreases the reference count - // and, if it reaches zero, frees or closes the underlying MemoryBuffer's - // heap allocation or file. Called from DisposablePool. - void dispose(); - public: //===--------------------------------------------------------------------===// // Instance properties: @@ -259,6 +259,12 @@ class DisposableElementsAttr template void readArray(MutableArrayRef dst) const; + // Returns a pointer to the underlying data as a flat byte array, if + // everything aligns, otherwise makes and returns a copy. + // If the element type is bool the data holds one byte (with value 0 or 1) per + // bool (contrary to how DenseElementsAttr::getRawData() bit packs bools). + onnx_mlir::ArrayBuffer getRawBytes() const; + // Returns a pointer to the underlying data as a flat WideNum array, if // everything aligns, otherwise makes and returns a copy. onnx_mlir::ArrayBuffer getWideNums() const; @@ -313,12 +319,6 @@ class DisposableElementsAttr // bool (contrary to how DenseElementsAttr::getRawData() bit packs bools). void readRawBytes(MutableArrayRef dst) const; - // Returns a pointer to the underlying data as a flat byte array, if - // everything aligns, otherwise makes and returns a copy. - // If the element type is bool the data holds one byte (with value 0 or 1) per - // bool (contrary to how DenseElementsAttr::getRawData() bit packs bools). - onnx_mlir::ArrayBuffer getRawBytes() const; - }; // class DisposableElementsAttr // Include template implementations. diff --git a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp index aa0f24d244..bdc654f540 100644 --- a/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp +++ b/src/Dialect/ONNX/ElementsAttr/DisposablePool.cpp @@ -114,7 +114,7 @@ void DisposablePool::scrub(ModuleOp moduleOp, OpAttrDictionary opsAttrs) { return llvm::make_range(batchBegin, batchEnd); }; // Parallel worker body: Fetch and process batches until there are no more. - auto work = [&fetchBatch](size_t threadNumber) { + auto work = [&fetchBatch, &translationMutex](size_t threadNumber) { for (;;) { auto batch = fetchBatch(); if (batch.empty()) @@ -122,9 +122,10 @@ void DisposablePool::scrub(ModuleOp moduleOp, OpAttrDictionary opsAttrs) { for (auto &[id, translation] : batch) { auto &[disposable, dense] = translation; dense = disposable.toDenseElementsAttr(); - // TODO: Consider calling disposable.dispose() here to free up memory - // on the go to make memory available to create the next - // DenseElementsAttr. In that case should we lock mutex? + { + const std::lock_guard lock(translationMutex); + disposable.dispose(); + } } } }; diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp index 63d2d3f0e9..0237feccdc 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Traits.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" #include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp" @@ -187,6 +188,7 @@ ElementsAttr ElementsAttrBuilder::fromWideNums( // demonstrates a speedup. ElementsAttr ElementsAttrBuilder::combine(ElementsAttr lhs, ElementsAttr rhs, ShapedType combinedType, WideNum (*combiner)(WideNum, WideNum)) { + assert(combinedType.hasStaticShape()); if (lhs.isSplat()) { WideNum lhsNum = getElementsSplatWideNum(lhs); return expandAndTransform(rhs, combinedType, @@ -227,6 +229,7 @@ ElementsAttr ElementsAttrBuilder::combine(ElementsAttr lhs, ElementsAttr rhs, ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs, ElementsAttr rhs, ShapedType combinedType) { + assert(combinedType.hasStaticShape()); assert(cond.getElementType().isInteger(1)); assert(lhs.getElementType() == rhs.getElementType()); assert(lhs.getElementType() == combinedType.getElementType()); @@ -361,12 +364,12 @@ template WideNum wideCast(WideNum n) { return WideNum::widen>( static_cast(n.narrow>())); -}; +} template double wideToDouble(WideNum n) { return static_cast(n.narrow>()); -}; +} } // namespace @@ -496,6 +499,285 @@ bool isIdentityPermutation(ArrayRef perm) { } } // namespace +// Adapted from +// https://github.com/onnx/onnx/blob/091d3ad16155640a7b56b0aab8a364fb908894a8/onnx/reference/ops/op_reverse_sequence.py#L9 + +// Pseudo code for the reverseSequence implementation: + +// result = input +// for i in enumerate(sequence_lens) +// list dstSrcPositionPairs +// list> timeAxisPosList1 with size input[0] dimsize for +// batch_axis=1 ( it will be input[1] dimsize for batch_axis=0) for idx on input +// begin +// if( batch_axis==1 and idx[1] == i and idx[0] < sequence_lens[i] ) +// dstSrcPositionPairs.push(idx.pos,0) // the destination pos which will +// be replaced is added, the source postion form where it will be +// replaced will be computed later +// timeAxisPosList1[idx[0]].push(idx.pos) // Add this pos to +// the correspoding timeAxis list. +// else if ( batch_axis==0 and idx[0] == i and idx[1] < sequence_lens[i] ) +// dstSrcPositionPairs.push(idx.pos,0) +// timeAxisPosList1[idx[1]].push(idx.pos) +// end + +// list> timeAxisPosList2 with size input[0] dimsize for +// batch_axis=1 ( it will be input[1] dimsize for batch_axis=0) for idx on input +// begin +// if( batch_axis==1 and idx[1] == i and idx[0] < sequence_lens[i] ) +// timeAxisPosList2[idx[0]].push(idx.pos) +// positionWithinList = timeAxisPosList2[idx[0]].size() + +// listAsPerRevSeq = timeAxisPosList1.size()-idx[0]-1 +// sourcePosition = +// timeAxisPosList1[listAsPerRevSeq][positionWithinList] update the pair +// in dstSrcPositionPairs for idx.pos with sourcePosition +// else if( batch_axis==0 and idx[0] == i and idx[1] < sequence_lens[i] ) +// timeAxisPosList2[idx[1]].push(idx.pos) +// positionWithinList = timeAxisPosList2[idx[1]].size() + +// listAsPerRevSeq = timeAxisPosList1.size()-idx[1]-1 +// sourcePosition = +// timeAxisPosList1[listAsPerRevSeq][positionWithinList] update the pair +// in dstSrcPositionPairs for idx.pos with sourcePosition +// end +// get iterator for dstSrcPositionPairs. +// for idx on input +// begin +// continue till the idx.pos equals iterator.destination +// update the result[idx.pos] with value at iterator.source +// increment the iterator +// end + +ElementsAttr ElementsAttrBuilder::reverseSequence( + ElementsAttr input, ElementsAttr sequenceLength, uint64_t batchAxis) { + + ShapedType inputType = input.getShapedType(); + ArrayRef inputShape = inputType.getShape(); + + SmallVector inputStrides; + ArrayBuffer inputNums = + getWideNumsAndExpandedStrides(input, inputShape, inputStrides); + + SmallVector seqLengthStrides; + ArrayRef seqLengthShape = sequenceLength.getShapedType().getShape(); + ArrayBuffer seqLengthNums = getWideNumsAndExpandedStrides( + sequenceLength, seqLengthShape, seqLengthStrides); + + Type elementType = inputType.getElementType(); + + return fromWideNums(inputType, [&](MutableArrayRef dstNums) { + wideZeroDispatchNonBool(elementType, [&](auto wideZero) { + using cpptype = decltype(wideZero); + constexpr BType TAG = toBType; + // This loop copies each element in the input to the dstNums + for (const auto &idxoffs : StridesRange<1>(inputShape, {inputStrides})) { + dstNums[idxoffs.flattenedIndex] = inputNums.get()[idxoffs[0]]; + } + SmallVector sequenceLength; + // Traverse and populate each element into sequenceLength. + for (const auto &idxoffs : + StridesRange<1>(seqLengthShape, {seqLengthStrides})) { + int64_t pos = idxoffs[0]; + auto value = seqLengthNums.get()[pos].narrow(); + sequenceLength.emplace_back(value); + } + // op Length of sequence_lens should match the + // sizeof batch axis of the input + // Iterating through the sequence_lens tensor values. + // This iteration means iterating through the batch axis dimensions. + // For ex: for input with dims (3,3,1,2), and batch_axis=1 + // it will have three iterations at dim[1]. + + // This is the most outer loop, after each iteration it will have + // rearranged the data in correspoding batch_axis dimension. + for (const auto [seqLengthIndex, seqLengthValue] : + llvm::enumerate(sequenceLength)) { + + // dstSrcPositionPairs: maintains the list of positions dst,src. + // Here destination means the position whose value will be + // overriden as part of rearrangement. + // source means the position from where the value will be + // picked up. + // `destination element` pos thats get its value from `source position` + // for ex: (0,1,2) <= (2,1,0),will have entries (0,2),(1,1),(2,0) + // This list will not have positions not affected by the sequence_length + // for example the input is (5,5,1,2) and the sequence_length + // is [3,3,3]. Here the reversal will happen for first 3 entries, the + // last two will be not affacted. And they will not be populated in + // dstSrcPositionPairs + // NOTE: dstSrcPositionPairs is created for every iteration. + // meaning the content of this is for a given batchAxis dimension + // ex: for input (3,3,1,2) seqlength [3,3,3], batch_axis 1 + // for each value in dim[1] we will have ones list of + // dstSrcPostionPairs. and one iteration. + + SmallVector> dstSrcPositionPairs; + + // timeAxisPosList1: + // This will have one list for each of the possible timeAxis values. + // ex: for input (3,3,1,2) seqlength [3,3,3], batch_axis 1 + // time_axis is 0, and at dim[0], we will have three differnet values. + // hence three lists. Each list will have its corresponding element's + // flat position values. + + // Once the list is fully populated. It will be used to find the source + // position for the current destination position. For ex: when at [0]th + // list second element position, we can find the second position in + // [2]nd list second element position value + + SmallVector> timeAxisPosList1; + int timeAxisPosListSize = 0; + int maxValueFortimeAxisPosListSize = 0; + // time_axis dim size + if (batchAxis == 1) { + maxValueFortimeAxisPosListSize = inputShape[0]; + } else { + maxValueFortimeAxisPosListSize = inputShape[1]; + } + // The reverseSequence length cannot be greater than the timeAxis dim + // length + + timeAxisPosListSize = (seqLengthValue > maxValueFortimeAxisPosListSize) + ? maxValueFortimeAxisPosListSize + : seqLengthValue; + // The number of lists trackes the positions of time_axis's which + // are to be reversed, this is defined by the sequence_lens value. + // ex: for input (3,3,1,2) seqlength [2,2,2], batch_axis 1 + // here the reversal is for first two entries of the timeAxis ( 0 and + // 1 + // ). The number of lists in the timeAxisPosList1 will be 2. + + timeAxisPosList1.resize(timeAxisPosListSize); + // Below loop will populate the dstSrcPositionPairs with the dst + // positions only. This dst positions are the one's which need to + // overriden. It will also populate the timeAxisPosList1. + for (const auto &idxoffs : + StridesRange<1>(inputShape, {inputStrides})) { + auto idx = idxoffs.index; + // for batch_axis = 1, the criteria is the elements dim[1] should + // be same as iteration index. and dim[0] should be less than + // seq_length defined's value ( it indicates how many elements should + // be reversed) ex: for input (3,3,1,2) seqlength [2,2,2], batch_axis + // 1 Here this considers only elements with dim[0] value less than 2. + // ( 0 and 1 only) + + if ((batchAxis == 1) && + ((idx[1] == seqLengthIndex) && + (static_cast(idx[0]) < seqLengthValue))) { + // adding only destination pos, source pos will be added in next + // iteration. + dstSrcPositionPairs.emplace_back(idxoffs[0], 0); + SmallVector>::iterator listIter = + timeAxisPosList1.begin(); + // advancing the list iterator by idx[0], + // note, we have one list per timeSliceIndex + // here the advancement is same as idx[0] + std::advance(listIter, idx[0]); + // Add the pos to the correspoding timeSliceIndex's list. + (*listIter).push_back(idxoffs[0]); + } else if ((batchAxis == 0) && + ((idx[0] == seqLengthIndex) && + (static_cast(idx[1]) < seqLengthValue))) { + dstSrcPositionPairs.emplace_back(idxoffs[0], 0); + SmallVector>::iterator iter = + timeAxisPosList1.begin(); + std::advance(iter, idx[1]); + (*iter).push_back(idxoffs[0]); + } + } + // timeAxisPosList2 is simliar to timeAxisPosList1. + SmallVector> timeAxisPosList2; + timeAxisPosList2.resize(timeAxisPosListSize); + // Starting the dstSrcPositionPairs iteration. + // As the dst poisitions are encountered, the source position + // assignement will be done. The assignments are done in the same + // order. Now, we have the knowledge of each timeAxis's position + // lists in the timeAxisPosList1. + + SmallVector>::iterator dstSrcPairsIter = + dstSrcPositionPairs.begin(); + + // In this loop, dstSrcPositionPairs's source position assignement + // will be completed. + + for (const auto &idxoffs : + StridesRange<1>(inputShape, {inputStrides})) { + auto idx = idxoffs.index; + if ((batchAxis == 1) && + ((idx[1] == seqLengthIndex) && + (static_cast(idx[0]) < seqLengthValue))) { + SmallVector>::iterator listIter2 = + timeAxisPosList2.begin(); + std::advance(listIter2, idx[0]); + (*listIter2).push_back(idxoffs[0]); + + // ex: for input (3,3,1,2) seqlength [3,3,3], batch_axis 1 , idx[1] + // == 0 for idx[0] ==0, listIter2 will be list tracking positions + // with idx[0] ==0 This size of the (*listIter2).size() at this + // stage will help us to know the corresponding element position in + // the source list. + + int posIndex = (*listIter2).size() - 1; + + // get the iterator timeAxisPosList1 and advance it to point to the + // correct source list. + + SmallVector>::iterator iter1 = + timeAxisPosList1.begin(); + std::advance(iter1, (timeAxisPosList1.size() - 1 - idx[0])); + // In the source list advance the iter to point to the corresponding + // postion. + SmallVector::iterator innerListIter = (*iter1).begin(); + std::advance(innerListIter, posIndex); + (*dstSrcPairsIter).second = *(innerListIter); + dstSrcPairsIter++; + } else if ((batchAxis == 0) && + ((idx[0] == seqLengthIndex) && + (static_cast(idx[1]) < seqLengthValue))) { + SmallVector>::iterator iter2 = + timeAxisPosList2.begin(); + std::advance(iter2, idx[1]); + (*iter2).push_back(idxoffs[0]); + int posIndex = (*iter2).size() - 1; + SmallVector>::iterator iter1 = + timeAxisPosList1.begin(); + std::advance(iter1, (timeAxisPosList1.size() - 1 - idx[1])); + SmallVector::iterator innerListIter = (*iter1).begin(); + std::advance(innerListIter, posIndex); + (*dstSrcPairsIter).second = *(innerListIter); + dstSrcPairsIter++; + } + } + SmallVector>::iterator dstSrcLookupIter = + dstSrcPositionPairs.begin(); + + // This loop uses the dstSrcPositionPairs, and as the dst position is + // encountered, it copies the source position value to the dst position. + + for (const auto &idxoffs : + StridesRange<1>(inputShape, {inputStrides})) { + + int64_t pos = idxoffs[0]; + if (pos < ((*dstSrcLookupIter).first)) { + continue; + } + if (pos == (*dstSrcLookupIter).first) { + cpptype replacingValue = 0; + replacingValue = + inputNums.get()[(*dstSrcLookupIter).second].narrow(); + dstNums[idxoffs.flattenedIndex] = + WideNum::widen(replacingValue); + dstSrcLookupIter++; + if (dstSrcLookupIter == dstSrcPositionPairs.end()) { + break; + } + } + } + } + }); + }); +} ElementsAttr ElementsAttrBuilder::transpose( ElementsAttr elms, ArrayRef perm) { if (isIdentityPermutation(perm)) diff --git a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp index f7276b6ebb..2105f31ba9 100644 --- a/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp +++ b/src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp @@ -155,6 +155,9 @@ class ElementsAttrBuilder { mlir::ElementsAttr transpose( mlir::ElementsAttr elms, llvm::ArrayRef perm); + mlir::ElementsAttr reverseSequence(mlir::ElementsAttr input, + mlir::ElementsAttr sequenceLength, uint64_t batchAxis); + // Returns a reshaped ElementsAttr. // // Reuses elms' underlying data without a data copy, unless the underlying diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.hpp b/src/Dialect/ONNX/ElementsAttr/Strides.hpp index 1402207a7b..c589d77d97 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.hpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.hpp @@ -111,4 +111,4 @@ void restrideArray(llvm::ArrayRef shape, } } // namespace onnx_mlir -#endif \ No newline at end of file +#endif diff --git a/src/Dialect/ONNX/ONNX.td b/src/Dialect/ONNX/ONNX.td index 53c52a3ffc..5fac4671f7 100644 --- a/src/Dialect/ONNX/ONNX.td +++ b/src/Dialect/ONNX/ONNX.td @@ -232,6 +232,11 @@ def ONNXConstantOpFromDenseAttr: NativeCodeCall< class ONNX_Op traits = []> : Op ; +// Trait to specify which operation set introduced a revision of an operator. +// For multi-versioned operators, the version also appears in the operator's name. +class OpVersionTrait + : ParamNativeOpTrait<"OpVersionTrait", !cast(version)>; + // The tablegen code onnxop.in is generated with gen_doc.py // clone and install onnx // git clone --recursive https://github.com/onnx/onnx.git diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 62e5b51e81..26e5e97a62 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -80,7 +80,7 @@ static std::optional insertDimWhenUseful(const Value tensor, // need to insert it. if (isa(op)) okToInsert = false; - else if (auto dimOp = dyn_cast(op)) { + else if (auto dimOp = mlir::dyn_cast(op)) { // The correct axis is from ONNXDimOp. axis = dimOp.getAxis(); okToInsert = true; @@ -107,9 +107,8 @@ static bool handleAndTestInBound(int64_t &axis, ShapedType type) { /// Given a QuestionMarkIndexExpr representing a dynamic dimension, find the /// same dynamic dimensions in the inputs. static void findAndAddSameDim(const QuestionmarkIndexExpr &qmOuputIE, - mlir::Operation *op, mlir::ValueRange operands, - DimAnalysis::DimSetT &sameDims) { - mlir::Location loc = op->getLoc(); + mlir::Operation *op, ValueRange operands, DimAnalysis::DimSetT &sameDims) { + Location loc = op->getLoc(); IndexExprBuilderForAnalysis createIE(loc); // Cannot process if the question mark is not a specific one. @@ -144,7 +143,7 @@ static void exploreSameDimsFromConsumingOperators( llvm::dbgs() << " - exploring "; op->dump(); }); - if (auto concatOp = dyn_cast(op)) { + if (auto concatOp = mlir::dyn_cast(op)) { // Dimensions on the same axis (except the concatenating axis) are the // same across all inputs. int64_t axis = concatOp.getAxis(); @@ -177,7 +176,7 @@ static void exploreSameDimsFromConsumingOperators( } continue; } - if (auto gemmOp = dyn_cast(op)) { + if (auto gemmOp = mlir::dyn_cast(op)) { Value A = gemmOp.getA(); Value B = gemmOp.getB(); if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) @@ -199,7 +198,7 @@ static void exploreSameDimsFromConsumingOperators( } continue; } - if (auto gruOp = dyn_cast(op)) { + if (auto gruOp = mlir::dyn_cast(op)) { int64_t layout = gruOp.getLayout(); // In LSTM, sequence_lens and batch_size are potentially dynamic. // Only batch_size is used in multiple inputs, so we'll check batch_size. @@ -237,7 +236,7 @@ static void exploreSameDimsFromConsumingOperators( } continue; } - if (auto lstmOp = dyn_cast(op)) { + if (auto lstmOp = mlir::dyn_cast(op)) { int64_t layout = lstmOp.getLayout(); // In LSTM, sequence_lens and batch_size are potentially dynamic. // Only batch_size is used in multiple inputs, so we'll check batch_size. @@ -305,7 +304,7 @@ static void exploreSameDimsFromConsumingOperators( } continue; } - if (auto rnnOp = dyn_cast(op)) { + if (auto rnnOp = mlir::dyn_cast(op)) { int64_t layout = rnnOp.getLayout(); // In LSTM, sequence_lens and batch_size are potentially dynamic. // Only batch_size is used in multiple inputs, so we'll check batch_size. @@ -362,11 +361,11 @@ static bool exploreSameDimsUsingShapeHelper(const DimAnalysis::DimT &dim, ONNXOpShapeHelper *shapeHelper = shape_op.getShapeHelper(op, {}, nullptr, nullptr); // If no shape helper, or unimplemented, just abort. - if (!shapeHelper || !shapeHelper->isImplemented()) + if (!shapeHelper) return false; // Compute shape. - if (failed(shapeHelper->computeShape())) { + if (!shapeHelper->isImplemented() || failed(shapeHelper->computeShape())) { delete shapeHelper; return false; } @@ -425,10 +424,10 @@ static bool exploreSameDimsUsingShapeInput(const DimAnalysis::DimT &dim, // Below are ONNX operations we know that specify the output shape via an // operand. Sorted in the alphabetical order. Value shapeInput = nullptr; - if (auto onnxOp = dyn_cast(op)) { + if (auto onnxOp = mlir::dyn_cast(op)) { // `shape` stores shape information for dimensions specified by `axes`. // `outputDimIndex` must be in `axes` in order to get dim from `shape`. - auto outputType = cast(onnxOp.getResult().getType()); + auto outputType = mlir::cast(onnxOp.getResult().getType()); SmallVector axesInt; ArrayAttr axes = onnxOp.getAxesAttr(); if (axes) { @@ -450,21 +449,21 @@ static bool exploreSameDimsUsingShapeInput(const DimAnalysis::DimT &dim, } if (found) shapeInput = onnxOp.getShape(); - } else if (auto onnxOp = dyn_cast(op)) { + } else if (auto onnxOp = mlir::dyn_cast(op)) { // `input` stores shape information. shapeInput = onnxOp.getInput(); - } else if (auto onnxOp = dyn_cast(op)) { + } else if (auto onnxOp = mlir::dyn_cast(op)) { // `shape` stores shape information. shapeInput = onnxOp.getShape(); - } else if (auto onnxOp = dyn_cast(op)) { + } else if (auto onnxOp = mlir::dyn_cast(op)) { // Optional `output_shape` stores shape information. if (!isNoneValue(onnxOp.getOutputShape())) shapeInput = onnxOp.getOutputShape(); - } else if (auto onnxOp = dyn_cast(op)) { + } else if (auto onnxOp = mlir::dyn_cast(op)) { // `shape` stores shape information. Only support `allow_zero == 0`. if (onnxOp.getAllowzero() == 0) shapeInput = onnxOp.getShape(); - } else if (auto onnxOp = dyn_cast(op)) { + } else if (auto onnxOp = mlir::dyn_cast(op)) { // If input dimension i is 1, `repeats` i stores shape information. Type inputType = onnxOp.getInput().getType(); ArrayRef inputShape = getShape(inputType); @@ -503,7 +502,7 @@ DimAnalysis::DimAnalysis(ArrayRef vals) { DimAnalysis::DimAnalysis(ModuleOp moduleOp) { moduleOp.walk([&](Operation *op) { - if (auto funcOp = dyn_cast(op)) { + if (auto funcOp = mlir::dyn_cast(op)) { // Build dimensions for function arguments and results. buildFunctionArgsRes(funcOp); } else { @@ -587,9 +586,9 @@ void DimAnalysis::buildFunctionArgsRes(func::FuncOp funcOp) { // Build internal mappings for results. Operation *terminator = funcOp.getRegion().back().getTerminator(); ValueRange resVals; - if (auto returnOp = dyn_cast(terminator)) + if (auto returnOp = mlir::dyn_cast(terminator)) resVals = returnOp.getOperands(); - else if (auto returnOp = dyn_cast(terminator)) + else if (auto returnOp = mlir::dyn_cast(terminator)) resVals = returnOp.getOperands(); ArrayAttr resAttrs = funcOp.getResAttrsAttr(); buildFor(resVals, resAttrs); @@ -862,14 +861,14 @@ void DimAnalysis::visitDim( return; // DimOp - if (auto dimOp = dyn_cast(op)) { + if (auto dimOp = mlir::dyn_cast(op)) { DimAnalysis::DimT newSameDim(dimOp.getData(), dimOp.getAxis()); sameDims.insert(newSameDim); return; } // CastOp - if (auto castOp = dyn_cast(op)) { + if (auto castOp = mlir::dyn_cast(op)) { if (auto d = insertDimWhenUseful(castOp.getInput(), dimIndex, sameDims)) LLVM_DEBUG(llvm::dbgs() << " - Added a new dim(" << d.value().first << ", " << d.value().second << ")\n"); @@ -933,7 +932,7 @@ void DimAnalysis::visitDim( } // ReshapeOp has some additional cases. - if (auto reshapeOp = dyn_cast(op)) { + if (auto reshapeOp = mlir::dyn_cast(op)) { if (reshapeOp.getAllowzero() != 0) return; @@ -954,20 +953,22 @@ void DimAnalysis::visitDim( // inputs. // Get the dynamic dimension from data. - auto dataType = cast(data.getType()); - auto outputType = cast(output.getType()); + auto dataType = mlir::cast(data.getType()); + auto outputType = mlir::cast(output.getType()); // Check if there is only one dynamic dimension in the data and output. bool dataHasOneDynamicDim = (llvm::count(dataType.getShape(), ShapedType::kDynamic) == 1); bool outputHasOneDynamicDim = (llvm::count(outputType.getShape(), ShapedType::kDynamic) == 1); // Check if the products of static sizes in the data and output are equal. - // It's ok to count ShapedType::kDynamic (dynamic dimension) in the size. int64_t dataStaticSize = 1, outputStaticSize = 1; - for (int64_t i = 0; i < dataType.getRank(); ++i) - dataStaticSize *= dataType.getShape()[i]; - for (int64_t i = 0; i < outputType.getRank(); ++i) - outputStaticSize *= outputType.getShape()[i]; + for (int64_t i = 0; i < dataType.getRank(); ++i) { + dataStaticSize *= dataType.isDynamicDim(i) ? -1 : dataType.getShape()[i]; + } + for (int64_t i = 0; i < outputType.getRank(); ++i) { + outputStaticSize *= + outputType.isDynamicDim(i) ? -1 : outputType.getShape()[i]; + } // Conditions hold, the dynamic dimension can be from the data. if (dataHasOneDynamicDim && outputHasOneDynamicDim && (dataStaticSize == outputStaticSize)) { @@ -1069,7 +1070,7 @@ void ONNXDimAnalysisPass::runOnOperation() { } else { Operation *op = val.getDefiningOp(); b.setInsertionPointAfter(op); - if (auto dimOp = dyn_cast(op)) + if (auto dimOp = mlir::dyn_cast(op)) val = dimOp.getData(); } DimAnalysis::DimT dim(val, dimAxis); diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 892cf01cf1..1e5eb23747 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -28,7 +28,7 @@ #define UNSUPPORTED_OPS(OP_TYPE) \ /* shape inference interface method */ \ mlir::LogicalResult mlir::OP_TYPE::inferShapes( \ - std::function doShapeInference) { \ + std::function doShapeInference) { \ return mlir::success(); \ } diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index febb5207c0..ab4503fbc2 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -19,6 +19,7 @@ #include "src/Dialect/ONNX/ONNXAttributes.hpp" #include "src/Dialect/ONNX/ONNXDialect.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" +#include "src/Dialect/ONNX/ONNXTraits.hpp" #include "src/Dialect/ONNX/ONNXTypes.hpp" #include "src/Interface/HasOnnxSubgraphOpInterface.hpp" #include "src/Interface/ResultTypeInferenceOpInterface.hpp" diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 3abab3a2e9..cfc70f6716 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -5,7 +5,7 @@ //******************************************************** def ONNXAbsOp:ONNX_Op<"Abs", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Abs operation"; let description = [{ Absolute takes one input data (Tensor) and produces one output data @@ -46,13 +46,13 @@ def ONNXAbsOp:ONNX_Op<"Abs", } def ONNXAcosOp:ONNX_Op<"Acos", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Acos operation"; let description = [{ Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -75,13 +75,13 @@ def ONNXAcosOp:ONNX_Op<"Acos", } def ONNXAcoshOp:ONNX_Op<"Acosh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Acosh operation"; let description = [{ Calculates the hyperbolic arccosine of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -104,7 +104,7 @@ def ONNXAcoshOp:ONNX_Op<"Acosh", } def ONNXAddOp:ONNX_Op<"Add", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Add operation"; let description = [{ @@ -160,7 +160,7 @@ def ONNXAddOp:ONNX_Op<"Add", } def ONNXAndOp:ONNX_Op<"And", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX And operation"; let description = [{ @@ -215,7 +215,7 @@ def ONNXAndOp:ONNX_Op<"And", } def ONNXArgMaxOp:ONNX_Op<"ArgMax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArgMax operation"; let description = [{ Computes the indices of the max elements of the input tensor's element along the @@ -254,7 +254,7 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax", } def ONNXArgMinOp:ONNX_Op<"ArgMin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArgMin operation"; let description = [{ Computes the indices of the min elements of the input tensor's element along the @@ -293,13 +293,13 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin", } def ONNXAsinOp:ONNX_Op<"Asin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Asin operation"; let description = [{ Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -322,13 +322,13 @@ def ONNXAsinOp:ONNX_Op<"Asin", } def ONNXAsinhOp:ONNX_Op<"Asinh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Asinh operation"; let description = [{ Calculates the hyperbolic arcsine of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -351,13 +351,13 @@ def ONNXAsinhOp:ONNX_Op<"Asinh", } def ONNXAtanOp:ONNX_Op<"Atan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Atan operation"; let description = [{ Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -380,13 +380,13 @@ def ONNXAtanOp:ONNX_Op<"Atan", } def ONNXAtanhOp:ONNX_Op<"Atanh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Atanh operation"; let description = [{ Calculates the hyperbolic arctangent of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -409,7 +409,7 @@ def ONNXAtanhOp:ONNX_Op<"Atanh", } def ONNXAveragePoolOp:ONNX_Op<"AveragePool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX AveragePool operation"; let description = [{ AveragePool consumes an input tensor X and applies average pooling across @@ -426,7 +426,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -445,7 +445,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedStrAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, DefaultValuedAttr:$count_include_pad, @@ -453,7 +453,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", I64ArrayAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -477,7 +477,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", } def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BatchNormalization operation"; let description = [{ Carries out batch normalization as described in the paper @@ -552,8 +552,56 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", }]; } +def ONNXBatchNormalizationV9Op:ONNX_Op<"BatchNormalizationV9", + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX BatchNormalization operation"; + let description = [{ + Carries out batch normalization as described in the paper + https://arxiv.org/abs/1502.03167. Depending on the mode it is being run, + there are multiple cases for the number of outputs, which we list below: + + Output case #1: Y, mean, var, saved_mean, saved_var (training mode) + Output case #2: Y (test mode) + + For previous (depreciated) non-spatial cases, implementors are suggested + to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op. + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$B, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$mean, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$var, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$momentum); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$out_mean, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$out_var, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$saved_mean, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$saved_var); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 5; + } + static int getNumberOfResults() { + return 5; + } + static std::vector getTypeMap() { + return {30,30,30,30,30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXBatchNormalizationV9OpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + def ONNXBernoulliOp:ONNX_Op<"Bernoulli", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Bernoulli operation"; let description = [{ Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor @@ -563,10 +611,10 @@ def ONNXBernoulliOp:ONNX_Op<"Bernoulli", This operator is non-deterministic and may not produce the same values in different implementations (even if a seed is specified). }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, OptionalAttr:$dtype, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[I1]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I1]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -589,7 +637,7 @@ def ONNXBernoulliOp:ONNX_Op<"Bernoulli", } def ONNXBitShiftOp:ONNX_Op<"BitShift", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitShift operation"; let description = [{ Bitwise shift operator performs element-wise operation. For each input element, if the @@ -632,7 +680,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", } def ONNXBitwiseAndOp:ONNX_Op<"BitwiseAnd", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseAnd operation"; let description = [{ Returns the tensor resulting from performing the bitwise `and` operation @@ -666,7 +714,7 @@ def ONNXBitwiseAndOp:ONNX_Op<"BitwiseAnd", } def ONNXBitwiseNotOp:ONNX_Op<"BitwiseNot", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseNot operation"; let description = [{ Returns the bitwise not of the input tensor element-wise. @@ -695,7 +743,7 @@ def ONNXBitwiseNotOp:ONNX_Op<"BitwiseNot", } def ONNXBitwiseOrOp:ONNX_Op<"BitwiseOr", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseOr operation"; let description = [{ Returns the tensor resulting from performing the bitwise `or` operation @@ -729,7 +777,7 @@ def ONNXBitwiseOrOp:ONNX_Op<"BitwiseOr", } def ONNXBitwiseXorOp:ONNX_Op<"BitwiseXor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BitwiseXor operation"; let description = [{ Returns the tensor resulting from performing the bitwise `xor` operation @@ -763,7 +811,7 @@ def ONNXBitwiseXorOp:ONNX_Op<"BitwiseXor", } def ONNXBlackmanWindowOp:ONNX_Op<"BlackmanWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX BlackmanWindow operation"; let description = [{ Generates a Blackman window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -794,7 +842,7 @@ def ONNXBlackmanWindowOp:ONNX_Op<"BlackmanWindow", } def ONNXCastOp:ONNX_Op<"Cast", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Cast operation"; let description = [{ @@ -863,10 +911,10 @@ def ONNXCastOp:ONNX_Op<"Cast", | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | | else | RNE | RNE | RNE | RNE | }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input, DefaultValuedAttr:$saturate, TypeAttr:$to); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -895,7 +943,7 @@ def ONNXCastOp:ONNX_Op<"Cast", } def ONNXCastLikeOp:ONNX_Op<"CastLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CastLike operation"; let description = [{ The operator casts the elements of a given input tensor (the first input) to @@ -928,7 +976,7 @@ def ONNXCastLikeOp:ONNX_Op<"CastLike", } def ONNXCeilOp:ONNX_Op<"Ceil", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Ceil operation"; let description = [{ Ceil takes one input data (Tensor) and produces one output data @@ -959,7 +1007,7 @@ def ONNXCeilOp:ONNX_Op<"Ceil", } def ONNXCeluOp:ONNX_Op<"Celu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Celu operation"; let description = [{ Continuously Differentiable Exponential Linear Units: @@ -970,9 +1018,10 @@ def ONNXCeluOp:ONNX_Op<"Celu", max(0,x) + min(0,alpha*(exp(x/alpha)-1)) ``` }]; - let arguments = (ins TensorOf<[F32]>:$X, + // FIXME(FXML-4138): Remove manual modification of BF16 support when the operation definition is updated upstream + let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[BF16]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs TensorOf<[F32]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[F32]>, TensorOf<[BF16]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -995,7 +1044,7 @@ def ONNXCeluOp:ONNX_Op<"Celu", } def ONNXCenterCropPadOp:ONNX_Op<"CenterCropPad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CenterCropPad operation"; let description = [{ Center crop or pad an input to given dimensions. @@ -1033,7 +1082,7 @@ def ONNXCenterCropPadOp:ONNX_Op<"CenterCropPad", } def ONNXClipOp:ONNX_Op<"Clip", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1066,7 +1115,7 @@ def ONNXClipOp:ONNX_Op<"Clip", } def ONNXClipV12Op:ONNX_Op<"ClipV12", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1099,7 +1148,7 @@ def ONNXClipV12Op:ONNX_Op<"ClipV12", } def ONNXClipV11Op:ONNX_Op<"ClipV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1132,7 +1181,7 @@ def ONNXClipV11Op:ONNX_Op<"ClipV11", } def ONNXClipV6Op:ONNX_Op<"ClipV6", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<6>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Clip operation"; let description = [{ Clip operator limits the given input within an interval. The interval is @@ -1165,7 +1214,7 @@ def ONNXClipV6Op:ONNX_Op<"ClipV6", } def ONNXCol2ImOp:ONNX_Op<"Col2Im", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Col2Im operation"; let description = [{ The operator rearranges column blocks back into a multidimensional image @@ -1208,7 +1257,7 @@ def ONNXCol2ImOp:ONNX_Op<"Col2Im", } def ONNXCompressOp:ONNX_Op<"Compress", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Compress operation"; let description = [{ Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. @@ -1243,7 +1292,7 @@ def ONNXCompressOp:ONNX_Op<"Compress", } def ONNXConcatOp:ONNX_Op<"Concat", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Concat operation"; let description = [{ Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. @@ -1274,7 +1323,7 @@ def ONNXConcatOp:ONNX_Op<"Concat", } def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConcatFromSequence operation"; let description = [{ Concatenate a sequence of tensors into a single tensor. @@ -1309,7 +1358,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", } def ONNXConstantOp:ONNX_Op<"Constant", - [Pure, ConstantLike, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, ConstantLike, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCustomAssemblyFormat = 1; let hasCanonicalizer = 1; let summary = "ONNX Constant operation"; @@ -1362,7 +1411,7 @@ def ONNXConstantOp:ONNX_Op<"Constant", } def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCustomAssemblyFormat = 1; let summary = "ONNX ConstantOfShape operation"; let description = [{ @@ -1394,22 +1443,22 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", } def ONNXConvOp:ONNX_Op<"Conv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Conv operation"; let description = [{ The convolution operator consumes an input tensor and a filter, and computes the output. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, DefaultValuedStrAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, OptionalAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let builders = [ OpBuilder<(ins "Value":$X, "Value":$W, "Value":$B, "StringAttr":$auto_pad, "ArrayAttr":$dilations, "IntegerAttr":$group, "ArrayAttr":$kernel_shape, "ArrayAttr":$pads, "ArrayAttr":$strides), [{ auto resultType = UnrankedTensorType::get(mlir::cast(X.getType()).getElementType()); @@ -1443,7 +1492,7 @@ def ONNXConvOp:ONNX_Op<"Conv", } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvInteger operation"; let description = [{ The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point, @@ -1482,7 +1531,7 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", } def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvTranspose operation"; let description = [{ The convolution transpose operator consumes an input tensor and a filter, @@ -1500,9 +1549,9 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, DefaultValuedStrAttr:$auto_pad, OptionalAttr:$dilations, DefaultValuedAttr:$group, @@ -1511,7 +1560,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", OptionalAttr:$output_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -1535,13 +1584,13 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", } def ONNXCosOp:ONNX_Op<"Cos", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Cos operation"; let description = [{ Calculates the cosine of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -1564,13 +1613,13 @@ def ONNXCosOp:ONNX_Op<"Cos", } def ONNXCoshOp:ONNX_Op<"Cosh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Cosh operation"; let description = [{ Calculates the hyperbolic cosine of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -1593,7 +1642,7 @@ def ONNXCoshOp:ONNX_Op<"Cosh", } def ONNXCumSumOp:ONNX_Op<"CumSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CumSum operation"; let description = [{ Performs cumulative sum of the input elements along the given axis. @@ -1644,7 +1693,7 @@ def ONNXCumSumOp:ONNX_Op<"CumSum", } def ONNXDFTOp:ONNX_Op<"DFT", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DFT operation"; let description = [{ Computes the discrete Fourier Transform (DFT) of the input. @@ -1693,7 +1742,7 @@ def ONNXDFTOp:ONNX_Op<"DFT", } def ONNXDFTV17Op:ONNX_Op<"DFTV17", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DFT operation"; let description = [{ Computes the discrete Fourier transform of input. @@ -1726,24 +1775,24 @@ def ONNXDFTV17Op:ONNX_Op<"DFTV17", } def ONNXDeformConvOp:ONNX_Op<"DeformConv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DeformConv operation"; let description = [{ Performs deformable convolution as described in https://arxiv.org/abs/1703.06211 and https://arxiv.org/abs/1811.11168. This operator specification supports the general N-D case. Note that most common use cases have 2D or 3D data. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$offset, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$mask, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$offset, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$mask, OptionalAttr:$dilations, DefaultValuedAttr:$group, OptionalAttr:$kernel_shape, DefaultValuedAttr:$offset_group, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 5; @@ -1766,7 +1815,7 @@ def ONNXDeformConvOp:ONNX_Op<"DeformConv", } def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX DepthToSpace operation"; let description = [{ @@ -1821,7 +1870,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", } def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX DequantizeLinear operation"; let description = [{ @@ -1861,7 +1910,7 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", } def ONNXDetOp:ONNX_Op<"Det", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Det operation"; let description = [{ Det calculates determinant of a square matrix or batches of square matrices. @@ -1870,8 +1919,8 @@ def ONNXDetOp:ONNX_Op<"Det", The output is a tensor of shape `[*]`, containing the determinants of all input submatrices. e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`). }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -1894,7 +1943,7 @@ def ONNXDetOp:ONNX_Op<"Det", } def ONNXDivOp:ONNX_Op<"Div", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Div operation"; let description = [{ @@ -1950,7 +1999,7 @@ def ONNXDivOp:ONNX_Op<"Div", } def ONNXDropoutOp:ONNX_Op<"Dropout", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Dropout operation"; let description = [{ @@ -1967,11 +2016,11 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", ``` This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$data, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$ratio, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$data, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$ratio, AnyTypeOf<[TensorOf<[I1]>, NoneType]>:$training_mode, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$output, + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output, AnyTypeOf<[TensorOf<[I1]>, NoneType]>:$mask); let extraClassDeclaration = [{ static int getNumberOfOperands() { @@ -1995,7 +2044,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", } def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DynamicQuantizeLinear operation"; let description = [{ A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data. @@ -2052,7 +2101,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", } def ONNXEinsumOp:ONNX_Op<"Einsum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<12>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Einsum operation"; let description = [{ An einsum of the form `term1, term2 -> output-term` produces an output tensor using the following equation @@ -2107,7 +2156,7 @@ def ONNXEinsumOp:ONNX_Op<"Einsum", } def ONNXEluOp:ONNX_Op<"Elu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Elu operation"; let description = [{ Elu takes one input data (Tensor) and produces one output data @@ -2115,9 +2164,9 @@ def ONNXEluOp:ONNX_Op<"Elu", 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -2140,7 +2189,7 @@ def ONNXEluOp:ONNX_Op<"Elu", } def ONNXEqualOp:ONNX_Op<"Equal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Equal operation"; let description = [{ @@ -2197,7 +2246,7 @@ def ONNXEqualOp:ONNX_Op<"Equal", } def ONNXErfOp:ONNX_Op<"Erf", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Erf operation"; let description = [{ Computes the error function of the given input tensor element-wise. @@ -2226,13 +2275,13 @@ def ONNXErfOp:ONNX_Op<"Erf", } def ONNXExpOp:ONNX_Op<"Exp", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Exp operation"; let description = [{ Calculates the exponential of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let builders = [ OpBuilder<(ins "Value":$input), [{ auto resultType = UnrankedTensorType::get(mlir::cast(input.getType()).getElementType()); @@ -2265,7 +2314,7 @@ def ONNXExpOp:ONNX_Op<"Exp", } def ONNXExpandOp:ONNX_Op<"Expand", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Expand operation"; let description = [{ Broadcast the input tensor following the given shape and the broadcast rule. @@ -2303,7 +2352,7 @@ def ONNXExpandOp:ONNX_Op<"Expand", } def ONNXEyeLikeOp:ONNX_Op<"EyeLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX EyeLike operation"; let description = [{ Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D @@ -2314,10 +2363,10 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message and be valid as an output type. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I1]>]>:$input, OptionalAttr:$dtype, DefaultValuedAttr:$k); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I1]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -2328,7 +2377,9 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", static std::vector getTypeMap() { return {-1}; } - }]; + + mlir::Type getResultElementType(); + }]; let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { @@ -2340,7 +2391,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", } def ONNXFlattenOp:ONNX_Op<"Flatten", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Flatten operation"; let description = [{ Flattens the input tensor into a 2D matrix. If input tensor has shape @@ -2373,7 +2424,7 @@ def ONNXFlattenOp:ONNX_Op<"Flatten", } def ONNXFloorOp:ONNX_Op<"Floor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Floor operation"; let description = [{ Floor takes one input data (Tensor) and produces one output data @@ -2404,7 +2455,7 @@ def ONNXFloorOp:ONNX_Op<"Floor", } def ONNXGRUOp:ONNX_Op<"GRU", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GRU operation"; let description = [{ @@ -2456,12 +2507,12 @@ def ONNXGRUOp:ONNX_Op<"GRU", * Ht = (1 - zt) (.) ht + zt (.) Ht-1 This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, AnyTypeOf<[TensorOf<[I32]>, NoneType]>:$sequence_lens, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, OptionalAttr:$activations, @@ -2470,8 +2521,8 @@ def ONNXGRUOp:ONNX_Op<"GRU", OptionalAttr:$hidden_size, DefaultValuedAttr:$layout, DefaultValuedAttr:$linear_before_reset); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 6; @@ -2494,7 +2545,7 @@ def ONNXGRUOp:ONNX_Op<"GRU", } def ONNXGatherOp:ONNX_Op<"Gather", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gather operation"; let description = [{ Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather @@ -2573,7 +2624,7 @@ def ONNXGatherOp:ONNX_Op<"Gather", } def ONNXGatherElementsOp:ONNX_Op<"GatherElements", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GatherElements operation"; let description = [{ GatherElements takes two inputs `data` and `indices` of the same rank r >= 1 @@ -2655,7 +2706,7 @@ def ONNXGatherElementsOp:ONNX_Op<"GatherElements", } def ONNXGatherNDOp:ONNX_Op<"GatherND", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GatherND operation"; let description = [{ Given `data` tensor of rank `r` >= 1, `indices` tensor of rank `q` >= 1, and `batch_dims` integer `b`, this operator gathers @@ -2770,7 +2821,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND", } def ONNXGeluOp:ONNX_Op<"Gelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gelu operation"; let description = [{ Gelu takes one input data (Tensor) and produces one @@ -2807,7 +2858,7 @@ def ONNXGeluOp:ONNX_Op<"Gelu", } def ONNXGemmOp:ONNX_Op<"Gemm", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gemm operation"; let description = [{ General Matrix multiplication: @@ -2853,7 +2904,7 @@ def ONNXGemmOp:ONNX_Op<"Gemm", } def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GlobalAveragePool operation"; let description = [{ @@ -2861,8 +2912,8 @@ def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", the values in the same channel. This is equivalent to AveragePool with kernel size equal to the spatial dimension of input tensor. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -2885,16 +2936,16 @@ def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", } def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GlobalLpPool operation"; let description = [{ GlobalLpPool consumes an input tensor X and applies lp pool pooling across the values in the same channel. This is equivalent to LpPool with kernel size equal to the spatial dimension of input tensor. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedAttr:$p); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -2917,7 +2968,7 @@ def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", } def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX GlobalMaxPool operation"; let description = [{ @@ -2925,8 +2976,8 @@ def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", the values in the same channel. This is equivalent to MaxPool with kernel size equal to the spatial dimension of input tensor. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -2949,7 +3000,7 @@ def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", } def ONNXGreaterOp:ONNX_Op<"Greater", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Greater operation"; let description = [{ @@ -3006,7 +3057,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater", } def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let summary = "ONNX GreaterOrEqual operation"; let description = [{ Returns the tensor resulted from performing the `greater_equal` logical operation @@ -3062,7 +3113,108 @@ def ONNXGreaterOrEqualOp:ONNX_Op<"GreaterOrEqual", } def ONNXGridSampleOp:ONNX_Op<"GridSample", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GridSample operation"; + let description = [{ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. + For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), + the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). + More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), + the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + + The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). + The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values + at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) + and a padding mode (for `grid` positions falling outside the 2-dimensional image). + + For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + + The GridSample operator is often used in doing grid generator and sampler in the + [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). + See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid, + DefaultValuedAttr:$align_corners, + DefaultValuedStrAttr:$mode, + DefaultValuedStrAttr:$padding_mode); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let hasVerifier = 1; +} + +def ONNXGridSampleV20Op:ONNX_Op<"GridSampleV20", + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GridSample operation"; + let description = [{ + Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. + For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2), + the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W), + the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out). + More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr), + the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out). + + The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in). + The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values + at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode) + and a padding mode (for `grid` positions falling outside the 2-dimensional image). + + For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`. + They are used to interpolate output values of `Y[n, c, h_out, w_out]`. + + The GridSample operator is often used in doing grid generator and sampler in the + [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). + See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$grid, + DefaultValuedAttr:$align_corners, + DefaultValuedStrAttr:$mode, + DefaultValuedStrAttr:$padding_mode); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV20OpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + +def ONNXGridSampleV16Op:ONNX_Op<"GridSampleV16", + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GridSample operation"; let description = [{ Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from `grid`. @@ -3099,7 +3251,7 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample", let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { - onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleOpShapeHelper(op, oper, ieb, scope); + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGridSampleV16OpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); return sh; } @@ -3107,7 +3259,63 @@ def ONNXGridSampleOp:ONNX_Op<"GridSample", } def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<21>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let summary = "ONNX GroupNormalization operation"; + let description = [{ + A GroupNormalization function. Carries out group normalization as described in + the paper https://arxiv.org/abs/1803.08494 + + This operator transforms input according to + ``` + y = scale * (x - mean) / sqrt(variance + epsilon) + bias, + ``` + where the mean and variance are computed per instance per group of channels, and + `scale` and `bias` should be specified for each group of channels. The number of + groups `num_groups` should be divisible by the number of channels so that there are + an equal number of channels per group. + + The overall computation has two stages: the first stage normalizes the elements to + have zero mean and unit variance for each instance in each group, and the second + stage scales and shifts the results of the first stage. The floating-point precision + used in the first stage is determined by the `stash_type` attribute. For example, + if `stash_type` is 1, the operator casts all input variables to 32-bit float, + performs the computation, and finally casts the normalized results back to the + original type of `X`. The second stage does not depend on `stash_type`. + + When the number of groups is the same as the number of channels, this operator is + equivalent to InstanceNormalization. When there is only one group, this operator + is equivalent to LayerNormalization. + }]; + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$bias, + DefaultValuedAttr:$epsilon, + SI64Attr:$num_groups, + DefaultValuedAttr:$stash_type); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 3; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {30}; + } + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + +def ONNXGroupNormalizationV18Op:ONNX_Op<"GroupNormalizationV18", + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX GroupNormalization operation"; let description = [{ A GroupNormalization function. Carries out group normalization as described in @@ -3146,15 +3354,16 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization", let extraClassDefinition = [{ onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef oper, onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { - onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope); + onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationV18OpShapeHelper(op, oper, ieb, scope); assert(sh && "failed to allocate shape helper"); return sh; } }]; + let hasVerifier = 1; } def ONNXHammingWindowOp:ONNX_Op<"HammingWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HammingWindow operation"; let description = [{ Generates a Hamming window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -3185,7 +3394,7 @@ def ONNXHammingWindowOp:ONNX_Op<"HammingWindow", } def ONNXHannWindowOp:ONNX_Op<"HannWindow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HannWindow operation"; let description = [{ Generates a Hann window as described in the paper https://ieeexplore.ieee.org/document/1455106. @@ -3216,17 +3425,17 @@ def ONNXHannWindowOp:ONNX_Op<"HannWindow", } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HardSigmoid operation"; let description = [{ HardSigmoid takes one input data (Tensor) and produces one output data (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)), is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -3249,15 +3458,15 @@ def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", } def ONNXHardSwishOp:ONNX_Op<"HardSwish", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX HardSwish operation"; let description = [{ HardSwish takes one input data (Tensor) and produces one output data (Tensor) where the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x), where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -3280,7 +3489,7 @@ def ONNXHardSwishOp:ONNX_Op<"HardSwish", } def ONNXHardmaxOp:ONNX_Op<"Hardmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Hardmax operation"; let description = [{ The operator computes the hardmax values for the given input: @@ -3317,7 +3526,7 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", } def ONNXIdentityOp:ONNX_Op<"Identity", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Identity operation"; let description = [{ @@ -3357,7 +3566,7 @@ def ONNXIdentityOp:ONNX_Op<"Identity", } def ONNXIfOp:ONNX_Op<"If", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX If operation"; let description = [{ If conditional @@ -3394,7 +3603,7 @@ def ONNXIfOp:ONNX_Op<"If", } def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX InstanceNormalization operation"; let description = [{ Carries out instance normalization as described in the paper @@ -3404,11 +3613,11 @@ def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", where mean and variance are computed per instance per channel. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$B, DefaultValuedAttr:$epsilon); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -3432,7 +3641,7 @@ def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", } def ONNXIsInfOp:ONNX_Op<"IsInf", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX IsInf operation"; let description = [{ Map infinity to true and other values to false. @@ -3464,7 +3673,7 @@ def ONNXIsInfOp:ONNX_Op<"IsInf", } def ONNXIsNaNOp:ONNX_Op<"IsNaN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX IsNaN operation"; let description = [{ Returns which elements of the input are NaN. @@ -3493,7 +3702,7 @@ def ONNXIsNaNOp:ONNX_Op<"IsNaN", } def ONNXLRNOp:ONNX_Op<"LRN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LRN operation"; let description = [{ Local Response Normalization proposed in the [AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). @@ -3535,7 +3744,7 @@ def ONNXLRNOp:ONNX_Op<"LRN", } def ONNXLSTMOp:ONNX_Op<"LSTM", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX LSTM operation"; let description = [{ @@ -3590,14 +3799,14 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", * Ht = ot (.) h(Ct) This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, AnyTypeOf<[TensorOf<[I32]>, NoneType]>:$sequence_lens, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_c, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$P, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_c, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$P, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, OptionalAttr:$activations, @@ -3606,9 +3815,9 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", OptionalAttr:$hidden_size, DefaultValuedAttr:$input_forget, DefaultValuedAttr:$layout); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_c); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_c); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 8; @@ -3631,7 +3840,7 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", } def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LayerNormalization operation"; let description = [{ This is layer normalization defined in ONNX as function. @@ -3672,7 +3881,9 @@ def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization", Let `d[i]` indicate the i-th dimension of `X`. If `X`'s shape is `[d[0], ..., d[axis-1], d[axis], ..., d[rank-1]]`, the shape of `Mean` and `InvStdDev` is `[d[0], ..., d[axis-1], 1, ..., 1]`. - `Y` and `X` have the same shape. + `Y` and `X` have the same shape. This operator supports unidirectional broadcasting + (tensors `Scale` and `B` should be unidirectional broadcastable to tensor `X`); + for more details please check [the doc](Broadcasting.md). }]; let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$X, AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$Scale, @@ -3706,7 +3917,7 @@ def ONNXLayerNormalizationOp:ONNX_Op<"LayerNormalization", } def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LeakyRelu operation"; let description = [{ LeakyRelu takes input data (Tensor) and an argument alpha, and produces one @@ -3738,7 +3949,7 @@ def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", } def ONNXLessOp:ONNX_Op<"Less", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Less operation"; let description = [{ @@ -3795,7 +4006,7 @@ def ONNXLessOp:ONNX_Op<"Less", } def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsElementType]> { let summary = "ONNX LessOrEqual operation"; let description = [{ Returns the tensor resulted from performing the `less_equal` logical operation @@ -3851,7 +4062,7 @@ def ONNXLessOrEqualOp:ONNX_Op<"LessOrEqual", } def ONNXLogOp:ONNX_Op<"Log", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Log operation"; let description = [{ Calculates the natural log of the given input tensor, element-wise. @@ -3880,7 +4091,7 @@ def ONNXLogOp:ONNX_Op<"Log", } def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LogSoftmax operation"; let description = [{ The operator computes the log of softmax values for the given input: @@ -3917,7 +4128,7 @@ def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", } def ONNXLoopOp:ONNX_Op<"Loop", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let hasCanonicalizer = 1; let summary = "ONNX Loop operation"; let description = [{ @@ -4091,15 +4302,15 @@ def ONNXLoopOp:ONNX_Op<"Loop", } def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LpNormalization operation"; let description = [{ Given a matrix, apply Lp-normalization along the provided axis. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, DefaultValuedAttr:$axis, DefaultValuedAttr:$p); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -4122,7 +4333,7 @@ def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", } def ONNXLpPoolOp:ONNX_Op<"LpPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LpPool operation"; let description = [{ LpPool consumes an input tensor X and applies Lp pooling across @@ -4149,7 +4360,7 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + {kernelSpatialShape} - input_spatial_shape[i] ``` }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedStrAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, OptionalAttr:$dilations, @@ -4157,7 +4368,7 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", DefaultValuedAttr:$p, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -4180,10 +4391,10 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", } def ONNXMatMulOp:ONNX_Op<"MatMul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MatMul operation"; let description = [{ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). }]; let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>:$A, AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>:$B); @@ -4210,10 +4421,10 @@ def ONNXMatMulOp:ONNX_Op<"MatMul", } def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MatMulInteger operation"; let description = [{ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. }]; let arguments = (ins AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>]>:$A, @@ -4244,7 +4455,7 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", } def ONNXMaxOp:ONNX_Op<"Max", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Max operation"; let description = [{ Element-wise max of each of the input tensors (with Numpy-style broadcasting support). @@ -4276,7 +4487,7 @@ def ONNXMaxOp:ONNX_Op<"Max", } def ONNXMaxPoolOp:ONNX_Op<"MaxPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxPool operation"; let description = [{ MaxPool consumes an input tensor X and applies max pooling across @@ -4293,7 +4504,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", ``` output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - dilation[i] * (kernel_shape[i] - 1) - 1) / strides_spatial_shape[i] + 1) ``` - if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. + if ceil_mode is enabled. `pad_shape[i]` is the sum of pads along axis `i`. Sliding windows that would start in the right padded region are ignored. `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following when ceil_mode is enabled: ``` @@ -4312,7 +4523,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", The output of each pooling window is maximum number of elements exclude pad. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[UI8]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[UI8]>]>:$X, DefaultValuedStrAttr:$auto_pad, DefaultValuedAttr:$ceil_mode, OptionalAttr:$dilations, @@ -4320,7 +4531,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", OptionalAttr:$pads, DefaultValuedAttr:$storage_order, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[UI8]>]>:$Y, + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[UI8]>]>:$Y, AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$Indices); let extraClassDeclaration = [{ static int getNumberOfOperands() { @@ -4344,18 +4555,18 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", } def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxRoiPool operation"; let description = [{ ROI max pool consumes an input tensor X and region of interests (RoIs) to apply max pooling across each RoI, to produce output 4-D tensor of shape (num_rois, channels, pooled_shape[0], pooled_shape[1]). }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$rois, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$rois, I64ArrayAttr:$pooled_shape, DefaultValuedAttr:$spatial_scale); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 2; @@ -4378,7 +4589,7 @@ def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", } def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MaxUnpool operation"; let description = [{ MaxUnpool essentially computes the partial inverse of the MaxPool op. @@ -4400,13 +4611,13 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", which define the exact unpooling op. The attributes typically have the same values as the corresponding pooling op that the unpooling op is trying to invert. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, TensorOf<[I64]>:$I, AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$output_shape, I64ArrayAttr:$kernel_shape, OptionalAttr:$pads, OptionalAttr:$strides); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -4429,7 +4640,7 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", } def ONNXMeanOp:ONNX_Op<"Mean", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Mean operation"; let description = [{ Element-wise mean of each of the input tensors (with Numpy-style broadcasting support). @@ -4461,7 +4672,7 @@ def ONNXMeanOp:ONNX_Op<"Mean", } def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MeanVarianceNormalization operation"; let description = [{ A MeanVarianceNormalization Function: Perform mean variance normalization @@ -4492,7 +4703,7 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", } def ONNXMelWeightMatrixOp:ONNX_Op<"MelWeightMatrix", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX MelWeightMatrix operation"; let description = [{ Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on the mel scale. @@ -4533,7 +4744,7 @@ def ONNXMelWeightMatrixOp:ONNX_Op<"MelWeightMatrix", } def ONNXMinOp:ONNX_Op<"Min", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Min operation"; let description = [{ Element-wise min of each of the input tensors (with Numpy-style broadcasting support). @@ -4565,7 +4776,7 @@ def ONNXMinOp:ONNX_Op<"Min", } def ONNXMishOp:ONNX_Op<"Mish", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Mish operation"; let description = [{ Mish: A Self Regularized Non-Monotonic Neural Activation Function. @@ -4576,8 +4787,8 @@ def ONNXMishOp:ONNX_Op<"Mish", mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) ``` }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -4600,7 +4811,7 @@ def ONNXMishOp:ONNX_Op<"Mish", } def ONNXModOp:ONNX_Op<"Mod", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Mod operation"; let description = [{ Performs element-wise binary modulus (with Numpy-style broadcasting support). @@ -4644,7 +4855,7 @@ def ONNXModOp:ONNX_Op<"Mod", } def ONNXMulOp:ONNX_Op<"Mul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Mul operation"; let description = [{ @@ -4700,13 +4911,13 @@ def ONNXMulOp:ONNX_Op<"Mul", } def ONNXMultinomialOp:ONNX_Op<"Multinomial", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Multinomial operation"; let description = [{ Generate a tensor of samples from a multinomial distribution according to the probabilities of each of the possible outcomes. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, DefaultValuedAttr:$dtype, DefaultValuedAttr:$sample_size, OptionalAttr:$seed); @@ -4733,7 +4944,7 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial", } def ONNXNegOp:ONNX_Op<"Neg", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Neg operation"; let description = [{ Neg takes one input data (Tensor) and produces one output data @@ -4774,7 +4985,7 @@ def ONNXNegOp:ONNX_Op<"Neg", } def ONNXNegativeLogLikelihoodLossOp:ONNX_Op<"NegativeLogLikelihoodLoss", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NegativeLogLikelihoodLoss operation"; let description = [{ A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss. @@ -4879,12 +5090,12 @@ def ONNXNegativeLogLikelihoodLossOp:ONNX_Op<"NegativeLogLikelihoodLoss", // -1.57 ``` }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input, AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>]>:$target, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$weight, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$weight, OptionalAttr:$ignore_index, DefaultValuedStrAttr:$reduction); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$loss); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$loss); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -4907,7 +5118,7 @@ def ONNXNegativeLogLikelihoodLossOp:ONNX_Op<"NegativeLogLikelihoodLoss", } def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NonMaxSuppression operation"; let description = [{ Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. @@ -4948,7 +5159,7 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", } def ONNXNonZeroOp:ONNX_Op<"NonZero", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX NonZero operation"; let description = [{ Returns the indices of the elements that are non-zero @@ -4981,7 +5192,7 @@ def ONNXNonZeroOp:ONNX_Op<"NonZero", } def ONNXNotOp:ONNX_Op<"Not", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Not operation"; let description = [{ Returns the negation of the input tensor element-wise. @@ -5010,7 +5221,7 @@ def ONNXNotOp:ONNX_Op<"Not", } def ONNXOneHotOp:ONNX_Op<"OneHot", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OneHot operation"; let description = [{ Produces a one-hot tensor based on inputs. @@ -5061,7 +5272,7 @@ def ONNXOneHotOp:ONNX_Op<"OneHot", } def ONNXOptionalOp:ONNX_Op<"Optional", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Optional operation"; let description = [{ Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute, @@ -5093,7 +5304,7 @@ def ONNXOptionalOp:ONNX_Op<"Optional", } def ONNXOptionalGetElementOp:ONNX_Op<"OptionalGetElement", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OptionalGetElement operation"; let description = [{ If the input is a tensor or sequence type, it returns the input. @@ -5125,7 +5336,7 @@ def ONNXOptionalGetElementOp:ONNX_Op<"OptionalGetElement", } def ONNXOptionalHasElementOp:ONNX_Op<"OptionalHasElement", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OptionalHasElement operation"; let description = [{ Returns true if (1) the input is an optional-type and contains an element, @@ -5157,7 +5368,7 @@ def ONNXOptionalHasElementOp:ONNX_Op<"OptionalHasElement", } def ONNXOrOp:ONNX_Op<"Or", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Or operation"; let description = [{ @@ -5212,7 +5423,7 @@ def ONNXOrOp:ONNX_Op<"Or", } def ONNXPReluOp:ONNX_Op<"PRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX PRelu operation"; let description = [{ PRelu takes input data (Tensor) and slope tensor as input, and produces one @@ -5246,7 +5457,7 @@ def ONNXPReluOp:ONNX_Op<"PRelu", } def ONNXPadOp:ONNX_Op<"Pad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5389,7 +5600,7 @@ def ONNXPadOp:ONNX_Op<"Pad", } def ONNXPadV18Op:ONNX_Op<"PadV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5496,7 +5707,7 @@ def ONNXPadV18Op:ONNX_Op<"PadV18", } def ONNXPadV13Op:ONNX_Op<"PadV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5602,7 +5813,7 @@ def ONNXPadV13Op:ONNX_Op<"PadV13", } def ONNXPadV11Op:ONNX_Op<"PadV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, @@ -5708,7 +5919,7 @@ def ONNXPadV11Op:ONNX_Op<"PadV11", } def ONNXPadV2Op:ONNX_Op<"PadV2", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Pad operation"; let description = [{ Given `data` tensor, pads, mode, and value. @@ -5755,7 +5966,7 @@ def ONNXPadV2Op:ONNX_Op<"PadV2", } def ONNXPowOp:ONNX_Op<"Pow", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<15>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Pow operation"; let description = [{ @@ -5810,7 +6021,7 @@ def ONNXPowOp:ONNX_Op<"Pow", } def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QLinearConv operation"; let description = [{ The convolution operator consumes a quantized input tensor, its scale and zero point, @@ -5859,10 +6070,10 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", } def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QLinearMatMul operation"; let description = [{ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html). It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. @@ -5905,7 +6116,7 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", } def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX QuantizeLinear operation"; let description = [{ The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor. @@ -5946,7 +6157,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", } def ONNXRNNOp:ONNX_Op<"RNN", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX RNN operation"; let description = [{ @@ -5991,12 +6202,12 @@ def ONNXRNNOp:ONNX_Op<"RNN", * Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$W, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$R, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$B, AnyTypeOf<[TensorOf<[I32]>, NoneType]>:$sequence_lens, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$initial_h, OptionalAttr:$activation_alpha, OptionalAttr:$activation_beta, DefaultValuedAttr:$activations, @@ -6004,8 +6215,8 @@ def ONNXRNNOp:ONNX_Op<"RNN", DefaultValuedStrAttr:$direction, OptionalAttr:$hidden_size, DefaultValuedAttr:$layout); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$Y_h); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 6; @@ -6028,7 +6239,7 @@ def ONNXRNNOp:ONNX_Op<"RNN", } def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomNormal operation"; let description = [{ Generate a tensor with random values drawn from a normal distribution. The shape @@ -6044,7 +6255,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", DefaultValuedAttr:$scale, OptionalAttr:$seed, I64ArrayAttr:$shape); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 0; @@ -6067,7 +6278,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", } def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomNormalLike operation"; let description = [{ Generate a tensor with random values drawn from a normal distribution. @@ -6078,12 +6289,12 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message, and be valid as an output type. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$input, OptionalAttr:$dtype, DefaultValuedAttr:$mean, DefaultValuedAttr:$scale, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -6107,7 +6318,7 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", } def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomUniform operation"; let description = [{ Generate a tensor with random values drawn from a uniform distribution. The shape @@ -6122,7 +6333,7 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", DefaultValuedAttr:$low, OptionalAttr:$seed, I64ArrayAttr:$shape); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 0; @@ -6145,7 +6356,7 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", } def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RandomUniformLike operation"; let description = [{ Generate a tensor with random values drawn from a uniform distribution. @@ -6156,12 +6367,12 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the TensorProto message and be valid as an output type. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$input, OptionalAttr:$dtype, DefaultValuedAttr:$high, DefaultValuedAttr:$low, OptionalAttr:$seed); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -6184,7 +6395,7 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", } def ONNXRangeOp:ONNX_Op<"Range", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Range operation"; let description = [{ Generate a tensor containing a sequence of numbers that begin at `start` and extends by increments of `delta` @@ -6245,7 +6456,7 @@ def ONNXRangeOp:ONNX_Op<"Range", } def ONNXReciprocalOp:ONNX_Op<"Reciprocal", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Reciprocal operation"; let description = [{ Reciprocal takes one input data (Tensor) and produces one output data @@ -6276,7 +6487,7 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", } def ONNXReduceL1Op:ONNX_Op<"ReduceL1", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL1 operation"; let description = [{ Computes the L1 norm of the input tensor's elements along the provided axes. The resulting @@ -6315,7 +6526,7 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", } def ONNXReduceL1V13Op:ONNX_Op<"ReduceL1V13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL1 operation"; let description = [{ Computes the L1 norm of the input tensor's elements along the provided axes. The resulting @@ -6353,7 +6564,7 @@ def ONNXReduceL1V13Op:ONNX_Op<"ReduceL1V13", } def ONNXReduceL2Op:ONNX_Op<"ReduceL2", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL2 operation"; let description = [{ Computes the L2 norm of the input tensor's elements along the provided axes. The resulting @@ -6392,7 +6603,7 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", } def ONNXReduceL2V13Op:ONNX_Op<"ReduceL2V13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceL2 operation"; let description = [{ Computes the L2 norm of the input tensor's elements along the provided axes. The resulting @@ -6430,7 +6641,7 @@ def ONNXReduceL2V13Op:ONNX_Op<"ReduceL2V13", } def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSum operation"; let description = [{ Computes the log sum of the input tensor's elements along the provided axes. The resulting @@ -6479,7 +6690,7 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", } def ONNXReduceLogSumV13Op:ONNX_Op<"ReduceLogSumV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSum operation"; let description = [{ Computes the log sum of the input tensor's elements along the provided axes. The resulting @@ -6517,7 +6728,7 @@ def ONNXReduceLogSumV13Op:ONNX_Op<"ReduceLogSumV13", } def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSumExp operation"; let description = [{ Computes the log sum exponent of the input tensor's elements along the provided axes. The resulting @@ -6556,7 +6767,7 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", } def ONNXReduceLogSumExpV13Op:ONNX_Op<"ReduceLogSumExpV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceLogSumExp operation"; let description = [{ Computes the log sum exponent of the input tensor's elements along the provided axes. The resulting @@ -6594,7 +6805,7 @@ def ONNXReduceLogSumExpV13Op:ONNX_Op<"ReduceLogSumExpV13", } def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6645,7 +6856,7 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", } def ONNXReduceMaxV18Op:ONNX_Op<"ReduceMaxV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6694,7 +6905,7 @@ def ONNXReduceMaxV18Op:ONNX_Op<"ReduceMaxV18", } def ONNXReduceMaxV13Op:ONNX_Op<"ReduceMaxV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMax operation"; let description = [{ Computes the max of the input tensor's elements along the provided axes. The resulting @@ -6742,7 +6953,7 @@ def ONNXReduceMaxV13Op:ONNX_Op<"ReduceMaxV13", } def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMean operation"; let description = [{ Computes the mean of the input tensor's elements along the provided axes. The resulting @@ -6778,10 +6989,11 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", return sh; } }]; + let hasFolder = 1; } def ONNXReduceMeanV13Op:ONNX_Op<"ReduceMeanV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMean operation"; let description = [{ Computes the mean of the input tensor's elements along the provided axes. The resulting @@ -6819,7 +7031,7 @@ def ONNXReduceMeanV13Op:ONNX_Op<"ReduceMeanV13", } def ONNXReduceMinOp:ONNX_Op<"ReduceMin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<20>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6860,7 +7072,7 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin", } def ONNXReduceMinV18Op:ONNX_Op<"ReduceMinV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6899,7 +7111,7 @@ def ONNXReduceMinV18Op:ONNX_Op<"ReduceMinV18", } def ONNXReduceMinV13Op:ONNX_Op<"ReduceMinV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceMin operation"; let description = [{ Computes the min of the input tensor's elements along the provided axes. The resulting @@ -6937,7 +7149,7 @@ def ONNXReduceMinV13Op:ONNX_Op<"ReduceMinV13", } def ONNXReduceProdOp:ONNX_Op<"ReduceProd", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceProd operation"; let description = [{ Computes the product of the input tensor's elements along the provided axes. The resulting @@ -6976,7 +7188,7 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd", } def ONNXReduceProdV13Op:ONNX_Op<"ReduceProdV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceProd operation"; let description = [{ Computes the product of the input tensor's elements along the provided axes. The resulting @@ -7014,7 +7226,7 @@ def ONNXReduceProdV13Op:ONNX_Op<"ReduceProdV13", } def ONNXReduceSumOp:ONNX_Op<"ReduceSum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSum operation"; let description = [{ Computes the sum of the input tensor's elements along the provided axes. The resulting @@ -7063,7 +7275,7 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", } def ONNXReduceSumV11Op:ONNX_Op<"ReduceSumV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSum operation"; let description = [{ Computes the sum of the input tensor's element along the provided axes. The resulting @@ -7109,7 +7321,7 @@ def ONNXReduceSumV11Op:ONNX_Op<"ReduceSumV11", } def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSumSquare operation"; let description = [{ Computes the sum square of the input tensor's elements along the provided axes. The resulting @@ -7158,7 +7370,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", } def ONNXReduceSumSquareV13Op:ONNX_Op<"ReduceSumSquareV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReduceSumSquare operation"; let description = [{ Computes the sum square of the input tensor's elements along the provided axes. The resulting @@ -7196,7 +7408,7 @@ def ONNXReduceSumSquareV13Op:ONNX_Op<"ReduceSumSquareV13", } def ONNXReluOp:ONNX_Op<"Relu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Relu operation"; let description = [{ Relu takes one input data (Tensor) and produces one output data @@ -7227,7 +7439,7 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Reshape operation"; let description = [{ @@ -7272,7 +7484,7 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", } def ONNXResizeOp:ONNX_Op<"Resize", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Resize operation"; let description = [{ @@ -7320,7 +7532,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", } def ONNXResizeV18Op:ONNX_Op<"ResizeV18", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7364,7 +7576,7 @@ def ONNXResizeV18Op:ONNX_Op<"ResizeV18", } def ONNXResizeV13Op:ONNX_Op<"ResizeV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7404,7 +7616,7 @@ def ONNXResizeV13Op:ONNX_Op<"ResizeV13", } def ONNXResizeV11Op:ONNX_Op<"ResizeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. @@ -7444,7 +7656,7 @@ def ONNXResizeV11Op:ONNX_Op<"ResizeV11", } def ONNXResizeV10Op:ONNX_Op<"ResizeV10", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Resize operation"; let description = [{ Resize the input tensor. @@ -7477,7 +7689,7 @@ def ONNXResizeV10Op:ONNX_Op<"ResizeV10", } def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ReverseSequence operation"; let description = [{ Reverse batch of sequences having different lengths specified by `sequence_lens`. @@ -7542,7 +7754,7 @@ def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", } def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX RoiAlign operation"; let description = [{ Region of Interest (RoI) align operation described in the @@ -7557,8 +7769,8 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", the value of the sampled locations are computed directly through bilinear interpolation. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, - AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$rois, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$rois, TensorOf<[I64]>:$batch_indices, DefaultValuedStrAttr:$coordinate_transformation_mode, DefaultValuedStrAttr:$mode, @@ -7566,7 +7778,7 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", DefaultValuedAttr:$output_width, DefaultValuedAttr:$sampling_ratio, DefaultValuedAttr:$spatial_scale); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 3; @@ -7590,7 +7802,7 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", } def ONNXRoundOp:ONNX_Op<"Round", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Round operation"; let description = [{ Round takes one input Tensor and rounds the values, element-wise, meaning @@ -7608,8 +7820,8 @@ def ONNXRoundOp:ONNX_Op<"Round", round([-4.5]) = [-4.0] ``` }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -7632,7 +7844,7 @@ def ONNXRoundOp:ONNX_Op<"Round", } def ONNXSTFTOp:ONNX_Op<"STFT", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX STFT operation"; let description = [{ Computes the Short-time Fourier Transform of the signal. @@ -7665,7 +7877,7 @@ def ONNXSTFTOp:ONNX_Op<"STFT", } def ONNXScanOp:ONNX_Op<"Scan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX Scan operation"; let description = [{ Scan can be used to iterate over one or more scan_input tensors, @@ -7829,7 +8041,7 @@ def ONNXScanOp:ONNX_Op<"Scan", } def ONNXScatterOp:ONNX_Op<"Scatter", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Scatter operation"; let description = [{ This operator is deprecated. Please use ScatterElements, which provides the same functionality. @@ -7913,7 +8125,7 @@ def ONNXScatterOp:ONNX_Op<"Scatter", } def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ScatterElements operation"; let description = [{ ScatterElements takes three inputs `data`, `updates`, and `indices` of the same @@ -8008,7 +8220,7 @@ def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", } def ONNXScatterNDOp:ONNX_Op<"ScatterND", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ScatterND operation"; let description = [{ ScatterND takes three inputs `data` tensor of rank r >= 1, `indices` tensor of rank q >= 1, @@ -8115,7 +8327,7 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND", } def ONNXSeluOp:ONNX_Op<"Selu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Selu operation"; let description = [{ Selu takes one input data (Tensor) and produces one output data @@ -8123,10 +8335,10 @@ def ONNXSeluOp:ONNX_Op<"Selu", `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`, is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedAttr:$alpha, DefaultValuedAttr:$gamma); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -8149,7 +8361,7 @@ def ONNXSeluOp:ONNX_Op<"Selu", } def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceAt operation"; let description = [{ Outputs a tensor copy from the tensor at 'position' in 'input_sequence'. @@ -8181,7 +8393,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", } def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceConstruct operation"; let description = [{ Construct a tensor sequence containing 'inputs' tensors. @@ -8211,7 +8423,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", } def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceEmpty operation"; let description = [{ Construct an empty tensor sequence, with given data type. @@ -8241,7 +8453,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", } def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceErase operation"; let description = [{ Outputs a tensor sequence that removes the tensor at 'position' from 'input_sequence'. @@ -8274,7 +8486,7 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", } def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceInsert operation"; let description = [{ Outputs a tensor sequence that inserts 'tensor' into 'input_sequence' at 'position'. @@ -8310,7 +8522,7 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", } def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SequenceLength operation"; let description = [{ Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'. @@ -8339,7 +8551,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", } def ONNXSequenceMapOp:ONNX_Op<"SequenceMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { + [Pure, OpVersionTrait<17>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, OpInterface<"HasOnnxSubgraphOpInterface">]> { let summary = "ONNX SequenceMap operation"; let description = [{ Applies a sub-graph to each sample in the input sequence(s). @@ -8387,7 +8599,7 @@ def ONNXSequenceMapOp:ONNX_Op<"SequenceMap", } def ONNXShapeOp:ONNX_Op<"Shape", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Shape operation"; let description = [{ @@ -8457,7 +8669,7 @@ def ONNXShapeOp:ONNX_Op<"Shape", } def ONNXShrinkOp:ONNX_Op<"Shrink", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Shrink operation"; let description = [{ Shrink takes one input data (Tensor) and produces one Tensor output, @@ -8491,7 +8703,7 @@ def ONNXShrinkOp:ONNX_Op<"Shrink", } def ONNXSigmoidOp:ONNX_Op<"Sigmoid", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sigmoid operation"; let description = [{ Sigmoid takes one input data (Tensor) and produces one output data @@ -8522,7 +8734,7 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid", } def ONNXSignOp:ONNX_Op<"Sign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sign operation"; let description = [{ Calculate the sign of the given input tensor element-wise. @@ -8552,13 +8764,13 @@ def ONNXSignOp:ONNX_Op<"Sign", } def ONNXSinOp:ONNX_Op<"Sin", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sin operation"; let description = [{ Calculates the sine of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -8581,13 +8793,13 @@ def ONNXSinOp:ONNX_Op<"Sin", } def ONNXSinhOp:ONNX_Op<"Sinh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sinh operation"; let description = [{ Calculates the hyperbolic sine of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -8610,7 +8822,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh", } def ONNXSizeOp:ONNX_Op<"Size", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<19>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Size operation"; let description = [{ @@ -8640,7 +8852,7 @@ def ONNXSizeOp:ONNX_Op<"Size", } def ONNXSliceOp:ONNX_Op<"Slice", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Slice operation"; let description = [{ Produces a slice of the input tensor along multiple axes. Similar to numpy: @@ -8734,7 +8946,7 @@ def ONNXSliceOp:ONNX_Op<"Slice", } def ONNXSoftmaxOp:ONNX_Op<"Softmax", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softmax operation"; let description = [{ The operator computes the normalized exponential values for the given input: @@ -8780,7 +8992,7 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax", } def ONNXSoftmaxV11Op:ONNX_Op<"SoftmaxV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Softmax operation"; let description = [{ @@ -8824,7 +9036,7 @@ def ONNXSoftmaxV11Op:ONNX_Op<"SoftmaxV11", } def ONNXSoftmaxCrossEntropyLossOp:ONNX_Op<"SoftmaxCrossEntropyLoss", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SoftmaxCrossEntropyLoss operation"; let description = [{ Loss function that measures the softmax cross entropy @@ -8897,15 +9109,15 @@ def ONNXSoftmaxCrossEntropyLossOp:ONNX_Op<"SoftmaxCrossEntropyLoss", } def ONNXSoftplusOp:ONNX_Op<"Softplus", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softplus operation"; let description = [{ Softplus takes one input data (Tensor) and produces one output data (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -8928,13 +9140,13 @@ def ONNXSoftplusOp:ONNX_Op<"Softplus", } def ONNXSoftsignOp:ONNX_Op<"Softsign", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Softsign operation"; let description = [{ Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -8957,7 +9169,7 @@ def ONNXSoftsignOp:ONNX_Op<"Softsign", } def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX SpaceToDepth operation"; let description = [{ @@ -8991,7 +9203,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", } def ONNXSplitOp:ONNX_Op<"Split", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<18>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified 'axis'. @@ -9038,7 +9250,7 @@ def ONNXSplitOp:ONNX_Op<"Split", } def ONNXSplitV13Op:ONNX_Op<"SplitV13", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified @@ -9081,7 +9293,7 @@ def ONNXSplitV13Op:ONNX_Op<"SplitV13", } def ONNXSplitV11Op:ONNX_Op<"SplitV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ Split a tensor into a list of tensors, along the specified @@ -9114,7 +9326,7 @@ def ONNXSplitV11Op:ONNX_Op<"SplitV11", } def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SplitToSequence operation"; let description = [{ Split a tensor into a sequence of tensors, along the specified 'axis'. @@ -9158,7 +9370,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", } def ONNXSqrtOp:ONNX_Op<"Sqrt", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Sqrt operation"; let description = [{ Square root takes one input data (Tensor) and produces one output data @@ -9199,7 +9411,7 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", } def ONNXSqueezeOp:ONNX_Op<"Squeeze", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Squeeze operation"; let description = [{ @@ -9244,7 +9456,7 @@ def ONNXSqueezeOp:ONNX_Op<"Squeeze", } def ONNXSqueezeV11Op:ONNX_Op<"SqueezeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Squeeze operation"; let description = [{ @@ -9289,7 +9501,7 @@ def ONNXSqueezeV11Op:ONNX_Op<"SqueezeV11", } def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<10>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX StringNormalizer operation"; let description = [{ StringNormalization performs string operations for basic cleaning. @@ -9330,7 +9542,7 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", } def ONNXSubOp:ONNX_Op<"Sub", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let hasCanonicalizer = 1; let summary = "ONNX Sub operation"; let description = [{ @@ -9386,7 +9598,7 @@ def ONNXSubOp:ONNX_Op<"Sub", } def ONNXSumOp:ONNX_Op<"Sum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { let summary = "ONNX Sum operation"; let description = [{ Element-wise sum of each of the input tensors (with Numpy-style broadcasting support). @@ -9418,13 +9630,13 @@ def ONNXSumOp:ONNX_Op<"Sum", } def ONNXTanOp:ONNX_Op<"Tan", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Tan operation"; let description = [{ Calculates the tangent of the given input tensor, element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -9447,13 +9659,13 @@ def ONNXTanOp:ONNX_Op<"Tan", } def ONNXTanhOp:ONNX_Op<"Tanh", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Tanh operation"; let description = [{ Calculates the hyperbolic tangent of the given input tensor element-wise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$input); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[BF16]>]>:$output); + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$input); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -9476,7 +9688,7 @@ def ONNXTanhOp:ONNX_Op<"Tanh", } def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TfIdfVectorizer operation"; let description = [{ This transform extracts n-grams from the input sequence and save them as a vector. Input can @@ -9540,16 +9752,16 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", } def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ThresholdedRelu operation"; let description = [{ ThresholdedRelu takes one input data (Tensor) and produces one output data (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, is applied to the tensor elementwise. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X, DefaultValuedAttr:$alpha); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; @@ -9572,7 +9784,7 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", } def ONNXTileOp:ONNX_Op<"Tile", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Tile operation"; let description = [{ @@ -9605,15 +9817,15 @@ def ONNXTileOp:ONNX_Op<"Tile", } def ONNXTopKOp:ONNX_Op<"TopK", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TopK operation"; let description = [{ Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of - shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: + shape [a_0, a_1, ..., a_{n-1\}\] and integer argument k, return two outputs: - * Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] + * Value tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1\}\] which contains the values of the top k elements along the specified axis - * Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which + * Index tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1\}\] which contains the indices of the top k elements (original indices from the input tensor). @@ -9654,7 +9866,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", } def ONNXTransposeOp:ONNX_Op<"Transpose", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Transpose operation"; let description = [{ @@ -9687,7 +9899,7 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", } def ONNXTriluOp:ONNX_Op<"Trilu", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<14>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Trilu operation"; let description = [{ Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s). @@ -9729,7 +9941,7 @@ def ONNXTriluOp:ONNX_Op<"Trilu", } def ONNXUniqueOp:ONNX_Op<"Unique", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Unique operation"; let description = [{ Find the unique elements of a tensor. When an optional attribute 'axis' is provided, unique subtensors sliced along the 'axis' are returned. @@ -9860,7 +10072,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique", } def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<13>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Unsqueeze operation"; let description = [{ @@ -9910,7 +10122,7 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", } def ONNXUnsqueezeV11Op:ONNX_Op<"UnsqueezeV11", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<11>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Unsqueeze operation"; let description = [{ @@ -9962,7 +10174,7 @@ def ONNXUnsqueezeV11Op:ONNX_Op<"UnsqueezeV11", } def ONNXUpsampleOp:ONNX_Op<"Upsample", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<9>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Upsample operation"; let description = [{ Upsample the input tensor. @@ -9996,7 +10208,7 @@ def ONNXUpsampleOp:ONNX_Op<"Upsample", } def ONNXUpsampleV7Op:ONNX_Op<"UpsampleV7", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Upsample operation"; let description = [{ Upsample the input tensor. @@ -10029,7 +10241,7 @@ def ONNXUpsampleV7Op:ONNX_Op<"UpsampleV7", } def ONNXWhereOp:ONNX_Op<"Where", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<16>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Where operation"; let description = [{ @@ -10067,7 +10279,7 @@ def ONNXWhereOp:ONNX_Op<"Where", } def ONNXXorOp:ONNX_Op<"Xor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<7>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; let summary = "ONNX Xor operation"; let description = [{ @@ -10122,7 +10334,7 @@ def ONNXXorOp:ONNX_Op<"Xor", } def ONNXArrayFeatureExtractorOp:ONNX_Op<"ArrayFeatureExtractor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ArrayFeatureExtractor operation"; let description = [{ Select elements of the input tensor based on the indices passed.
@@ -10153,7 +10365,7 @@ def ONNXArrayFeatureExtractorOp:ONNX_Op<"ArrayFeatureExtractor", } def ONNXBinarizerOp:ONNX_Op<"Binarizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Binarizer operation"; let description = [{ Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value. @@ -10183,7 +10395,7 @@ def ONNXBinarizerOp:ONNX_Op<"Binarizer", } def ONNXCastMapOp:ONNX_Op<"CastMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CastMap operation"; let description = [{ Converts a map to a tensor.
The map key must be an int64 and the values will be ordered @@ -10217,7 +10429,7 @@ def ONNXCastMapOp:ONNX_Op<"CastMap", } def ONNXCategoryMapperOp:ONNX_Op<"CategoryMapper", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX CategoryMapper operation"; let description = [{ Converts strings to integers and vice versa.
@@ -10258,7 +10470,7 @@ def ONNXCategoryMapperOp:ONNX_Op<"CategoryMapper", } def ONNXDictVectorizerOp:ONNX_Op<"DictVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX DictVectorizer operation"; let description = [{ Uses an index mapping to convert a dictionary to an array.
@@ -10300,7 +10512,7 @@ def ONNXDictVectorizerOp:ONNX_Op<"DictVectorizer", } def ONNXFeatureVectorizerOp:ONNX_Op<"FeatureVectorizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX FeatureVectorizer operation"; let description = [{ Concatenates input tensors into one continuous output.
@@ -10333,7 +10545,7 @@ def ONNXFeatureVectorizerOp:ONNX_Op<"FeatureVectorizer", } def ONNXImputerOp:ONNX_Op<"Imputer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Imputer operation"; let description = [{ Replaces inputs that equal one value with another, leaving all other elements alone.
@@ -10373,7 +10585,7 @@ def ONNXImputerOp:ONNX_Op<"Imputer", } def ONNXLabelEncoderOp:ONNX_Op<"LabelEncoder", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<2>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LabelEncoder operation"; let description = [{ Maps each element in the input tensor to another value.
@@ -10427,7 +10639,7 @@ def ONNXLabelEncoderOp:ONNX_Op<"LabelEncoder", } def ONNXLinearClassifierOp:ONNX_Op<"LinearClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LinearClassifier operation"; let description = [{ Linear classifier @@ -10463,7 +10675,7 @@ def ONNXLinearClassifierOp:ONNX_Op<"LinearClassifier", } def ONNXLinearRegressorOp:ONNX_Op<"LinearRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX LinearRegressor operation"; let description = [{ Generalized linear regression evaluation.
@@ -10501,7 +10713,7 @@ def ONNXLinearRegressorOp:ONNX_Op<"LinearRegressor", } def ONNXNormalizerOp:ONNX_Op<"Normalizer", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Normalizer operation"; let description = [{ Normalize the input. There are three normalization modes, which have the corresponding formulas, @@ -10540,7 +10752,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer", } def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX OneHotEncoder operation"; let description = [{ Replace each input element with an array of ones and zeros, where a single @@ -10580,7 +10792,7 @@ def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", } def ONNXSVMClassifierOp:ONNX_Op<"SVMClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SVMClassifier operation"; let description = [{ Support Vector Machine classifier @@ -10621,7 +10833,7 @@ def ONNXSVMClassifierOp:ONNX_Op<"SVMClassifier", } def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX SVMRegressor operation"; let description = [{ Support Vector Machine regression prediction and one-class SVM anomaly detection. @@ -10658,7 +10870,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor", } def ONNXScalerOp:ONNX_Op<"Scaler", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Scaler operation"; let description = [{ Rescale input data, for example to standardize features by removing the mean and scaling to unit variance. @@ -10689,7 +10901,7 @@ def ONNXScalerOp:ONNX_Op<"Scaler", } def ONNXTreeEnsembleClassifierOp:ONNX_Op<"TreeEnsembleClassifier", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TreeEnsembleClassifier operation"; let description = [{ Tree Ensemble classifier. Returns the top class for each of N inputs.
@@ -10744,7 +10956,7 @@ def ONNXTreeEnsembleClassifierOp:ONNX_Op<"TreeEnsembleClassifier", } def ONNXTreeEnsembleRegressorOp:ONNX_Op<"TreeEnsembleRegressor", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX TreeEnsembleRegressor operation"; let description = [{ Tree Ensemble regressor. Returns the regressed values for each input in N.
@@ -10799,7 +11011,7 @@ def ONNXTreeEnsembleRegressorOp:ONNX_Op<"TreeEnsembleRegressor", } def ONNXZipMapOp:ONNX_Op<"ZipMap", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX ZipMap operation"; let description = [{ Creates a map from the input and the attributes.
@@ -10833,7 +11045,7 @@ def ONNXZipMapOp:ONNX_Op<"ZipMap", } def ONNXAdagradOp:ONNX_Op<"Adagrad", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Adagrad operation"; let description = [{ Compute one iteration of ADAGRAD, a stochastic gradient based optimization @@ -10916,7 +11128,7 @@ def ONNXAdagradOp:ONNX_Op<"Adagrad", } def ONNXAdamOp:ONNX_Op<"Adam", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Adam operation"; let description = [{ Compute one iteration of Adam, a stochastic gradient based optimization @@ -11012,7 +11224,7 @@ def ONNXAdamOp:ONNX_Op<"Adam", } def ONNXGradientOp:ONNX_Op<"Gradient", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Gradient operation"; let description = [{ Gradient operator computes the partial derivatives of a specific tensor w.r.t. @@ -11166,7 +11378,7 @@ def ONNXGradientOp:ONNX_Op<"Gradient", } def ONNXMomentumOp:ONNX_Op<"Momentum", - [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + [Pure, OpVersionTrait<1>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "ONNX Momentum operation"; let description = [{ Compute one iteration of stochastic gradient update with momentum. diff --git a/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp b/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp index fdf6991771..c9c9ec77e2 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/ConcatShapeTranspose.cpp @@ -77,7 +77,6 @@ LogicalResult ONNXConcatShapeTransposeOpShapeHelper::computeShape() { outputConcatDims[dim] = createIE->getShapeAsDim(firstInput, dim); } IndexExpr cumulativeAxisSize = createIE->getShapeAsDim(firstInput, axisIndex); - // Handle the rest of input for (unsigned i = 1; i < numInputs; ++i) { Value currInput = operandAdaptor.getInputs()[i]; @@ -113,7 +112,7 @@ LogicalResult ONNXConcatShapeTransposeOpShapeHelper::computeShape() { assert(start <= end && "Start must not be greater than end"); // Output is the actual number of values (1D) - setOutputDims({LiteralIndexExpr(end - start)}, 0); + setOutputDims({LitIE(end - start)}, 0); // For the transpose DimsExpr outputTransposeDims(commonRank); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Return.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Return.cpp index cc85f68c33..36cd53b6d1 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/Return.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/Return.cpp @@ -44,8 +44,8 @@ bool shapeIsSameOrMoreSpecific(ShapedType lhs, ShapedType rhs) { // True if the types are the same up to shape specificity. bool typeIsSameOrMoreSpecific(Type lhs, Type rhs) { - ShapedType lhsShaped = dyn_cast(lhs); - ShapedType rhsShaped = dyn_cast(rhs); + ShapedType lhsShaped = mlir::dyn_cast(lhs); + ShapedType rhsShaped = mlir::dyn_cast(rhs); if (!lhsShaped && !rhsShaped) { return lhs == rhs; @@ -67,7 +67,7 @@ bool typeIsSameOrMoreSpecific(Type lhs, Type rhs) { // Implementation is adapted from mlir/lib/Dialect/Func/IR/FuncOps.cpp // relaxing the type check to allow more specific shapes. LogicalResult ONNXReturnOp::verify() { - auto function = cast((*this)->getParentOp()); + auto function = mlir::cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp b/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp index 26a9e5cd54..dbe815e53d 100644 --- a/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp +++ b/src/Dialect/ONNX/ONNXOps/Additional/ShapeTransform.cpp @@ -67,7 +67,7 @@ LogicalResult ONNXShapeTransformOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXShapeTransformOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { Operation *op = getOperation(); // If any input is not ranked tensor, do nothing. if (!hasShapeAndRank(op)) @@ -93,13 +93,13 @@ LogicalResult ONNXShapeTransformOp::verify() { return emitError("Does not support affine_map with symbols"); // Only support static shape at this moment. - auto inputType = dyn_cast(input.getType()); + auto inputType = mlir::dyn_cast(input.getType()); if (inputType && !inputType.hasStaticShape()) return emitError("Does not support input with dynamic shape"); // If input and output have static shape, check that the same number of // elements are the same. - if (auto outputType = dyn_cast(output.getType())) + if (auto outputType = mlir::dyn_cast(output.getType())) if (outputType.hasStaticShape()) { uint64_t elementsInput = 1; for (uint64_t d : inputType.getShape()) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index b28df91bbc..41a2a18af9 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -99,7 +99,7 @@ SmallVector transposeVariadicInput(PatternRewriter &rewriter, assert(inpType && "Type is not ShapedType"); ONNXTransposeOp transposeOp = rewriter.create( loc, UnrankedTensorType::get(inpType.getElementType()), inp, permAttr); - (void)transposeOp.inferShapes([](Region ®ion) {}); + static_cast(transposeOp.inferShapes([](Region ®ion) {})); transposedInputs.emplace_back(transposeOp.getResult()); } return transposedInputs; @@ -114,7 +114,7 @@ SmallVector castVariadicInput(PatternRewriter &rewriter, Location loc, assert(inpType && "Type is not ShapedType"); ONNXCastOp castOp = rewriter.create(loc, UnrankedTensorType::get(inpType.getElementType()), inp, saturate, to); - (void)castOp.inferShapes([](Region ®ion) {}); + static_cast(castOp.inferShapes([](Region ®ion) {})); castInputs.emplace_back(castOp.getResult()); } return castInputs; @@ -183,8 +183,8 @@ bool AreTheSameAxesArrayAttr( bool AreTheSameAxesConstant(int64_t rank, Value lhs, Value rhs) { assert(cast(lhs.getType()).getElementType().isInteger(64)); assert(cast(rhs.getType()).getElementType().isInteger(64)); - auto lhsConstOp = dyn_cast_or_null(lhs.getDefiningOp()); - auto rhsConstOp = dyn_cast_or_null(rhs.getDefiningOp()); + auto lhsConstOp = mlir::dyn_cast_or_null(lhs.getDefiningOp()); + auto rhsConstOp = mlir::dyn_cast_or_null(rhs.getDefiningOp()); return lhsConstOp && rhsConstOp && AreTheSameAxesArrayAttr(rank, createArrayAttrFromConstantOp(lhsConstOp), @@ -202,11 +202,7 @@ bool haveSameStaticShape(Value lhs, Value rhs) { /// Test if the input is a splat constant with a negative value or not. bool isNegativeSplatConstant(Value val) { - if (!isDenseONNXConstant(val)) - return false; - ONNXConstantOp constOp = val.getDefiningOp(); - auto valAttr = - llvm::dyn_cast_or_null(constOp.getValueAttr()); + ElementsAttr valAttr = getElementAttributeFromONNXValue(val); if (!valAttr) return false; @@ -238,9 +234,7 @@ bool areAllDimSizes(ValueRange vals) { Type elemTy = mlir::cast(val.getType()).getElementType(); if (!mlir::isa(elemTy)) return false; - ONNXConstantOp constOp = val.getDefiningOp(); - auto valAttr = - llvm::dyn_cast_or_null(constOp.getValueAttr()); + ElementsAttr valAttr = getElementAttributeFromONNXValue(val); if (!valAttr) return false; int64_t v = (*valAttr.getValues().begin()).getSExtValue(); @@ -321,6 +315,31 @@ bool matchShapeAddMatMul(Value v, Value &matA, Value &biasB, return true; } +// Check if Reshape with allowzero == 1 can be replaced by +// another one with allowzero == 0. Conditions: +// - If no value in the 'shape' input is set to zero. +bool isConstantOpWithNoZeroElements(Value constVal) { + if (!isDenseONNXConstant(constVal)) + return false; + + ONNXConstantOp constOp = constVal.getDefiningOp(); + DenseElementsAttr intElemsAttr; + if (auto elms = + dyn_cast(constOp.getValueAttr())) { + intElemsAttr = elms; + } else if (auto elms = dyn_cast( + constOp.getValueAttr())) { + intElemsAttr = dyn_cast_or_null( + elms.toDenseElementsAttr()); + } + if (!intElemsAttr) + return false; + + auto isZero = [](int64_t val) { return val == 0; }; + + return llvm::none_of(intElemsAttr.getValues(), isZero); +} + } // namespace onnx_mlir // ============================================================================= @@ -357,8 +376,8 @@ class BinaryOpBroadcastAxisPattern : public OpRewritePattern { assert(op->getNumOperands() == 2 && "op must be binary"); Value lhs = op->getOperand(0); Value rhs = op->getOperand(1); - ShapedType lhsType = cast(lhs.getType()); - ShapedType rhsType = cast(rhs.getType()); + ShapedType lhsType = mlir::cast(lhs.getType()); + ShapedType rhsType = mlir::cast(rhs.getType()); if (!lhsType.hasRank() || !rhsType.hasRank()) { return failure(); // Cannot apply pattern until ranks are known. } @@ -485,7 +504,7 @@ class PropagateReshapeThroughBinaryOpPattern Operation *reshapeGenericOp = lhs.getDefiningOp(); if (!reshapeGenericOp) return failure(); - auto reshapeOp = dyn_cast(reshapeGenericOp); + auto reshapeOp = mlir::dyn_cast(reshapeGenericOp); if (!reshapeOp) return failure(); // RHS is a scalar. @@ -591,8 +610,8 @@ struct PropagateConstantScalingInAttentionLayerPattern onnxGemmOp.getLoc(), onnxGemmOp.getC().getType(), B, K)); }); } else { - auto onnxSubMatOp = cast(matmulOrGemmOp); - auto onnxAddOp = cast(addOp); + auto onnxSubMatOp = mlir::cast(matmulOrGemmOp); + auto onnxAddOp = mlir::cast(addOp); // Update in place MatMul and Add. rewriter.modifyOpInPlace(onnxSubMatOp, [&] { rewriter.setInsertionPoint(onnxSubMatOp); @@ -650,7 +669,7 @@ class EmptyTensorInputsResizePattern : public OpRewritePattern { private: bool isEmptyTensor(Value input) const { - if (ShapedType shapedType = dyn_cast(input.getType())) { + if (ShapedType shapedType = mlir::dyn_cast(input.getType())) { return shapedType.hasStaticShape() && shapedType.getNumElements() == 0; } else { return false; @@ -749,12 +768,9 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { bool isDefinedByIntegerConstantOp(Value v) const { if (mlir::isa(v)) return false; - Operation *definingOp = v.getDefiningOp(); if (mlir::isa( mlir::cast(v.getType()).getElementType()) && - isa(definingOp) && - mlir::isa( - cast(definingOp).getValueAttr())) + isDenseONNXConstant(v)) return true; return false; } @@ -795,10 +811,8 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // A helper function to get an integer constant from a value. int64_t getOneIntegerConstant(Value v) const { - Operation *definingOp = v.getDefiningOp(); - DenseElementsAttr valueAttr = mlir::cast( - cast(definingOp).getValueAttr()); - return (*valueAttr.getValues().begin()).getSExtValue(); + return onnx_mlir::getScalarValue( + v.getDefiningOp()); } // A helper function to match the pattern of the given operation. It also @@ -874,7 +888,7 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { // newCounterValue = ONNXAddOp(counterValue, stepValue). // cond = LessOp(newCounterValue, ubValue) // ONNXYieldOp (cond, ..., ubValue, ..., newCounterValue, ...) - Operation *addOp = cast(newCounterValue.getDefiningOp()); + Operation *addOp = mlir::cast(newCounterValue.getDefiningOp()); Value counterValue = addOp->getOperands()[0]; Value stepValue = addOp->getOperands()[1]; // Counter is a block argument and updated at each iteration. @@ -898,13 +912,14 @@ class LoopOpRewriteMaxTripCountPattern : public OpRewritePattern { if (isInvariantBlockArg(ubValue, yieldOp)) ubValue = getFedValue(ubValue, loopOp); else - ubValue = cast(rewriter.clone(*ubValue.getDefiningOp())) - .getResult(); + ubValue = + mlir::cast(rewriter.clone(*ubValue.getDefiningOp())) + .getResult(); if (isInvariantBlockArg(stepValue, yieldOp)) stepValue = getFedValue(stepValue, loopOp); else stepValue = - cast(rewriter.clone(*stepValue.getDefiningOp())) + mlir::cast(rewriter.clone(*stepValue.getDefiningOp())) .getResult(); // Case 1: the upper bound, lower bound and step are constants. @@ -1159,9 +1174,10 @@ class PowToMulRewritePattern : public OpRewritePattern { ShapedType resultType = mlir::cast(powOp.getZ().getType()); Type elementType = getElementType(resultType); if (exponent == 0) { - Attribute one = isa(elementType) - ? (Attribute)rewriter.getFloatAttr(elementType, 1.0) - : (Attribute)rewriter.getIntegerAttr(elementType, 1); + Attribute one = + isa(elementType) + ? static_cast(rewriter.getFloatAttr(elementType, 1.0)) + : static_cast(rewriter.getIntegerAttr(elementType, 1)); result = create.onnx.constant(DenseElementsAttr::get(resultType, one)); } else { // calculate pow(input,exponent) with "exponentiation by squaring" method @@ -1219,7 +1235,7 @@ class ReplaceUnsqueezeOfExpandRewritePattern // 1. data is from ExpandOp, axes is from ConstantOp. if (!definedBy(data) || !definedBy(axes)) return failure(); - auto expandOp = cast(data.getDefiningOp()); + auto expandOp = mlir::cast(data.getDefiningOp()); // 2. ExpandOp's input is a scalar tensor so that it's safe to use a new // shape that do not violate the broadcasting rule.. if (!isScalarTensor(expandOp.getInput())) @@ -1355,10 +1371,26 @@ class FuseTwoReshapesPattern : public OpRewritePattern { {firstReshapeOp.getLoc(), secondReshapeOp.getLoc()}); OnnxBuilder createONNX(rewriter, loc); + auto eraseTriviallyDeadValues = [&](PatternRewriter &rewriter, + SmallVector &values) { + for (auto val : values) { + auto *op = val.getDefiningOp(); + if (!op || !isOpTriviallyDead(op)) + continue; + rewriter.eraseOp(op); + } + }; + // Try to compute a new shape tensor by fusing the two old shapes. SmallVector firstDims, secondDims, fusedDims; if (!getValuesFromShape(createONNX, firstShape, firstDims) || !getValuesFromShape(createONNX, secondShape, secondDims)) { + // New values may be created by getValuesFromShape. Erase newly-created + // values before failing. This avoids that the PatternRewriter notify + // changes and prevent convergence issue. + eraseTriviallyDeadValues(rewriter, firstDims); + eraseTriviallyDeadValues(rewriter, secondDims); + // Not rewrite if we can not read dimension values (0, -1, L) from a shape // tensor. return rewriter.notifyMatchFailure( @@ -1399,6 +1431,12 @@ class FuseTwoReshapesPattern : public OpRewritePattern { minusOnes++; } if (minusOnes > 1) { + // New values may be created by getValuesFromShape. Erase newly-created + // values before failing. This avoids that the PatternRewriter notify + // changes and prevent convergence issue. + eraseTriviallyDeadValues(rewriter, firstDims); + eraseTriviallyDeadValues(rewriter, secondDims); + // The fused shape is invalid because it has two -1s. return rewriter.notifyMatchFailure(op, "Failed to compute a fused shape"); } @@ -1493,7 +1531,6 @@ struct PropagateBiasIntoLayerNormRewritePattern Value y, bias; Operation *yLayerNormOp; Operation *ywbAddOp = addOp.getOperation(); - Location loc = addOp.getLoc(); // Match // %noBias = "onnx.NoValue"() // %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %scale, %noBias) @@ -1508,7 +1545,7 @@ struct PropagateBiasIntoLayerNormRewritePattern // used. if (!yLayerNormOp->hasOneUse()) return reportFailure("y/layer norm has too many uses"); - auto lnOp = cast(yLayerNormOp); + auto lnOp = mlir::cast(yLayerNormOp); if (!onnx_mlir::isNoneValue(lnOp.getB())) return reportFailure("layer norm already has a bias"); // We are fine. @@ -1519,7 +1556,8 @@ struct PropagateBiasIntoLayerNormRewritePattern LLVM_DEBUG(llvm::dbgs() << "LayerNorm from add, axis : " << axis << "\n"); // Replace - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create( + rewriter, rewriter.getFusedLoc({lnOp.getLoc(), addOp->getLoc()})); Type xType = x.getType(); Value res; if constexpr (std::is_same::value) @@ -1712,6 +1750,7 @@ void ONNXReshapeOp::getCanonicalizationPatterns( result.insert(context); result.insert(context); result.insert(context); + result.insert(context); } /// on the ONNXResizeOp. @@ -1843,7 +1882,8 @@ void ONNXUnsqueezeV11Op::getCanonicalizationPatterns( void ONNXPowOp::getCanonicalizationPatterns( RewritePatternSet &result, MLIRContext *context) { // Is 64 necessary? Maybe too high? - result.insert(context, 64); + // Changed from upstream 64 to 2 because it can break quantization patterns + result.insert(context, 2); result.insert>(context); } @@ -1861,6 +1901,4 @@ void ONNXWhereOp::getCanonicalizationPatterns( // on the ONNXDequantizeLinearOp. void ONNXDequantizeLinearOp::getCanonicalizationPatterns( - RewritePatternSet &result, MLIRContext *context) { - result.insert(context); -} + RewritePatternSet &result, MLIRContext *context) {} diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index 694ce93394..957981519c 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -38,6 +38,21 @@ include "src/Dialect/ONNX/ONNX.td" def createDenseElementsAttrFromFloatAttr : NativeCodeCall< "onnx_mlir::createDenseElementsAttrFromFloatAttr($_builder, mlir::cast($0.getType()).getElementType(), $1)">; +// TODO: Currently this will not saturate. We might need/want this for float8 +// types. +def castWithSameRankDifferentElementType : NativeCodeCall< + "onnx_mlir::castTo($_builder, $0, $1.getType().cast().getElementType(), /*saturate=*/ 0)">; + +def createConstantOpWithOneToRankOfExclusive: NativeCodeCall< + [{onnx_mlir::createConstantOp($_builder, $0.getLoc(), + // Create an ArrayAttr of IntergerAttr(s) of values in [1, N-2]. + onnx_mlir::createArrayAttrOfNToM($_builder, 1, $0.getType().cast().getRank() - 2))}]>; + +def createConstantOpWithOneToRankOf : NativeCodeCall< + [{onnx_mlir::createConstantOp($_builder, $0.getLoc(), + // Create an ArrayAttr of IntergerAttr(s) of values in [1, N-1]. + onnx_mlir::createArrayAttrOfNToM($_builder, 1, $0.getType().cast().getRank() - 1))}]>; + // Create a DenseElementsAttr from the shape of the type of a value. def createDenseElementsAttrFromShape : NativeCodeCall< "onnx_mlir::createDenseElementsAttrFromShape($_builder, $0)">; @@ -70,10 +85,6 @@ def createDenseElementsAttrOf : NativeCodeCall< def createDenseElementsAttrOfOneToRankOf : NativeCodeCall< "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast($0.getType()).getRank() - 1)">; -// Create an ArrayAttr of IntergerAttr(s) of values in [1, N-2]. -def createDenseElementsAttrOfOneToRankOfExclusive : NativeCodeCall< - "onnx_mlir::createDenseElementsAttrOfNToM($_builder, 1, mlir::cast($0.getType()).getRank() - 2)">; - // Create an ArrayAttr of IntergerAttr(s) of values in [2, rank - 1]. def createArrayAttrOfTwoToRankOf : NativeCodeCall< "onnx_mlir::createArrayAttrOfNToM($_builder, 2, mlir::cast($0.getType()).getRank() - 1)">; @@ -81,11 +92,6 @@ def createArrayAttrOfTwoToRankOf : NativeCodeCall< def AttributeIsNotNull : Constraint, "Attribute is not null">; -def IsDenseElementsAttr : - Constraint, - CPred<"mlir::isa(($_self))"> - ]>, "Attribute is not a DenseElementsAttr">; - // Intended to check whether there is at least one not-Null the attributes // However, the current table gen can only support max 4 parameters // Multiple rules are used instead of one rule @@ -188,6 +194,12 @@ def IsStaticShapeTensor: "mlir::cast<::mlir::ShapedType>($_self.getType()).hasStaticShape()">, "hasStaticShape">; +def HasRank: + Constraint< + CPred< + "$_self.getType().cast<::mlir::ShapedType>().hasRank()">, + "is a tensor of static rank">; + def IsNoneValue: Constraint< CPred<"onnx_mlir::isNoneValue($_self)">, "Is the value none">; @@ -197,18 +209,13 @@ def HasSpecifiedConstantShape: Constraint< "Has the specified constant shape">; def IsFromONNXConstantOp: Constraint< - CPred<"llvm::dyn_cast_or_null($0.getDefiningOp())">, + CPred<"onnx_mlir::isDenseONNXConstant($0)">, "Is a value from ONNXConstantOp">; def IsNotFromONNXConstantOp: Constraint< CPred<"!(llvm::dyn_cast_or_null($0.getDefiningOp()))">, "Is a value not from ONNXConstantOp">; -def IsFromONNXConstantOpWithDenseElementsAttr: Constraint< - And<[CPred<" $_self.getDefiningOp() ">, - CPred<" isa(onnx_mlir::getONNXConstantOp($_self).getValueAttr()) "> - ]>, "Value is not a ONNXConstantOp with a DenseElementsAttr">; - def IsNegativeSplatConstant: Constraint< CPred<"onnx_mlir::isNegativeSplatConstant($_self)">, "Is a splat constant with a negative value." @@ -283,6 +290,15 @@ class IntegerAttrIsOf : Constraint< "IntegerAttr is of the given value" >; +def isConstantOpWithNoZeroElements: Constraint< + CPred<"onnx_mlir::isConstantOpWithNoZeroElements($0)">, + "Checks this constant has no zero elements." +>; + +class ValidAxisConstraints: Constraint< + CPred<"(" # axis # " >= 0) && ( " # axis # " < mlir::cast($_self.getType()).getShape().size())">, +"Axis is within the valid range of shape dimensions">; + //===----------------------------------------------------------------------===// // Pattern-Match and Rewrite //===----------------------------------------------------------------------===// @@ -366,6 +382,7 @@ def FuseAddConvPattern: Pat< [(HasShapeAndRank:$res), (NotNoneType $b), (AttributeIsNotNull:$denseAttr), + (ValidAxisConstraints<1>:$y), (AllDimsFromAxisToEndAre<1, 1>:$y), (RankXMinusRankYIs<1> $res, $y)] >; @@ -407,8 +424,8 @@ def FuseMulConvNullBiasPattern: Pat< // unchanged operands and attributes. $b, $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), [(HasNoneType $b), - (IsDenseElementsAttr:$denseAttr), - (IsFromONNXConstantOpWithDenseElementsAttr:$w), + (AttributeIsNotNull:$denseAttr), + (IsFromONNXConstantOp $w), (HaveSameElementType $w, $y), // multiplier and Conv weight must have the same element type. (HasRankGT<1> $w), // rank of $w must be at least 2. (RankXMinusRankYIs<1> $w, $y), // rank($y) must be equal to rank($w)-1. @@ -474,13 +491,13 @@ def SwapCastSlicePattern: Pat< // Canonicalization for ONNXTileOp //===----------------------------------------------------------------------===// -def IsFromONNXConstantOpWithOnesDenseElementsAttr: Constraint< - And<[IsFromONNXConstantOpWithDenseElementsAttr.predicate, +def IsFromONNXConstantOpWithOnes: Constraint< + And<[CPred<"onnx_mlir::isDenseONNXConstant($_self)">, CPred<"::llvm::all_of(" - "mlir::dyn_cast(onnx_mlir::getONNXConstantOp($_self)" - ".getValueAttr()).getValues(), " + "onnx_mlir::getElementAttributeFromONNXValue($_self)" + ".getValues(), " "[](int64_t repeat) { return repeat == 1;})"> - ]>, "Value is not a ONNXConstantOp with a DenseElementsAttr of ones">; + ]>, "Value is not a ONNXConstantOp with an ElementsAttr of ones">; def RemoveIdentityTilePattern: Pat< // Tile with `repeats` of all constant 1's @@ -488,7 +505,7 @@ def RemoveIdentityTilePattern: Pat< // Remove the tile. (replaceWithValue $val), // Check that we have indeed a identity tile pattern. - [(IsFromONNXConstantOpWithOnesDenseElementsAttr:$r), (HaveSameShapedType $val,$result)]>; + [(IsFromONNXConstantOpWithOnes:$r), (HaveSameShapedType $val,$result)]>; //===----------------------------------------------------------------------===// // Canonicalization for ONNXLayoutTransformOp @@ -668,6 +685,15 @@ def RemoveIdentityReshapePattern2: Pat< // Check that val and out have the same static shape. [(IsIdentityReshape $out, $val)]>; +def ReplaceReshapeAllowZeroByReshape: Pat< + // Reshape with allowzero == 1 + (ONNXReshapeOp:$out $val, $shape, $allowzero), + // Replace Reshape (allowzero == 1) with Reshape (allowzero == 0) + (ONNXReshapeOp $val, $shape, (GetNullIntegerAttr)), + // Check that allowzero == 1, and val and out have static and non zero dimensions. + [(IntegerAttrIsOf<1> $allowzero), + (isConstantOpWithNoZeroElements $shape)]>; + def GetReturnTypeForMatMulOpND2D: NativeCodeCall< "onnx_mlir::getReturnTypeForMatMulOpND2D($0, $1)" >; @@ -821,19 +847,19 @@ def FuseBatchNormInferenceModeConvPattern: Pat< $w, (ONNXUnsqueezeOp (ONNXDivOp:$coefficientW - $scale, + (castWithSameRankDifferentElementType $scale, $x), (ONNXSqrtOp (ONNXAddOp - $var, + (castWithSameRankDifferentElementType $var, $x), (ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromFloatAttr $res, $epsilon))))), - (ONNXConstantOpFromDenseAttr (createDenseElementsAttrOfOneToRankOf $w)))), + (createConstantOpWithOneToRankOf $w))), // b_ (ONNXAddOp $B, (ONNXMulOp $coefficientW, - (subtractOrNeg $res, $b, $mean))), + (subtractOrNeg $res, $b, (castWithSameRankDifferentElementType $mean, $x)))), $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides), [], [], (addBenefit 1) @@ -864,17 +890,21 @@ def RewriteBatchNormInferenceModeConvPattern1: Pat< $x, (ONNXUnsqueezeOp (ONNXDivOp:$a - $scale, + (castWithSameRankDifferentElementType $scale, $x), (ONNXSqrtOp (ONNXAddOp - $var, + (castWithSameRankDifferentElementType $var, $x), (ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromFloatAttr $res, $epsilon))))), - (ONNXConstantOpFromDenseAttr (createDenseElementsAttrOfOneToRankOfExclusive $x)))), + (createConstantOpWithOneToRankOfExclusive $x))), // b (ONNXUnsqueezeOp - (ONNXSubOp $bias, (ONNXMulOp $mean, $a)), - (ONNXConstantOpFromDenseAttr (createDenseElementsAttrOfOneToRankOfExclusive $x)))), + (ONNXSubOp + (castWithSameRankDifferentElementType $bias, $x), + (ONNXMulOp + (castWithSameRankDifferentElementType $mean, $x), + $a)), + (createConstantOpWithOneToRankOfExclusive $x))), [(HasRankGT<2> $x)], [], (addBenefit 0) >; @@ -890,14 +920,15 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat< (ONNXMulOp $x, (ONNXDivOp:$a - $scale, + (castWithSameRankDifferentElementType $scale, $x), (ONNXSqrtOp (ONNXAddOp - $var, + (castWithSameRankDifferentElementType $var, $x), (ONNXConstantOpFromDenseAttr (createDenseElementsAttrFromFloatAttr $res, $epsilon)))))), // b - (ONNXSubOp $bias, (ONNXMulOp $mean, $a))), + (ONNXSubOp (castWithSameRankDifferentElementType $bias, $x), + (ONNXMulOp (castWithSameRankDifferentElementType $mean, $x), $a))), [(HasRankOf<1> $x)], [], (addBenefit 0) >; @@ -937,7 +968,8 @@ def SizeToConstantPattern: Pat< // Rewrite GlobalAveragePool using ReduceMean. def GlobalAveragePoolPattern: Pat< (ONNXGlobalAveragePoolOp $x), - (ONNXReduceMeanV13Op $x, (createArrayAttrOfTwoToRankOf $x), (GetNullAttr)) + (ONNXReduceMeanV13Op $x, (createArrayAttrOfTwoToRankOf $x), (GetNullAttr)), + [(HasRank:$x)] >; // Rewrite GlobalMaxPool using ReduceMax. @@ -1055,15 +1087,4 @@ def AlwaysFalseWherePattern : Pat< [(IsNegativeSplatConstant:$negative_constant), (AreAllDimSizes:$dims)] >; -//===----------------------------------------------------------------------===// -// Canonicalization for ONNXDequantizeLinear -//===----------------------------------------------------------------------===// - -// Convert QuantizeLinear+DequantizeLinear to Identity. -def QuantizeDequantizePattern: Pat< - (ONNXDequantizeLinearOp (ONNXQuantizeLinearOp $x, $x_scale, $x_zeropoint, $x_axis, $x_saturate), - $y_scale, $y_zeropoint, $y_axis), - (replaceWithValue $x) ->; - #endif // ONNX_REWRITE diff --git a/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp b/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp index bcb229e6d0..85d60dbfeb 100644 --- a/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp +++ b/src/Dialect/ONNX/ONNXOps/ControlFlow/Loop.cpp @@ -39,7 +39,7 @@ std::vector ONNXLoopOp::resultTypeInference() { resultTypes.push_back(ty); } else { // scan output // Erase any rank and shape. Shape inference will add a leading dimension. - Type elementType = cast(ty).getElementType(); + Type elementType = mlir::cast(ty).getElementType(); resultTypes.push_back(UnrankedTensorType::get(elementType)); } } diff --git a/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp b/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp index 5c62a5fa95..b875ee29f1 100644 --- a/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp +++ b/src/Dialect/ONNX/ONNXOps/ControlFlow/Scan.cpp @@ -50,7 +50,7 @@ std::vector ONNXScanOp::resultTypeInference() { resultTypes.push_back(ty); } else { // scan output // Erase any rank and shape. Shape inference will add a leading dimension. - Type elementType = cast(ty).getElementType(); + Type elementType = mlir::cast(ty).getElementType(); resultTypes.push_back(UnrankedTensorType::get(elementType)); } } diff --git a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp index 3b3805ea26..47a74a0093 100644 --- a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp +++ b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp @@ -46,7 +46,7 @@ LogicalResult ONNXOneHotEncoderOpShapeHelper::computeShape() { // total category count will determine the size of the extra dimension DimsExpr outputDims; createIE->getShapeAsDims(X, outputDims); - outputDims.emplace_back(LiteralIndexExpr(outDim)); + outputDims.emplace_back(LitIE(outDim)); // Save the final result. setOutputDims(outputDims); diff --git a/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp b/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp index be4b5a6694..94ee2acd2a 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Bernoulli.cpp @@ -4,7 +4,7 @@ //===------------------ Bernoulli.cpp - ONNX Operations -------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -34,8 +34,9 @@ LogicalResult ONNXBernoulliOp::inferShapes( } Type elementType; if (getDtypeAttr()) { - elementType = convertONNXTypeToMLIRType(builder, - (onnx::TensorProto_DataType)getDtypeAttr().getValue().getSExtValue()); + elementType = convertONNXTypeToMLIRType( + builder, static_cast( + getDtypeAttr().getValue().getSExtValue())); } else { elementType = mlir::cast(getInput().getType()).getElementType(); diff --git a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp index 82047008a2..787f18ae70 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/DFT.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXGenericDFTOpShapeHelper::customComputeShape( // Get info about input data operand. Value input = operandAdaptor.getInput(); // Get the rank to compensate for N dimensions. + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(input); // Check if the dimension for axis is a literal and in range. @@ -67,7 +70,7 @@ LogicalResult ONNXGenericDFTOpShapeHelper::customComputeShape( } } } - outputDims.emplace_back(LiteralIndexExpr(2)); + outputDims.emplace_back(LitIE(2)); // Save the final result. setOutputDims(outputDims); @@ -88,7 +91,7 @@ LogicalResult ONNXGenericDFTOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXDFTOp::inferShapes( - std::function doShapeInference) { + std::function doShapeInference) { // Cannot infer the output shape if the operands shape isn't known yet. if (!hasShapeAndRank(getOperation())) return success(); diff --git a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp index af472ec6ba..c6982e870d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/EinsumHelper.cpp @@ -179,7 +179,7 @@ Subscripts extractSubscripts(StringRef parameterEquation, int64_t rank) { } appendLetterSubscripts(suffix, subscripts); } - assert((int64_t)subscripts.size() == rank && + assert(static_cast(subscripts.size()) == rank && "#subscripts == rank after replacing any ellipsis with digits"); return subscripts; } diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp index f3961667e5..d465a367bf 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseBroadcast.cpp @@ -343,15 +343,7 @@ LogicalResult ONNXOrOp::inferShapes( //===----------------------------------------------------------------------===// LogicalResult ONNXPowOp::verify() { - ShapedType lhsTy = mlir::cast(getX().getType()); - ShapedType rhsTy = mlir::cast(getY().getType()); - Type rhsETy = rhsTy.getElementType(); - Type lhsETy = lhsTy.getElementType(); - if (rhsETy != lhsETy) - return emitOpError("Pow with different input type not implemented yet"); - if (mlir::isa(lhsETy) || mlir::isa(lhsETy)) - return emitOpError("Integer power not implemented yet"); - return success(); + return verifyShapeForBroadcastingOps(getOperation()); } LogicalResult ONNXPowOp::inferShapes( diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp index a38ddfcb11..c527a92f5c 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp @@ -395,6 +395,15 @@ LogicalResult ONNXMeanVarianceNormalizationOp::inferShapes( return inferShapeForUnaryOps(this->getOperation()); } +//===----------------------------------------------------------------------===// +// MishOp +//===----------------------------------------------------------------------===// + +LogicalResult ONNXMishOp::inferShapes( + std::function doShapeInference) { + return inferShapeForUnaryOps(this->getOperation()); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp index ab26c52771..2533561cfe 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Gemm.cpp @@ -66,10 +66,10 @@ LogicalResult ONNXGemmOpShapeHelper::computeShape() { if (hasBias) { if (cRank == 0) { // Broadcast for scalar: both dims are 1. - cDims = {LiteralIndexExpr(1), LiteralIndexExpr(1)}; + cDims = {LitIE(1), LitIE(1)}; } else if (cRank == 1) { // First dim is the one padded. - cDims = {LiteralIndexExpr(1), createIE->getShapeAsDim(C, 0)}; + cDims = {LitIE(1), createIE->getShapeAsDim(C, 0)}; } else { assert(cRank == 2 && "illegal path"); cDims = {createIE->getShapeAsDim(C, 0), createIE->getShapeAsDim(C, 1)}; diff --git a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp index c17b62679d..d9fe81fce1 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/MatMul.cpp @@ -55,6 +55,9 @@ LogicalResult ONNXGenericMatMulOpShapeHelper::computeShape() { std::tie(A, B) = matMulInputs(operandAdaptor); // Size all the arrays to padded length. + if (!hasShapeAndRank(A) || !hasShapeAndRank(B)) { + return failure(); + } uint64_t aRank = createIE->getShapedTypeRank(A); uint64_t bRank = createIE->getShapedTypeRank(B); int paddedRank = std::max(aRank, bRank); diff --git a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp index d988f67bc1..bb4a9fd2cb 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Reduction.cpp @@ -29,6 +29,9 @@ LogicalResult ONNXGenericReductionOpShapeHelper::customComputeShape( DimsExpr &axes, int noopWithEmptyAxes) { typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(data); // Normalize the axes: at present, we only support compile time axes, but // with keep_dim on, it might not be too difficult to generate the code. @@ -104,7 +107,11 @@ LogicalResult ONNXGenericReductionOpShapeHelper::computeShape() { createIE->getIntFromArrayAsSymbols(operandAdaptor.getAxes(), axes); } else { // When the axis is dynamic, try to infer the rank of output tensor - int64_t dataRank = createIE->getShapedTypeRank(operandAdaptor.getData()); + const auto data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } + int64_t dataRank = createIE->getShapedTypeRank(data); int64_t axlesSize = createIE->getArraySize(operandAdaptor.getAxes()); if (!operandAdaptor.getKeepdims() && axlesSize < 0 /*undef shape*/) { // Even though we did not compute the shape in ShapeHelper, return @@ -386,6 +393,27 @@ LogicalResult ONNXReduceSumSquareV13Op::inferShapes( return inferShapeForReductionOps_old(*this); } +//===----------------------------------------------------------------------===// +// Folder +//===----------------------------------------------------------------------===// + +OpFoldResult ONNXReduceMeanOp::fold(FoldAdaptor adaptor) { + typename ONNXReduceMeanOp::Adaptor opAdaptor(*this); + onnx_mlir::ONNXGenericReductionOpShapeHelper shapeHelper( + getOperation(), opAdaptor.getOperands()); + + if (failed(shapeHelper.computeShape())) + return nullptr; + + const bool hasReduction = + llvm::any_of(shapeHelper.isReductionAxis, [](bool axis) { return axis; }); + + if (!hasReduction && opAdaptor.getNoopWithEmptyAxes()) + return getData(); + + return nullptr; +} + //===----------------------------------------------------------------------===// // Template instantiation; keep at the end of the file. //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp index 189d855805..701e03721d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp @@ -76,6 +76,10 @@ LogicalResult ONNXScatterElementsOp::verify() { if (dataDimAtAxis >= 0) { if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) diff --git a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp index 641faa1e4d..98e0ec45f6 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/TopK.cpp @@ -31,6 +31,9 @@ LogicalResult ONNXTopKOpShapeHelper::computeShape() { // Get info about X and K operands. Value X = operandAdaptor.getX(); Value K = operandAdaptor.getK(); + if (!hasShapeAndRank(X)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(X); // Axis to compute TopK. diff --git a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp index 951905f8ad..0539ad446c 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Conv.cpp @@ -4,7 +4,7 @@ //===------------------ Conv.cpp - ONNX Operations ------------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -370,10 +370,13 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() { int64_t groupNum = convTransposeOp.getGroup(); llvm::StringRef autoPad = convTransposeOp.getAutoPad(); - Value xValue = (Value)operandAdaptor.getX(); + Value xValue = static_cast(operandAdaptor.getX()); Value wValue = operandAdaptor.getW(); // Basic information. + if (!hasShapeAndRank(xValue)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(xValue); int64_t spatialOffset = 2; int64_t spatialRank = rank - spatialOffset; @@ -388,8 +391,7 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() { dilationOpt.has_value() ? ArrayAttrIntVal(dilationOpt, i) : 1); // Kernel shape from attribute, default from Weight's spatial dims. if (kernelShapeOpt.has_value()) { - kernelShape.emplace_back( - LiteralIndexExpr(ArrayAttrIntVal(kernelShapeOpt, i))); + kernelShape.emplace_back(LitIE(ArrayAttrIntVal(kernelShapeOpt, i))); } else { int ii = i + spatialOffset; kernelShape.emplace_back(createIE->getShapeAsSymbol(wValue, ii)); @@ -402,15 +404,14 @@ LogicalResult ONNXConvTransposeOpShapeHelper::computeShape() { // Pads, at this stage a given compile-time literal or default 0. for (int i = 0; i < 2 * spatialRank; ++i) { int64_t p = padOpt.has_value() ? ArrayAttrIntVal(padOpt, i) : 0; - pads.emplace_back(LiteralIndexExpr(p)); + pads.emplace_back(LitIE(p)); } // Handle output size: start by inserting batch size and output channels. DimsExpr outputDims; outputDims.emplace_back(createIE->getShapeAsDim(xValue, 0)); - outputDims.emplace_back( - createIE->getShapeAsDim(wValue, 1) * - LiteralIndexExpr(groupNum)); // CO may be different from CI. + outputDims.emplace_back(createIE->getShapeAsDim(wValue, 1) * + LitIE(groupNum)); // CO may be different from CI. LiteralIndexExpr zeroIE(0); LiteralIndexExpr oneIE(1); @@ -532,7 +533,7 @@ LogicalResult ONNXConvOp::verify() { } if (hasShapeAndRank(X)) { auto xShape = mlir::cast(X.getType()).getShape(); - if ((int64_t)xShape.size() - 2 != spatialRank) + if (static_cast(xShape.size()) - 2 != spatialRank) return emitOpError("Input and filter rank mismatch"); if (xShape[1] != ShapedType::kDynamic && xShape[1] % g != 0) return emitOpError( @@ -619,7 +620,7 @@ LogicalResult ONNXConvTransposeOp::verify() { if (hasShapeAndRank(X)) { auto xShape = mlir::cast(X.getType()).getShape(); - if ((int64_t)xShape.size() - 2 != spatialRank) + if (static_cast(xShape.size()) - 2 != spatialRank) return emitOpError("Input and filter rank mismatch"); if (xShape[1] != ShapedType::kDynamic && wShape[0] != ShapedType::kDynamic && xShape[1] != wShape[0]) { diff --git a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc index d742a01bbd..688624d3f3 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc +++ b/src/Dialect/ONNX/ONNXOps/NN/NNHelper.cpp.inc @@ -31,6 +31,9 @@ LogicalResult ONNXGenericPoolOpShapeHelper::customComputeShape( std::optional strideOpt, std::optional dilationOpt, bool hasFilter, bool ceilMode) { // Basic information. + if(!hasShapeAndRank(xValue)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(xValue); int64_t spatialOffset = 2; int64_t spatialRank = rank - spatialOffset; @@ -45,8 +48,7 @@ LogicalResult ONNXGenericPoolOpShapeHelper::customComputeShape( dilationOpt.has_value() ? ArrayAttrIntVal(dilationOpt, i) : 1); // Kernel shape from attribute, default from Weight's spatial dims. if (kernelShapeOpt.has_value()) { - kernelShape.emplace_back( - LiteralIndexExpr(ArrayAttrIntVal(kernelShapeOpt, i))); + kernelShape.emplace_back(LitIE(ArrayAttrIntVal(kernelShapeOpt, i))); } else { assert(hasFilter && "no kernel shape and no filter: unkown kernel shape"); int ii = i + spatialOffset; @@ -56,7 +58,7 @@ LogicalResult ONNXGenericPoolOpShapeHelper::customComputeShape( // Pads, at this stage a given compile-time literal or default 0. for (int i = 0; i < 2 * spatialRank; ++i) { int64_t p = padOpt.has_value() ? ArrayAttrIntVal(padOpt, i) : 0; - pads.emplace_back(LiteralIndexExpr(p)); + pads.emplace_back(LitIE(p)); } // Handle output size: start by inserting batch size and output channels. diff --git a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp index df7f3c2d56..33248a782a 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp @@ -4,7 +4,7 @@ //===------------------ Normalization.cpp - ONNX Operations ---------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -32,22 +32,37 @@ LogicalResult ONNXBatchNormalizationInferenceModeOpShapeHelper::computeShape() { return setOutputDimsFromOperand(operandAdaptor.getX()); } -} // namespace onnx_mlir +template <> +LogicalResult ONNXBatchNormalizationOpShapeHelper::computeShape() { + // Single output in inference mode, Y same shape as X. + ONNXBatchNormalizationOpAdaptor operandAdaptor(operands); + + // Input, RunningMean and RunningVar have the same dimensions as their inputs + // counterparts. + auto inputShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getX()).succeeded(); + auto meanShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getInputMean(), 1).succeeded(); + auto varShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getInputVar(), 2).succeeded(); + return LogicalResult::success( + inputShapeInferRes && meanShapeInferRes && varShapeInferRes); +} -LogicalResult ONNXBatchNormalizationInferenceModeOp::inferShapes( - std::function doShapeInference) { - // Cannot infer shape if no shape exists. - if (!hasShapeAndRank(getX()) || !hasShapeAndRank(getScale()) || - !hasShapeAndRank(getB()) || !hasShapeAndRank(getMean()) || - !hasShapeAndRank(getVar())) +template +LogicalResult inferShapesForBatchNorm(Operation *op, Value input, Value scale, + Value bias, Value mean, Value variance) { + if (!hasShapeAndRank(input) || !hasShapeAndRank(scale) || + !hasShapeAndRank(bias) || !hasShapeAndRank(mean) || + !hasShapeAndRank(variance)) return success(); // Verifier code. - auto inputTensorTy = mlir::cast(getX().getType()); - auto scaleTensorTy = mlir::cast(getScale().getType()); - auto biasTensorTy = mlir::cast(getB().getType()); - auto meanTensorTy = mlir::cast(getMean().getType()); - auto varianceTensorTy = mlir::cast(getVar().getType()); + auto inputTensorTy = cast(input.getType()); + auto scaleTensorTy = cast(scale.getType()); + auto biasTensorTy = cast(bias.getType()); + auto meanTensorTy = cast(mean.getType()); + auto varianceTensorTy = cast(variance.getType()); // Check whether the shapes of scale, bias, mean and variance are valid. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. @@ -69,26 +84,72 @@ LogicalResult ONNXBatchNormalizationInferenceModeOp::inferShapes( auto v = varianceTensorTy.getShape(); if ((s.size() != 1) || (!ShapedType::isDynamic(s[0]) && s[0] != c)) - return emitError("Wrong rank for the scale"); + return op->emitError("Wrong rank for the scale"); if ((b.size() != 1) || (!ShapedType::isDynamic(b[0]) && b[0] != c)) - return emitError("Wrong rank for the bias"); + return op->emitError("Wrong rank for the bias"); if ((m.size() != 1) || (!ShapedType::isDynamic(m[0]) && m[0] != c)) - return emitError("Wrong rank for the mean"); + return op->emitError("Wrong rank for the mean"); if ((v.size() != 1) || (!ShapedType::isDynamic(v[0]) && v[0] != c)) - return emitError("Wrong rank for the variance"); + return op->emitError("Wrong rank for the variance"); } // The output tensor of the same shape as the input. - Type elementType = - mlir::cast(getX().getType()).getElementType(); - ONNXBatchNormalizationInferenceModeOpShapeHelper shapeHelper( - getOperation(), {}); + Type elementType = inputTensorTy.getElementType(); + BatchNormOpShapeHelper shapeHelper(op, {}); return shapeHelper.computeShapeAndUpdateType(elementType); } +template <> +LogicalResult ONNXBatchNormalizationV9OpShapeHelper::computeShape() { + // Single output in inference mode, Y same shape as X. + ONNXBatchNormalizationV9OpAdaptor operandAdaptor(operands); + + // Input, RunningMean and RunningVar have the same dimensions as their inputs + // counterparts. + auto inputShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getX()).succeeded(); + auto meanShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getMean(), 1).succeeded(); + auto varShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getVar(), 2).succeeded(); + auto savedMeanShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getMean(), 3).succeeded(); + auto savedVarShapeInferRes = + setOutputDimsFromOperand(operandAdaptor.getVar(), 4).succeeded(); + return LogicalResult::success(inputShapeInferRes && meanShapeInferRes && + varShapeInferRes && savedMeanShapeInferRes && + savedVarShapeInferRes); +} + +} // namespace onnx_mlir + +LogicalResult ONNXBatchNormalizationInferenceModeOp::inferShapes( + std::function doShapeInference) { + return inferShapesForBatchNorm< + ONNXBatchNormalizationInferenceModeOpShapeHelper>( + getOperation(), getX(), getScale(), getB(), getMean(), getVar()); +} + +LogicalResult ONNXBatchNormalizationOp::inferShapes( + std::function doShapeInference) { + // Cannot infer shape if no shape exists. + + return inferShapesForBatchNorm( + getOperation(), getX(), getScale(), getB(), getInputMean(), + getInputVar()); +} + +LogicalResult ONNXBatchNormalizationV9Op::inferShapes( + std::function doShapeInference) { + return inferShapesForBatchNorm( + getOperation(), getX(), getScale(), getB(), getMean(), getVar()); +} + namespace onnx_mlir { template struct ONNXNonSpecificOpShapeHelper< ONNXBatchNormalizationInferenceModeOp>; + +template struct ONNXNonSpecificOpShapeHelper; } // namespace onnx_mlir //===----------------------------------------------------------------------===// @@ -149,6 +210,21 @@ LogicalResult ONNXInstanceNormalizationOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GroupNormalizationV18 +//===----------------------------------------------------------------------===// +LogicalResult ONNXGroupNormalizationV18Op::verify() { + ONNXGroupNormalizationV18OpAdaptor(*this); + llvm::outs() + << "\nWarning: The previous understanding of Opset 18 for " + "GroupNormalization " + "is incorrect. As shown in the following issue: " + "https://github.com/onnx/onnx/issues/5466.Rather, use Opset 21 for " + "GroupNormalization instead." + << "\n\n"; + return success(); +} + // TODO: should there be a shape inference for this one? //===----------------------------------------------------------------------===// @@ -191,7 +267,7 @@ LogicalResult verifyShapeForLayerNorm(OP_TYPE *op) { if (!OpTrait::util::getBroadcastedShape(XShape, bShape, BBroadcastShape)) op->emitOpError( "LayerNormalization op with incompatible B shapes (broadcast)"); - if ((int64_t)BBroadcastShape.size() != XRank) + if (static_cast(BBroadcastShape.size()) != XRank) op->emitOpError("LayerNormalization op with incompatible B shapes " "(unidirectional broadcast)"); if (bType.getElementType() != XElementType) @@ -208,7 +284,7 @@ LogicalResult verifyShapeForLayerNorm(OP_TYPE *op) { XShape, scaleShape, scaleBroadcastShape)) op->emitOpError( "LayerNormalization op with incompatible scale shapes (broadcast)"); - if ((int64_t)scaleBroadcastShape.size() != XRank) + if (static_cast(scaleBroadcastShape.size()) != XRank) op->emitOpError("LayerNormalization op with incompatible scale shapes " "(unidirectional broadcast)"); if (scaleType.getElementType() != XElementType) @@ -260,7 +336,7 @@ mlir::LogicalResult ONNXLNOpShapeHelper::computeShape() { if (hasMean) { DimsExpr meanShape(getOutputDims(0)); for (int64_t r = axis; r < XRank; ++r) - meanShape[r] = LiteralIndexExpr(1); + meanShape[r] = LitIE(1); setOutputDims(meanShape, 1, false); } @@ -268,7 +344,7 @@ mlir::LogicalResult ONNXLNOpShapeHelper::computeShape() { if (hasInvStdDev) { DimsExpr invStdDevShape(getOutputDims(0)); for (int64_t r = axis; r < XRank; ++r) - invStdDevShape[r] = LiteralIndexExpr(1); + invStdDevShape[r] = LitIE(1); setOutputDims(invStdDevShape, invStdDevIndex, false); } return success(); diff --git a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp index 27cce02696..59ef14d10f 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp @@ -4,7 +4,7 @@ //===------------------ Pooling.cpp - ONNX Operations ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -38,8 +38,8 @@ LogicalResult ONNXGenericGlobalPoolOpShapeHelper::computeShape() { outputDims.emplace_back(xDims[0]); outputDims.emplace_back(xDims[1]); // Spatial dimensions are reduced to 1. - for (int i = 2; i < (int)xDims.size(); ++i) - outputDims.emplace_back(LiteralIndexExpr(1)); + for (int i = 2; i < static_cast(xDims.size()); ++i) + outputDims.emplace_back(LitIE(1)); // Save the final result. setOutputDims(outputDims); return success(); @@ -48,21 +48,25 @@ LogicalResult ONNXGenericGlobalPoolOpShapeHelper::computeShape() { template <> LogicalResult ONNXMaxRoiPoolOpShapeHelper::computeShape() { ONNXMaxRoiPoolOpAdaptor operandAdaptor(operands, op->getAttrDictionary()); - IndexExpr channel = createIE->getShapeAsDim(operandAdaptor.getX(), 1); - uint64_t roisRank = createIE->getShapedTypeRank(operandAdaptor.getRois()); + + const auto rois = operandAdaptor.getRois(); + if (!hasShapeAndRank(rois)) { + return failure(); + } + uint64_t roisRank = createIE->getShapedTypeRank(rois); if (roisRank != 2) return op->emitError("rois rank is expected to be 2d"); // 2d tensor: (num_rois, 5) - IndexExpr numRois = createIE->getShapeAsDim(operandAdaptor.getRois(), 0); + IndexExpr numRois = createIE->getShapeAsDim(rois, 0); DimsExpr pooledDims; createIE->getIntFromArrayAsLiterals( operandAdaptor.getPooledShape(), pooledDims); // 4-D tensor : (num_rois, channels, pooled_shape[0], pooled_shape[1]). DimsExpr outputDims; - outputDims.push_back(LiteralIndexExpr(numRois)); + outputDims.push_back(LitIE(numRois)); outputDims.push_back(channel); outputDims.push_back(pooledDims[0]); outputDims.push_back(pooledDims[1]); @@ -107,7 +111,7 @@ LogicalResult ONNXAveragePoolOp::verify() { auto X = operandAdaptor.getX(); if (hasShapeAndRank(X)) { auto xShape = mlir::cast(X.getType()).getShape(); - if ((int64_t)xShape.size() - 2 != spatialRank) + if (static_cast(xShape.size()) - 2 != spatialRank) return emitOpError("Input and kernel shape rank mismatch"); } diff --git a/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp b/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp index 1e833aac54..cdfe6faa0c 100644 --- a/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp +++ b/src/Dialect/ONNX/ONNXOps/NN/RoiAlign.cpp @@ -39,8 +39,8 @@ LogicalResult ONNXRoiAlignOpShapeHelper::computeShape() { int64_t height = roiAlignOp.getOutputHeight(); int64_t width = roiAlignOp.getOutputWidth(); - DimsExpr outputDims = {batchIndicesDims[0], xDims[1], - LiteralIndexExpr(height), LiteralIndexExpr(width)}; + DimsExpr outputDims = { + batchIndicesDims[0], xDims[1], LitIE(height), LitIE(width)}; // Save the final result. setOutputDims(outputDims); diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 964f79b436..233383fbe9 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -4,7 +4,7 @@ //===------- ONNXOpsHelper.cpp - Helper functions for ONNX dialects -------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Path.h" @@ -124,8 +125,8 @@ bool hasCustomONNXTensorDataLayout(const Type type) { } bool sameRank(Value tensorOrMemref1, Value tensorOrMemref2) { - auto type1 = dyn_cast_or_null(tensorOrMemref1.getType()); - auto type2 = dyn_cast_or_null(tensorOrMemref2.getType()); + auto type1 = mlir::dyn_cast_or_null(tensorOrMemref1.getType()); + auto type2 = mlir::dyn_cast_or_null(tensorOrMemref2.getType()); if (!type1 || !type2) return false; if (!type1.hasRank() || !type2.hasRank()) @@ -234,7 +235,7 @@ std::vector getIndexExprsForConvWindow( SmallVector endExprs = {end1, end2}; windowEndExpr = IndexExpr::min(endExprs); // kernelOffsetExpr - SmallVector kernelExprs = {LiteralIndexExpr(0), start2}; + SmallVector kernelExprs = {LitIE(0), start2}; kernelOffsetExpr = IndexExpr::min(kernelExprs); return std::vector{ @@ -307,18 +308,19 @@ void ArrayAttrIntVals(ArrayAttr a, mlir::SmallVectorImpl &i) { ElementsAttr getElementAttributeFromONNXValue(Value value) { ONNXConstantOp constantOp = getONNXConstantOp(value); - if (constantOp) + // In case the ConstantOp has not been normalized yet + if (constantOp && constantOp.getValueAttr()) return mlir::dyn_cast(constantOp.getValueAttr()); return nullptr; } // Returns the ConstantOp which defines an MLIR Value or null. ONNXConstantOp getONNXConstantOp(Value value) { - return dyn_cast_or_null(value.getDefiningOp()); + return mlir::dyn_cast_or_null(value.getDefiningOp()); } bool getI64ValuesFromONNXConstantOp( - mlir::Value val, mlir::SmallVectorImpl &iRes) { + Value val, mlir::SmallVectorImpl &iRes) { ElementsAttr elemsAttr = getElementAttributeFromONNXValue(val); if (!elemsAttr) return false; @@ -329,6 +331,18 @@ bool getI64ValuesFromONNXConstantOp( return true; } +//===----------------------------------------------------------------------===// +// Support for BatchNorm + +ONNXConstantOp createConstantOp( + PatternRewriter &rewriter, Location loc, ArrayAttr values) { + return rewriter.create(loc, Attribute(), + DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(values.size())}, rewriter.getI64Type()), + llvm::ArrayRef(values.getValue()))); +} + //===----------------------------------------------------------------------===// // Support for transpose patterns. //===----------------------------------------------------------------------===// @@ -376,7 +390,7 @@ bool HasSpecifiedConstantShape(Value value, Value shape) { return false; int64_t dimensionsOfShape = shapeAttr.getShapedType().getShape()[0]; - if ((int64_t)valueShape.size() != dimensionsOfShape) + if (static_cast(valueShape.size()) != dimensionsOfShape) return false; auto valueIt = shapeAttr.getValues().begin(); @@ -390,11 +404,11 @@ bool HasSpecifiedConstantShape(Value value, Value shape) { /// Test if a value is a scalar constant tensor or not, i.e. tensor or /// tensor<1xdtype>. -bool isScalarConstantTensor(mlir::Value v) { +bool isScalarConstantTensor(Value v) { if (!hasShapeAndRank(v)) return false; - auto t = dyn_cast(v.getType()); + auto t = mlir::dyn_cast(v.getType()); int64_t r = t.getRank(); return isDenseONNXConstant(v) && ((r == 0) || ((r == 1) && (t.getShape()[0] == 1))); @@ -438,7 +452,7 @@ bool hasOneUseExceptDimOp(Value val) { // Create an ArrayAttr from a dense ConstantOp ArrayAttr createArrayAttrFromConstantOp(ONNXConstantOp constOp) { - auto elements = cast(constOp.getValueAttr()); + auto elements = mlir::cast(constOp.getValueAttr()); SmallVector values(elements.getValues()); return ArrayAttr::get(constOp.getContext(), values); } @@ -447,13 +461,21 @@ ArrayAttr createArrayAttrFromConstantOp(ONNXConstantOp constOp) { DenseElementsAttr createDenseElementsAttrFromFloatAttr( PatternRewriter &rewriter, Type elementType, FloatAttr attr) { auto tensorType = RankedTensorType::get({1}, elementType); - auto ftype = cast(elementType); + auto ftype = mlir::cast(elementType); APFloat f = attr.getValue(); bool ignored; f.convert(ftype.getFloatSemantics(), APFloat::rmNearestTiesToEven, &ignored); return DenseElementsAttr::get(tensorType, {f}); } +ONNXCastOp castTo( + PatternRewriter &rewriter, Value val, Type newElementTy, int64_t saturate) { + return rewriter.create(val.getLoc(), + val.getType().cast().clone(newElementTy), val, + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), saturate), + TypeAttr::get(newElementTy)); +} + //===----------------------------------------------------------------------===// // Support for dim operations. //===----------------------------------------------------------------------===// @@ -528,7 +550,7 @@ DenseElementsAttr createDenseElementsAttrFromSize( /// Check whether a value is produced by a dense ONNXConstantOp. bool isDenseONNXConstant(Value result) { ONNXConstantOp constOp = - dyn_cast_or_null(result.getDefiningOp()); + mlir::dyn_cast_or_null(result.getDefiningOp()); // Must be a constant. if (!constOp) @@ -556,10 +578,13 @@ RESULT_TYPE getScalarValue(ElementsAttr denseAttr, Type type) { if (elementaryType.isInteger(16) || elementaryType.isInteger(32) || elementaryType.isInteger(64)) { auto valueIt = denseAttr.getValues().begin(); - return (RESULT_TYPE)mlir::cast(*valueIt).getInt(); + return static_cast(mlir::cast(*valueIt).getInt()); } else if (mlir::isa(elementaryType)) { auto valueIt = denseAttr.getValues().begin(); - return (RESULT_TYPE)(*valueIt).convertToDouble(); + return static_cast((*valueIt).convertToDouble()); + } else if (elementaryType.isBF16()) { + auto valueIt = denseAttr.getValues().begin(); + return static_cast((*valueIt).convertToFloat()); } llvm_unreachable("Unexpected type."); return 0; @@ -579,6 +604,24 @@ RESULT_TYPE getScalarValue(ONNXConstantOp constantOp) { template double getScalarValue(ONNXConstantOp constantOp); template int64_t getScalarValue(ONNXConstantOp constantOp); +/// Return the wide type of a value. +WideNum asWideNum(double n, Type elemType) { + return wideZeroDispatch(elemType, [n](auto wideZero) { + using cpptype = decltype(wideZero); + constexpr BType TAG = toBType; + return WideNum::widen(static_cast(n)); + }); +} + +/// Checks whether a constant tensor's elements are all equal to a given scalar. +bool isConstOf(Value constValue, double n) { + ElementsAttr constElements = getElementAttributeFromONNXValue(constValue); + Type elemType = constElements.getElementType(); + assert(!elemType.isInteger(1) && "booleans are not supported"); + WideNum w = asWideNum(n, elemType); + return ElementsAttrBuilder::allEqual(constElements, w); +} + // Convert type to MLIR type. // A complete list of types can be found in: // /third_party/onnx/onnx/onnx.pb.h @@ -622,6 +665,10 @@ Type convertONNXTypeToMLIRType( return builder.getI1Type(); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: return ONNXStringType::get(builder.getContext()); + case onnx::TensorProto_DataType::TensorProto_DataType_INT4: + return builder.getIntegerType(/*width=*/4); + case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: + return builder.getIntegerType(/*width=*/4, false); case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: @@ -674,6 +721,10 @@ int64_t mlirTypeToOnnxType(Type elemType) { ? onnx::TensorProto::UNDEFINED : onnx::TensorProto::BOOL; break; + case 4: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4 + : onnx::TensorProto::INT4; + break; case 8: onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8 : onnx::TensorProto::INT8; @@ -721,7 +772,7 @@ bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) { double floatVal = getScalarValue(elementAttr, elementType); if (floatVal == ceil(floatVal)) { // We essentially have an integer value represented as a float. - exponentValue = (int64_t)floatVal; + exponentValue = static_cast(floatVal); return true; } } else if (mlir::isa(elementType)) { @@ -851,4 +902,15 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) { return "NOTSET"; } +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr) { + const auto denseResourceElementsAttr = + mlir::dyn_cast(elementsAttr); + return denseResourceElementsAttr && + !denseResourceElementsAttr.getRawHandle().getBlob(); +} + } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index 3d827f85d5..9a04700a8f 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -44,6 +44,7 @@ #include "src/Support/TypeUtilities.hpp" #include +#include #include namespace onnx_mlir { @@ -187,6 +188,12 @@ bool getI64ValuesFromONNXConstantOp( // Note: It's ok to inline the isa test and not call this function. inline bool isNoneValue(mlir::Value value); +//===----------------------------------------------------------------------===// +// Support for BatchNorm + +mlir::ONNXConstantOp createConstantOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::ArrayAttr values); + //===----------------------------------------------------------------------===// // Support for transpose patterns. //===----------------------------------------------------------------------===// @@ -222,6 +229,12 @@ mlir::DenseElementsAttr createDenseElementsAttrFromFloatAttr( mlir::PatternRewriter &rewriter, mlir::Type elementType, mlir::FloatAttr attr); +mlir::ONNXCastOp castTo(mlir::PatternRewriter &rewriter, mlir::Value val, + mlir::Type newElementTy, int64_t saturate); + +mlir::Value normalizeConstantOp( + mlir::PatternRewriter &rewriter, mlir::Value output, mlir::Attribute attr); + // Create a DenseElementsAttr based on the shape of type at the given index. mlir::DenseElementsAttr createDenseElementsAttrFromShapeAtIndex( mlir::PatternRewriter &rewriter, mlir::Value value, @@ -244,6 +257,12 @@ RESULT_TYPE getScalarValue(mlir::ElementsAttr denseAttr, mlir::Type type); template RESULT_TYPE getScalarValue(mlir::ONNXConstantOp constantOp); +/// Return the wide type of a value. +WideNum asWideNum(double n, mlir::Type elemType); + +/// Checks whether a constant tensor's elements are all equal to a given scalar. +bool isConstOf(mlir::Value constValue, double n); + mlir::Type convertONNXTypeToMLIRType( mlir::Builder &builder, onnx::TensorProto_DataType onnxType); @@ -263,6 +282,36 @@ bool hasIntegerPowerExponent(mlir::ONNXPowOp *op, int64_t &exponentValue); template bool definedBy(mlir::Value v); +// This is to match if two values A and B are bijectively defined by OP1 and +// OP2. In other words, +// - if A is defined by OP1, then B would be defined by OP2. +// - if A is defined by OP2, then B would be defined by OP1. +// +// In both case, the output has two values, +// - the first one is the value defined by OP1, +// - the second one is the value defined by OP2. +// +// For example, to recognize BOTH A*B+C and C+A*B, where C is defined by +// ONNXConstant +// ``` +// %C = onnx.Constant +// %AB = onnx.MatMul(A, B) +// onnx.Add(%AB, %C); +// ``` +// +// We can use: +// Value lhs = addOp.getOperation(0); +// Value rhs = addOp.getOperation(1); +// ValueRange matchedValued; +// +// Value AB, C; +// areDefinedBy(lhs, rhs, AB, C); +// +// Note: The order of A and B are not important, they can be swapped. +template +bool areDefinedBy(mlir::Value A, mlir::Value B, mlir::Value &matchedOP1, + mlir::Value &matchedOP2); + // Check if the operation defining `op->operand[matchThisOperandIndex]` matches // `OP`. If it does, set matchOperand to that operand, and matchOp to that // defining op. Otherwise, don't change the match values. @@ -277,6 +326,43 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op, mlir::Value &matchOperand0, mlir::Value &matchOperand1, int64_t matchThisOperandIndex); +// This is to recognize a binary op, e.g. A*B where one of A and B is a constant +// and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern: +// %x = "onnx.Tanh"() +// %y = 0.5 * %x // or %x * 0.5 +// +// we call +// ``` +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, 0.5, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &op); + +// This is to recognize a binary op, e.g. A*B where one of A and B is the given +// value and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern where %z is one of the inputs of *, +// and the other input of * is defined by onnx.Tanh: +// %x = "onnx.Tanh"() +// %y = %z * %x // or %x * %z +// +// we call +// ``` +// Value z; +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, z, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchValueAndOp( + mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp); + /// Check if a value is to store dimensions, meaning it is a tensor of one /// element or concatenation of one-element tensors. bool areDims(mlir::Value val); @@ -307,6 +393,14 @@ bool isIdentityReshape(mlir::Value input, mlir::Value output, std::string getNodeNameInPresenceOfOpt( mlir::Operation *op, bool useFileLine = true); +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +/// Returns true if elementsAttr is a DenseResourceAttr with a blob that can not +/// be received +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr); + #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc" } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc index b0fa82f8c4..fd3de372d2 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc @@ -24,7 +24,7 @@ mlir::Location ONNXLoc(mlir::Operation *op) { } inline bool isNoneValue(mlir::Value value) { - return llvm::isa(value.getType()); + return MemRefBuilder::isNoneValue(value); } /// Check the defining operation of a value. @@ -33,6 +33,22 @@ bool definedBy(mlir::Value v) { return !mlir::isa(v) && llvm::isa(v.getDefiningOp()); } +template +bool areDefinedBy(mlir::Value A, mlir::Value B, mlir::Value &matchedOP1, mlir::Value + &matchedOP2) { + if (A.getDefiningOp() && B.getDefiningOp()) { + matchedOP1 = A; + matchedOP2 = B; + return true; + } + if (A.getDefiningOp() && B.getDefiningOp()) { + matchedOP1 = B; + matchedOP2 = A; + return true; + } + return false; +} + // Support for recognizing patterns. Detects if the operation "op" has an input // operand number "matchThisOperandIndex" that is defined by an operation of // type "OP". If that is the case, "matchOperand" will be set to that operand, @@ -83,3 +99,65 @@ bool operandOfOpDefinedBy(mlir::Operation *&matchOp, mlir::Operation *op, } return false; } + +// This is to recognize a binary op, e.g. A*B where one of A and B is a constant +// and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern: +// %x = "onnx.Tanh"() +// %y = 0.5 * %x // or %x * 0.5 +// +// we call +// ``` +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, 0.5, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchConstAndOp(mlir::Value A, mlir::Value B, double cst, OP &matchOp) { + auto opA = A.getDefiningOp(); + auto opB = B.getDefiningOp(); + if (onnx_mlir::isDenseONNXConstant(A) && onnx_mlir::isConstOf(A, cst) && opB) + { + matchOp = opB; + return true; + } + if (opA && onnx_mlir::isDenseONNXConstant(B) && onnx_mlir::isConstOf(B, cst)) + { + matchOp = opA; + return true; + } + return false; +} + +// This is to recognize a binary op, e.g. A*B where one of A and B is the given +// value and the other one is defined by OP. +// Note: this function can handle the communitive property of the binary op. +// +// For example, to recognize this pattern where %z is one of the inputs of *, +// and the other input of * is defined by onnx.Tanh: +// %x = "onnx.Tanh"() +// %y = %z * %x // or %x * %z +// +// we call +// ``` +// Value z; +// ONNXTanhOp tanhOp; +// bool found = matchConstAndOp(A, B, z, tanhOp); +// ``` +// where `A` and `B` are operands of ONNXMul that produces %y. +template +bool matchValueAndOp(mlir::Value A, mlir::Value B, mlir::Value matchValue, OP &matchOp) { + auto opA = A.getDefiningOp(); + auto opB = B.getDefiningOp(); + if ((A == matchValue) && opB) { + matchOp = opB; + return true; + } + if (opA && (B == matchValue)) { + matchOp = opA; + return true; + } + return false; +} diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp index 4728b8f2a9..51fc4a1cf7 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp @@ -68,7 +68,7 @@ LogicalResult ONNXDequantizeLinearOpShapeHelper::computeShape() { if (a < 0) a += r; if (!outputDims[a].isLiteral()) { - outputDims[a] = LiteralIndexExpr(d); + outputDims[a] = LitIE(d); } LLVM_DEBUG(llvm::dbgs() << "literal: " << outputDims[a].getLiteral() << " d = " << d << "\n"); diff --git a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp index aa536b3df9..f79c18f8f3 100644 --- a/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp +++ b/src/Dialect/ONNX/ONNXOps/RNN/RNN.cpp @@ -58,7 +58,7 @@ LogicalResult ONNXGenericRNNShapeHelper::customComputeShape( // Get hidden size from hidden_size attribute. IndexExpr hiddenSize; if (operandAdaptor.getHiddenSize().has_value()) { - hiddenSize = LiteralIndexExpr(operandAdaptor.getHiddenSize().value()); + hiddenSize = LitIE(operandAdaptor.getHiddenSize().value()); } else { // Infer hidden_size from wShape and rShape if possible. if (rDims[2].isLiteral()) @@ -84,9 +84,9 @@ LogicalResult ONNXGenericRNNShapeHelper::customComputeShape( IndexExpr numDir; if ((operandAdaptor.getDirection() == "forward") || (operandAdaptor.getDirection() == "reverse")) - numDir = LiteralIndexExpr(1); + numDir = LitIE(1); else if (operandAdaptor.getDirection() == "bidirectional") - numDir = LiteralIndexExpr(2); + numDir = LitIE(2); else return op->emitError( "direction attribute must be one of the strings: forward, " diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp index 3e46327a7f..be146078c7 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/Sequence.cpp @@ -4,7 +4,7 @@ //===------------------ Sequence.cpp - ONNX Operations -------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -109,8 +109,9 @@ LogicalResult ONNXSequenceEmptyOp::verify() { auto builder = OpBuilder(getContext()); Type elementType; if (getDtypeAttr()) { - elementType = convertONNXTypeToMLIRType(builder, - (onnx::TensorProto_DataType)getDtypeAttr().getValue().getSExtValue()); + elementType = convertONNXTypeToMLIRType( + builder, static_cast( + getDtypeAttr().getValue().getSExtValue())); } else { elementType = builder.getF32Type(); } diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp index 3a17990e56..38f922f765 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp @@ -58,6 +58,10 @@ LogicalResult ONNXSplitToSequenceOp::verify() { if (splitRank > 1) return emitOpError() << ": split has rank " << splitRank << " > 1"; if (ElementsAttr entries = getElementAttributeFromONNXValue(splitValue)) { + if (isElementAttrUninitializedDenseResource(entries)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } if (splitRank == 0) { auto scalar = getScalarValue(entries, splitType); if (scalar <= 0) diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 7561575f11..0ca91b44ce 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -102,29 +102,31 @@ static void refineDims(Operation *op, DimsExpr &inferredDims, Value output) { // InferredDim is unknown at shape inference: update it. if (inferredDims[i].isQuestionmark()) { - inferredDims[i] = LiteralIndexExpr(existingDims[i]); + inferredDims[i] = LitIE(existingDims[i]); continue; } // inferredDim is unknown at lowering: use existing dim for efficiency. if (!inferredDims[i].isLiteral()) { - inferredDims[i] = LiteralIndexExpr(existingDims[i]); + inferredDims[i] = LitIE(existingDims[i]); continue; } // inferredDim is different from existingDim. Believe in existingDim. assert(inferredDims[i].isLiteral() && "isLiteral failed"); if (existingDims[i] != inferredDims[i].getLiteral()) { if (op) - llvm::outs() << "Warning for operation " << op->getName() + llvm::errs() << "\nWarning for operation " << op->getName() << ": [Shape inference, dim " << i << "] the inferred dim (" << inferredDims[i].getLiteral() << ") is different from the existing dim (" - << existingDims[i] << "). Use the existing dim instead.\n"; + << existingDims[i] + << "). Use the existing dim instead.\n\n"; else - llvm::outs() << "Warning: [Shape inference, dim " << i + llvm::errs() << "\nWarning: [Shape inference, dim " << i << "] the inferred dim (" << inferredDims[i].getLiteral() << ") is different from the existing dim (" - << existingDims[i] << "). Use the existing dim instead.\n"; - inferredDims[i] = LiteralIndexExpr(existingDims[i]); + << existingDims[i] + << "). Use the existing dim instead.\n\n"; + inferredDims[i] = LitIE(existingDims[i]); } } } @@ -219,10 +221,10 @@ LogicalResult ONNXOpShapeHelper::setOutputDimsFromTypeWithConstantShape( LogicalResult ONNXOpShapeHelper::computeShapeAndUpdateType( Type elementType, Attribute encoding) { // Invoke virtual compute shape. - if (failed(computeShape())) - return op->emitError("Failed to scan parameters successfully"); - assert((mlir::isa(elementType) || - !mlir::isa(elementType)) && + if (failed(computeShape())) { + return failure(); + } + assert((isa(elementType) || !isa(elementType)) && "element type cannot be a shaped type other than vector type"); uint64_t resNum = op->getNumResults(); for (uint64_t i = 0; i < resNum; ++i) { @@ -285,6 +287,11 @@ LogicalResult ONNXBroadcastOpShapeHelper::customComputeShape( DimsExpr dimsExpr; uint64_t numOfInputs = initialOperands.size(); + if (!llvm::all_of(initialOperands, + [](Value initalOperand) { return hasShapeAndRank(initalOperand); })) { + return failure(); + } + // Compute rank of the output. Rank of the output is the maximum rank of all // initial operands. uint64_t additionalOperRank = @@ -460,7 +467,7 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( << " dim analysis\n"); // Keep track of cumulative inner dim sizes. collapsedLiteralSize = 1; - collapsedDynamicSize = LiteralIndexExpr(1); + collapsedDynamicSize = LitIE(1); // Keep track of ones, scalar, and broadcast per input. llvm::SmallBitVector isOne(dimNum, true); llvm::SmallBitVector isScalar(dimNum, true); @@ -708,7 +715,7 @@ LogicalResult ONNXBroadcastOpShapeHelper::getAccessExprs(Value operand, // Compute access index based on broadcasting rules. if (operandDim.isLiteralAndIdenticalTo(1)) { // Dim of size 1: access is always 0. - operandAccessExprs.emplace_back(LiteralIndexExpr(0)); + operandAccessExprs.emplace_back(LitIE(0)); } else if (noBroadcasting || useLoopIndexNoMatterWhat) { // No broadcasting or we can use the loop index no matter what -> just use // the index. @@ -764,7 +771,7 @@ bool ONNXUnaryOpShapeHelper::hasManageableBroadcastForInnerDims( int64_t outputRank = output.size(); // Keep track of cumulative inner dim sizes. collapsedLiteralSize = 1; - collapsedDynamicSize = LiteralIndexExpr(1); + collapsedDynamicSize = LitIE(1); for (int64_t r = 0; r < outputRank; ++r) { if (output[r].isLiteral()) collapsedLiteralSize *= output[r].getLiteral(); @@ -826,7 +833,7 @@ void updateType(Operation *op, Value val, ArrayRef shape, if (ShapedType::isDynamic(d) || d == -1) inferredDims.emplace_back(QuestionmarkIndexExpr(/*isFloat*/ false)); else - inferredDims.emplace_back(LiteralIndexExpr(d)); + inferredDims.emplace_back(LitIE(d)); } refineDims(op, inferredDims, val); IndexExpr::getShape(inferredDims, inferredShape); @@ -906,7 +913,7 @@ ONNXCustomOpShapeHelper::ONNXCustomOpShapeHelper(Operation *op, return; } - std::vector operandsVector; + std::vector operandsVector; for (auto indexAttr : inputIndexAttrs.value()) { operandsVector.push_back( inputs[mlir::cast(indexAttr).getInt()]); diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index 01a8943ead..3fc36022d6 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -31,6 +31,7 @@ #include "src/Dialect/Mlir/IndexExprBuilder.hpp" #include "src/Dialect/ONNX/ONNXDimAnalysis.hpp" +#define GET_OP_FWD_DEFINES 1 #include "src/Dialect/ONNX/ONNXOps.hpp.inc" // ONNXOpShapeHelper is defined in the interface file below. @@ -265,6 +266,7 @@ using ONNXLessOrEqualOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXMaxOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXMeanOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXMinOpShapeHelper = ONNXBroadcastOpShapeHelper; +using ONNXMishOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXModOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXMulOpShapeHelper = ONNXBroadcastOpShapeHelper; using ONNXOrOpShapeHelper = ONNXBroadcastOpShapeHelper; @@ -855,6 +857,8 @@ struct ONNXNonSpecificOpShapeHelper : public ONNXOpShapeHelper { // Ops listed in alphabetical order. Disable formatting for easier sorting. // clang-format off +using ONNXBatchNormalizationOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXBatchNormalizationV9OpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXBatchNormalizationInferenceModeOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXCategoryMapperOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXCompressOpShapeHelper = ONNXNonSpecificOpShapeHelper; @@ -868,6 +872,7 @@ using ONNXDimOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXDropoutOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXDynamicQuantizeLinearOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXEinsumOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXGridSampleOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXEyeLikeOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXFlattenOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXGatherElementsOpShapeHelper = ONNXNonSpecificOpShapeHelper; diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp index 8b6762fd0c..e9a9aec062 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ArgMinMax.cpp @@ -55,8 +55,8 @@ LogicalResult ONNXArgMinMaxOpShapeHelper::computeShape() { outputDims.resize(reducedRank); for (int64_t i = 0; i < reducedRank; i++) { if (isKeepdims) - outputDims[i] = (i != axisValue) ? createIE->getShapeAsDim(data, i) - : LiteralIndexExpr(1); + outputDims[i] = + (i != axisValue) ? createIE->getShapeAsDim(data, i) : LitIE(1); else outputDims[i] = (i < axisValue) ? createIE->getShapeAsDim(data, i) : createIE->getShapeAsDim(data, i + 1); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp index ec6a7a1cb4..0bf3b36999 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Compress.cpp @@ -31,6 +31,9 @@ LogicalResult ONNXCompressOpShapeHelper::computeShape() { ONNXCompressOpAdaptor operandAdaptor(operands); Value input = operandAdaptor.getInput(); Value cond = operandAdaptor.getCondition(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); createIE->assertHasShapeAndRank(cond); std::optional optionalAxis = compressOp.getAxis(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp index bb16aef01c..70ee132682 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp @@ -50,9 +50,9 @@ LogicalResult ONNXConstantOpShapeHelper::computeShape() { std::vector ONNXConstantOp::resultTypeInference() { ShapedType type; if (auto attr = getValueAttr()) { - type = cast(attr).getShapedType(); + type = mlir::cast(attr).getShapedType(); } else if (auto attr = getSparseValueAttr()) { - type = cast(attr).getShapedType(); + type = mlir::cast(attr).getShapedType(); } else if (auto attr = getValueFloatAttr()) { type = RankedTensorType::get({}, FloatType::getF32(getContext())); } else if (auto attr = getValueFloatsAttr()) { diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index c34683341f..6058adfcdb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -70,6 +70,10 @@ LogicalResult ONNXConstantOfShapeOp::verify() { if (auto constantOp = getONNXConstantOp(input)) { ElementsAttr valueAttribute = mlir::cast(constantOp.getValueAttr()); + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } // Get repeat values from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < inputShape[0]; ++i) { @@ -93,7 +97,7 @@ LogicalResult ONNXConstantOfShapeOp::verify() { std::vector ONNXConstantOfShapeOp::resultTypeInference() { Type elementType; if (auto attr = getValueAttr()) { - elementType = cast(attr).getElementType(); + elementType = mlir::cast(attr).getElementType(); } else { elementType = FloatType::getF32(getContext()); } @@ -106,7 +110,7 @@ std::vector ONNXConstantOfShapeOp::resultTypeInference() { LogicalResult ONNXConstantOfShapeOp::inferShapes( std::function doShapeInference) { - ShapedType inputType = cast(getInput().getType()); + ShapedType inputType = mlir::cast(getInput().getType()); if (!inputType.hasStaticShape()) return success(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp index 68c9a213a7..515d947efb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/DepthToSpace.cpp @@ -30,6 +30,9 @@ LogicalResult ONNXDepthToSpaceOpShapeHelper::computeShape() { ONNXDepthToSpaceOp depthOp = llvm::cast(op); ONNXDepthToSpaceOpAdaptor operandAdaptor(operands); Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); assert(inputRank == 4 && "Unexpected input tensor rank"); int64_t blocksize = depthOp.getBlocksize(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp index 9e25039d8f..8650c2af7b 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Expand.cpp @@ -40,7 +40,7 @@ LogicalResult ONNXExpandOpShapeHelper::computeShape() { if (ShapedType::isDynamic(shapeType.getShape()[0])) return op->emitError("expected size of shape parameter to be defined"); - if (ONNXShapeOp shapeOp = dyn_cast_or_null(shapeDefOp)) { + if (ONNXShapeOp shapeOp = mlir::dyn_cast_or_null(shapeDefOp)) { assert(mlir::isa(shapeOp.getData().getType()) && "expected"); // Consider a first case where the expand.shape is produced by a shape op. // Infer its shape and use it as the requested shape. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp index 18ba80810a..9f64524e3c 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/EyeLike.cpp @@ -4,7 +4,7 @@ //===------------------ .cpp - ONNX Operations ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -39,6 +39,33 @@ LogicalResult ONNXEyeLikeOpShapeHelper::computeShape() { // Verify //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Type Inference +//===----------------------------------------------------------------------===// + +Type ONNXEyeLikeOp::getResultElementType() { + const auto inputType = cast(getInput().getType()); + if (getDtypeAttr()) { + auto builder = OpBuilder(getContext()); + return convertONNXTypeToMLIRType( + builder, static_cast( + getDtypeAttr().getValue().getSExtValue())); + } + return inputType.getElementType(); +} + +std::vector ONNXEyeLikeOp::resultTypeInference() { + Type elementType = getResultElementType(); + std::vector resultTypes; + + if (auto rankedInputType = dyn_cast(getInput().getType())) { + resultTypes.push_back(rankedInputType.clone(elementType)); + } else { + resultTypes.push_back(UnrankedTensorType::get(elementType)); + } + return resultTypes; +} + //===----------------------------------------------------------------------===// // Shape Inference //===----------------------------------------------------------------------===// @@ -48,17 +75,7 @@ LogicalResult ONNXEyeLikeOp::inferShapes( if (!hasShapeAndRank(getInput())) return success(); - RankedTensorType inputType = - mlir::cast(getInput().getType()); - Type elementType; - if (getDtypeAttr()) { - auto builder = OpBuilder(getContext()); - elementType = convertONNXTypeToMLIRType(builder, - (onnx::TensorProto_DataType)getDtypeAttr().getValue().getSExtValue()); - } else { - elementType = inputType.getElementType(); - } - + Type elementType = getResultElementType(); ONNXEyeLikeOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp index 5ff505b754..024f7137b8 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Flatten.cpp @@ -34,7 +34,7 @@ LogicalResult ONNXFlattenOpShapeHelper::computeShape() { ArrayRef inputShape = inputType.getShape(); int64_t inputRank = inputType.getRank(); int64_t axis = flattenOp.getAxis(); - assert(axis >= -inputRank && axis < inputRank && "Invalid inputRank"); + assert(axis >= -inputRank && axis <= inputRank && "Invalid inputRank"); // Negative axis means values are counted from the opposite side. if (axis < 0) @@ -42,13 +42,13 @@ LogicalResult ONNXFlattenOpShapeHelper::computeShape() { // Warning: code does appear to only work for shape inference. // Compute outputDims. - DimsExpr outputDims = {LiteralIndexExpr(1), LiteralIndexExpr(1)}; + DimsExpr outputDims = {LitIE(1), LitIE(1)}; for (int64_t i = 0; i < axis; ++i) { if (ShapedType::isDynamic(inputShape[i])) { outputDims[0] = QuestionmarkIndexExpr(/*isFloat*/ false); break; } - outputDims[0] = outputDims[0] * LiteralIndexExpr(inputShape[i]); + outputDims[0] = outputDims[0] * LitIE(inputShape[i]); } for (int64_t i = axis; i < inputRank; ++i) { @@ -56,7 +56,7 @@ LogicalResult ONNXFlattenOpShapeHelper::computeShape() { outputDims[1] = QuestionmarkIndexExpr(/*isFloat*/ false); break; } - outputDims[1] = outputDims[1] * LiteralIndexExpr(inputShape[i]); + outputDims[1] = outputDims[1] * LitIE(inputShape[i]); } setOutputDims(outputDims); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp index ce35ad81b3..dde8029994 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp @@ -71,8 +71,13 @@ LogicalResult ONNXGatherElementsOp::verify() { // along axis of size s. ArrayRef dataShape = dataType.getShape(); const int64_t dataDimAtAxis = dataShape[axis]; - if (dataDimAtAxis >= 0) - if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) + if (dataDimAtAxis >= 0) { + if (ElementsAttr valueAttribute = + getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) @@ -83,6 +88,8 @@ LogicalResult ONNXGatherElementsOp::verify() { onnx_mlir::Diagnostic::Range( -dataDimAtAxis, dataDimAtAxis - 1)); } + } + } return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp index 7bf23643cd..b388607c12 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp @@ -4,7 +4,7 @@ //===------------------ GatherND.cpp - ONNX Operations --------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -73,7 +73,7 @@ LogicalResult ONNXGatherNDOpShapeHelper::computeShape() { for (int64_t i = b + indicesLastDim; i < dataRank; ++i) outputDims.emplace_back(dataDims[i]); - assert((int64_t)outputDims.size() == outputRank && + assert(static_cast(outputDims.size()) == outputRank && "Incorrect shape computation"); setOutputDims(outputDims); @@ -144,6 +144,10 @@ LogicalResult ONNXGatherNDOp::verify() { // All values in 'indices' are expected to satisfy the inequality: // -data.shape[b + i] <= indices[...,i] <= (data.shape[b + i]-1)]. if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } int flatIndex = 0; for (IntegerAttr value : valueAttribute.getValues()) { int64_t indexValue = value.getInt(); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp new file mode 100644 index 0000000000..c281ca296e --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GridSample.cpp @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ GridSample.cpp - ONNX Operations ------------------===// +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect GridSample operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace mlir::OpTrait::util; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// Support +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { + +template <> +LogicalResult ONNXGridSampleOpShapeHelper::computeShape() { + + // Read data and indices shapes as dim indices. + ONNXGridSampleOpAdaptor operandAdaptor(operands); + DimsExpr inputDims; + DimsExpr gridDims; + createIE->getShapeAsDims(operandAdaptor.getX(), inputDims); + createIE->getShapeAsDims(operandAdaptor.getGrid(), gridDims); + + int64_t gridRank = gridDims.size(); + + // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr) + // Grid's dimensions should also have rank r+2 and be in the form of + // (N,D1_out,D2_out,...,Dr_out,r). + // The output Y will have shape (N, C, D1_out, D2_out, ..., Dr_out). + DimsExpr outputDims; + outputDims.emplace_back(inputDims[0]); + outputDims.emplace_back(inputDims[1]); + for (int i = 1; i < gridRank - 1; ++i) { + outputDims.emplace_back(gridDims[i]); + } + + setOutputDims(outputDims); + return success(); +} + +} // namespace onnx_mlir + +//===----------------------------------------------------------------------===// +// Verify +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGridSampleOp::verify() { + ONNXGridSampleOpAdaptor operandAdaptor(*this); + auto op = mlir::cast(*this); + + const auto alignCorners = op.getAlignCorners(); + if (alignCorners != 0 && alignCorners != 1) { + return emitOpError("align_corners needs to be 0 or 1"); + } + const auto mode = op.getMode(); + if (mode != "linear" && mode != "nearest" && mode != "cubic") { + return emitOpError("mode needs to be linear, nearest or cubic"); + } + const auto paddingMode = op.getPaddingMode(); + if (paddingMode != "zeros" && paddingMode != "border" && + paddingMode != "reflection") { + return emitOpError("padding_mode needs to be zeros, border or reflection"); + } + + if (!hasShapeAndRank(getOperation())) + return success(); + + auto inputShape = + mlir::cast(operandAdaptor.getX().getType()).getShape(); + int64_t inputRank = inputShape.size(); + auto gridShape = + mlir::cast(operandAdaptor.getGrid().getType()).getShape(); + + // Check whether the ranks of input and grid are valid and are equal. + // Input's dimensions of rank r+2 should be in the form of (N,C,D1,D2,...,Dr) + // Grid's dimensions should also have rank r+2 and be in the form of + // (N,D1_out,D2_out,...,Dr_out,r). + if (inputShape.size() != gridShape.size()) { + return emitOpError() << "Input(=" << inputShape.size() + << ") and grid(=" << gridShape.size() + << ") have different dim sizes."; + } + + if (inputShape[0] != gridShape[0]) { + return emitOpError() << "Input and grid must have the same batch value."; + } + + if (!ShapedType::isDynamic(gridShape.back()) && + gridShape.back() != inputRank - 2) { + return emitOpError() << "Grid last dim must have been '" << inputRank - 2 + << "' instead of '" << gridShape.back() << "'."; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXGridSampleOp::inferShapes( + std::function /*doShapeInference*/) { + + Type elementType = mlir::cast(getX().getType()).getElementType(); + ONNXGridSampleOpShapeHelper shapeHelper(getOperation(), {}); + return shapeHelper.computeShapeAndUpdateType(elementType); +} + +//===----------------------------------------------------------------------===// +// Template instantiation +//===----------------------------------------------------------------------===// + +namespace onnx_mlir { +template struct ONNXNonSpecificOpShapeHelper; +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp index 2c245032f0..bc9809fea0 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/NonZero.cpp @@ -27,7 +27,11 @@ namespace onnx_mlir { template <> LogicalResult ONNXNonZeroOpShapeHelper::computeShape() { ONNXNonZeroOpAdaptor operandAdaptor(operands); - int64_t xRank = createIE->getShapedTypeRank(operandAdaptor.getX()); + auto x = operandAdaptor.getX(); + if (!hasShapeAndRank(x)) { + return failure(); + } + int64_t xRank = createIE->getShapedTypeRank(x); // Cannot refine shape as we may otherwise loose the dynamic dim. return setOutputDimsFromLiterals( {xRank, ShapedType::kDynamic}, 0, /*refineShape*/ false); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp index 3b1699b35e..80474ef9e9 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/OneHot.cpp @@ -28,6 +28,9 @@ LogicalResult ONNXOneHotOpShapeHelper::computeShape() { ONNXOneHotOp oneHotOp = llvm::cast(op); ONNXOneHotOpAdaptor operandAdaptor(operands); Value indices = operandAdaptor.getIndices(); + if (!hasShapeAndRank(indices)) { + return failure(); + } int64_t indicesRank = createIE->getShapedTypeRank(indices); // Axis is a required attribute and should have default value of -1. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp index 16e4713a91..9faa9d464e 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Pad.cpp @@ -13,6 +13,9 @@ //===----------------------------------------------------------------------===// #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" +#include using namespace mlir; using namespace mlir::OpTrait::util; @@ -28,38 +31,104 @@ LogicalResult ONNXPadOpShapeHelper::computeShape() { ONNXPadOpAdaptor operandAdaptor(operands); Value dataOperand = operandAdaptor.getData(); Value padsOperand = operandAdaptor.getPads(); - DimsExpr outputDims; + Value axesOperand = operandAdaptor.getAxes(); // Get info about input data operand. + if (!hasShapeAndRank(dataOperand)) { + return failure(); + } uint64_t dataRank = createIE->getShapedTypeRank(dataOperand); - // Initialize context and results (pads & output) - pads.resize(2 * dataRank); // pads two sides of each axis. - outputDims.resize(dataRank); + bool isFloat = isa(getElementType(dataOperand.getType())); + // Initially, output dim sizes are all unknown. + DimsExpr outputDims(dataRank, QuestionmarkIndexExpr(/*IsFloat=*/isFloat)); + + // Compute the values of the "axes" array. If "axes" operand is not provided, + // it is a range from 0 to dataRank. If it is provided, it is a list of + // integers and the values must be in the range [-dataRank, dataRank). + SmallVector axes; + if (isNoneValue(axesOperand)) { + axes.resize(dataRank); + std::iota(axes.begin(), axes.end(), 0); + } else { + auto axesSize = createIE->getArraySize(axesOperand); + + // Bail out: If axes is dynamic, output is also dynamic. + if (axesSize == ShapedType::kDynamic) { + setOutputDims(outputDims); + return success(); + } + + if (axesSize < 0) { + return op->emitError("axes size must be greater than 0"); + } + + // Iterate over axesOperand to figure out the axes that will be padded + for (auto axesOperandIndex : llvm::seq(axesSize)) { + IndexExpr padsAxis = + createIE->getIntFromArrayAsSymbol(axesOperand, axesOperandIndex); + + // If the values of axesOperand cannot be calculated at compile time, bail + // out... + if (!padsAxis.isLiteral()) { + setOutputDims(outputDims); + return success(); + } + + int64_t positiveAxis = padsAxis.getLiteral(); + if (positiveAxis < 0) { + positiveAxis += dataRank; + } + + if (positiveAxis + (int)dataRank < 0 || positiveAxis >= (int)dataRank) { + return op->emitError("axes value is out of bounds"); + } + + axes.push_back(positiveAxis); + } + } - // `pads` format is : [x1_begin, x2_begin,...,x1_end, x2_end,...], - // where - // - xi_begin: the number of pad values added at the beginning of axis `i` - // - xi_end: the number of pad values added at the end of axis `i`. + // Initialize pads according to the most likely case + pads.resize(2 * dataRank); // pads two sides of each axis. - // Calculate output dimension sizes. - for (uint64_t i = 0; i < dataRank; i++) { + llvm::SmallSet visited; + for (auto [idx, axis] : llvm::enumerate(axes)) { + // `pads` format is : [x1_begin, x2_begin,...,x1_end, x2_end,...], + // where + // - xi_begin: the number of pad values added at the beginning of axis `i` + // - xi_end: the number of pad values added at the end of axis `i`. // Get begin/end pads. - SymbolIndexExpr padBegin(createIE->getIntFromArrayAsSymbol(padsOperand, i)); + SymbolIndexExpr padBegin( + createIE->getIntFromArrayAsSymbol(padsOperand, idx)); SymbolIndexExpr padEnd( - createIE->getIntFromArrayAsSymbol(padsOperand, i + dataRank)); - if (padBegin.isUndefined() || padEnd.isUndefined()) + createIE->getIntFromArrayAsSymbol(padsOperand, idx + axes.size())); + + if (padBegin.isUndefined() || padEnd.isUndefined()) { return op->emitError("pad parameter could not be processed"); + } + // Get input dim. - DimIndexExpr dimInput(createIE->getShapeAsDim(dataOperand, i)); + DimIndexExpr dimInput(createIE->getShapeAsDim(dataOperand, axis)); // Calculation for output size. IndexExpr dimOutputFinal = (padBegin + dimInput) + padEnd; - // Save results. - pads[i] = padBegin; - pads[i + dataRank] = padEnd; - outputDims[i] = dimOutputFinal; + visited.insert(axis); + + // Currently "pads" is only used when axes is NoneType and for constant + // propagation + if (isNoneValue(axesOperand)) { + pads[axis] = padBegin; + pads[axis + dataRank] = padEnd; + } + + outputDims[axis] = dimOutputFinal; + } + + for (auto i : llvm::seq(dataRank)) { + if (!visited.count(i)) { + outputDims[i] = createIE->getShapeAsLiteral(dataOperand, i); + } } // Save the final result. @@ -86,10 +155,6 @@ LogicalResult ONNXPadOp::verify() { } } - if (!isNoneValue(getAxes())) { - return emitOpError("Axes input is not currently supported"); - } - return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp index 81da106541..646c4423d9 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Reshape.cpp @@ -51,12 +51,12 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { // Compute the total number of elements using the input data operand. // dataRank will be 0 if Data is unranked tensor. // The number of element will not be computed - IndexExpr numOfElements = LiteralIndexExpr(1); + IndexExpr numOfElements = LitIE(1); for (unsigned i = 0; i < dataRank; ++i) numOfElements = numOfElements * createIE->getShapeAsDim(data, i); // Compute the total number of elements from the shape values. - IndexExpr numOfElementsFromShape = LiteralIndexExpr(1); + IndexExpr numOfElementsFromShape = LitIE(1); for (unsigned i = 0; i < outputRank; ++i) { IndexExpr dimShape = createIE->getIntFromArrayAsSymbol(shape, i); if (dimShape.isUndefined()) @@ -74,7 +74,7 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { // dimShape == -1: use 1 to compute the number of elements to avoid // negative value. - dim = dim.selectOrSelf(dim == -1, LiteralIndexExpr(1)); + dim = dim.selectOrSelf(dim == -1, LitIE(1)); numOfElementsFromShape = numOfElementsFromShape * dim; } @@ -85,8 +85,9 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { for (unsigned i = 0; i < outputRank; ++i) { if (hasShapeAndRank(data)) { IndexExpr dimShape = createIE->getIntFromArrayAsSymbol(shape, i); - outputDims[i] = outputDims[i].selectOrSelf( - dimShape == -1, numOfElements.floorDiv(numOfElementsFromShape)); + if (dimShape.isLiteralAndIdenticalTo(-1)) { + outputDims[i] = numOfElements.floorDiv(numOfElementsFromShape); + } } else { // ToFix: can not check getAllowzero because the operandAdaptor is // constructed without attributes @@ -111,15 +112,38 @@ LogicalResult ONNXReshapeOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXReshapeOp::verify() { + auto shape = getShape(); + // Cannot verify if shape has unknown rank. - if (!hasShapeAndRank(getShape())) + if (!hasShapeAndRank(shape)) return success(); // Only rank 1 shape tensors are supported. - auto shapeTy = cast(getShape().getType()); + auto shapeTy = cast(shape.getType()); if (shapeTy.getRank() != 1) return emitOpError("Shape tensor must have rank one"); + // Cannot verify if shape is not from ConstantOp + if (!isDenseONNXConstant(shape)) + return success(); + + SmallVector dims; + if (!getI64ValuesFromONNXConstantOp(shape, dims)) { + return emitError( + "Shape comes from ConstantOp but cannot get int64_t values from it."); + } + + auto isZero = [](int64_t val) { return val == 0; }; + auto isMinusOne = [](int64_t val) { return val == -1; }; + + if (getAllowzero()) { + if (llvm::any_of(dims, isZero) && llvm::any_of(dims, isMinusOne)) { + return emitOpError( + "Allowzero is set and shape contains both -1 and 0. Dimension " + "corresponding to -1 cannot be determined uniquely."); + } + } + // TODO: Check that any -1 dim is used correctly. // TODO: Check that any 0 dim is used correctly with allowzero. // TODO: Check that data can reshape to shape if data's shape is known. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp index e4e89239c4..d4559e37de 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Resize.cpp @@ -26,7 +26,7 @@ namespace onnx_mlir { namespace { bool isEmptyTensor(Value input) { - if (ShapedType shapedType = dyn_cast(input.getType())) { + if (ShapedType shapedType = mlir::dyn_cast(input.getType())) { return shapedType.hasStaticShape() && shapedType.getNumElements() == 0; } else { return false; @@ -48,9 +48,13 @@ LogicalResult ONNXResizeOpShapeHelper::computeShape() { ONNXResizeOpAdaptor operandAdaptor(operands, cast(op)); if (operandAdaptor.getAxes().has_value()) return op->emitOpError("axes are unsupported"); - uint64_t rank = createIE->getShapedTypeRank(operandAdaptor.getX()); + const auto x = operandAdaptor.getX(); + if (!hasShapeAndRank(x)) { + return failure(); + } + uint64_t rank = createIE->getShapedTypeRank(x); DimsExpr inputDims, outputDims; - createIE->getShapeAsDims(operandAdaptor.getX(), inputDims); + createIE->getShapeAsDims(x, inputDims); bool scalesIsAbsent = isAbsent(operandAdaptor.getScales()); if (!scalesIsAbsent) { // Read and save scales as float. @@ -94,12 +98,13 @@ LogicalResult ONNXResizeOpShapeHelper::computeShape() { //===----------------------------------------------------------------------===// LogicalResult ONNXResizeOp::verify() { - // Cannot verify if scales or sizes have unknown shapes. - if (auto scalesShapedType = dyn_cast(getScales().getType())) { + // Cannot verify if scales or sizes have unknown sha∑pes. + if (auto scalesShapedType = + mlir::dyn_cast(getScales().getType())) { if (!scalesShapedType.hasStaticShape()) return success(); } - if (auto sizesShapedType = dyn_cast(getSizes().getType())) { + if (auto sizesShapedType = mlir::dyn_cast(getSizes().getType())) { if (!sizesShapedType.hasStaticShape()) return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp index 279cbecbeb..187daea201 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Shape.cpp @@ -52,6 +52,9 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape() { Value data = operandAdaptor.getData(); // Compute and store start/end in ONNXShapeOpShapeHelper object. + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(data); start = shapeOp.getStart(); start = normalizeClampedPerSpec(start, rank); @@ -61,7 +64,7 @@ LogicalResult ONNXShapeOpShapeHelper::computeShape() { return op->emitError("Start must not be greater than end"); // Output shape is a 1D vector with "end-start" values - DimsExpr outputDims(1, LiteralIndexExpr(end - start)); + DimsExpr outputDims(1, LitIE(end - start)); setOutputDims(outputDims); return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp index 88b2cdb8a7..348abfaa1d 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Slice.cpp @@ -4,7 +4,7 @@ //===------------------ Slice.cpp - ONNX Operations ---------------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -47,7 +47,7 @@ LogicalResult ONNXSliceOpShapeHelper::computeShape() { int64_t axis = val.getLiteral(); if (axis < 0) axis += dataRank; - if (!(axis >= 0 && axis < (int64_t)dataRank)) + if (!(axis >= 0 && axis < static_cast(dataRank))) return op->emitError("Axes contains an out-of-bound index"); axesIntLit.emplace_back(axis); } @@ -126,8 +126,8 @@ LogicalResult ONNXSliceOpShapeHelper::computeShape() { if (steps[i].isUndefined()) { // have one unset, put the defaults (start was already at zero, so we // are fine). - starts[i] = LiteralIndexExpr(0); - steps[i] = LiteralIndexExpr(1); + starts[i] = LitIE(0); + steps[i] = LitIE(1); DimIndexExpr dimInput = createIE->getShapeAsDim(data, i); ends[i] = dimInput; outputDims[i] = dimInput; @@ -163,7 +163,8 @@ LogicalResult ONNXSliceOp::inferShapes( if (!isNoneValue(axes) && !getONNXConstantOp(axes)) return success(); - const auto startsType = dyn_cast(getStarts().getType()); + const auto startsType = + mlir::dyn_cast(getStarts().getType()); assert(startsType != nullptr && "starts type is not a RankedTensorType"); auto startsDim = startsType.getShape()[0]; { @@ -175,7 +176,7 @@ LogicalResult ONNXSliceOp::inferShapes( // If axes is not specified, default to [0, ..., ndim-1] if (isNoneValue(axes)) { SmallVector vals = {}; - for (size_t s = 0; s < (size_t)startsDim; ++s) + for (size_t s = 0; s < static_cast(startsDim); ++s) vals.emplace_back(s); auto constantDenseAttribute = DenseElementsAttr::get(tensorType, llvm::ArrayRef(vals)); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp index ce3f4f84d0..55cb4bdd35 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/SpaceToDepth.cpp @@ -33,7 +33,9 @@ LogicalResult ONNXSpaceToDepthOpShapeHelper::computeShape() { Value input = operandAdaptor.getInput(); int64_t blocksize = operandAdaptor.getBlocksize(); assert(blocksize > 0 && "blocksize should be strictly positive"); - + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); assert(inputRank == 4 && "Unexpected input tensor rank"); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp index be2eeaa887..e469239240 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Split.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXCommonSplitOpShapeHelper::customComputeShape( unsigned int numOfResults = splitOp.getNumResults(); Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(input); // Checking value of axis parameter. @@ -108,8 +111,8 @@ LogicalResult ONNXSplitOpShapeHelper::computeShape() { // None is fine, indexExprArray will be empty. } else { createIE->getIntFromArrayAsSymbols(split, indexExprArray); - assert(IndexExpr::isLiteral(indexExprArray) && - "dynamic split not yet supported"); + if (!IndexExpr::isLiteral(indexExprArray)) + return failure(); } return customComputeShape(indexExprArray); } @@ -124,8 +127,8 @@ LogicalResult ONNXSplitV13OpShapeHelper::computeShape() { // None is fine, indexExprArray will be empty. } else { createIE->getIntFromArrayAsSymbols(split, indexExprArray); - assert(IndexExpr::isLiteral(indexExprArray) && - "dynamic split not yet supported"); + if (!IndexExpr::isLiteral(indexExprArray)) + return failure(); } return customComputeShape(indexExprArray); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp index 786f1e136a..7a31d94257 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Squeeze.cpp @@ -15,6 +15,8 @@ #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#define DEBUG_TYPE "onnx-squeeze" + using namespace mlir; using namespace mlir::OpTrait::util; using namespace onnx_mlir; @@ -42,6 +44,9 @@ LogicalResult ONNXCommonSqueezeOpShapeHelper::customComputeShape( typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); DimsExpr outputDims; Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t dataRank = createIE->getShapedTypeRank(data); // Init state. @@ -55,9 +60,11 @@ LogicalResult ONNXCommonSqueezeOpShapeHelper::customComputeShape( createIE->getShapeAsSymbols(data, dataShape); for (int i = 0; i < dataRank; ++i) { // Check if the dimension to squeeze is a literal and in range. - if (!dataShape[i].isLiteral()) - return op->emitError( - "Can not squeeze from dynamic dimensions at this time"); + if (!dataShape[i].isLiteral()) { + LLVM_DEBUG(llvm::errs() + << "Can not squeeze from dynamic dimensions at this time"); + return failure(); + } int64_t shape = dataShape[i].getLiteral(); assert(shape != ShapedType::kDynamic && "Compile time shape should be nonnegative"); @@ -71,9 +78,11 @@ LogicalResult ONNXCommonSqueezeOpShapeHelper::customComputeShape( // Normalize the axis values, record modified values in squeezedDims. for (uint64_t i = 0; i < squeezedDims.size(); ++i) { // Check if the dimension to squeeze is a literal and in range. - if (!squeezedDims[i].isLiteral()) - return op->emitError( - "Can not squeeze from dynamic dimensions at this time"); + if (!squeezedDims[i].isLiteral()) { + LLVM_DEBUG(llvm::errs() + << "Can not squeeze from dynamic dimensions at this time"); + return failure(); + } int64_t a = squeezedDims[i].getLiteral(); if (a < -dataRank || a >= dataRank) return op->emitError("Invalid axis value"); @@ -189,9 +198,6 @@ template struct ONNXCommonSqueezeOpShapeHelper; // Folder //===----------------------------------------------------------------------===// OpFoldResult ONNXSqueezeOp::fold(FoldAdaptor adaptor) { - // Fold type - if (failed(inferShapes(nullptr))) - return nullptr; // Fold value if (!adaptor.getData() || !adaptor.getAxes()) { @@ -199,8 +205,8 @@ OpFoldResult ONNXSqueezeOp::fold(FoldAdaptor adaptor) { return nullptr; } - assert(hasStaticShape(getSqueezed().getType()) && - "Shape should be static when the inputs are constant"); + if (!hasStaticShape(getSqueezed().getType())) + return nullptr; OnnxElementsAttrBuilder elementsBuilder(getContext()); return elementsBuilder.reshape(mlir::cast(adaptor.getData()), @@ -208,9 +214,6 @@ OpFoldResult ONNXSqueezeOp::fold(FoldAdaptor adaptor) { } OpFoldResult ONNXSqueezeV11Op::fold(FoldAdaptor adaptor) { - // Fold the type of tensor - if (failed(inferShapes(nullptr))) - return nullptr; // Fold the value in tensor if (!adaptor.getData()) { @@ -218,8 +221,8 @@ OpFoldResult ONNXSqueezeV11Op::fold(FoldAdaptor adaptor) { return nullptr; } - assert(hasStaticShape(getSqueezed().getType()) && - "Shape should be static when the inputs are constant"); + if (!hasStaticShape(getSqueezed().getType())) + return nullptr; OnnxElementsAttrBuilder elementsBuilder(getContext()); return elementsBuilder.reshape(mlir::cast(adaptor.getData()), diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp index 96f403c409..6819b4c81f 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Tile.cpp @@ -29,6 +29,9 @@ LogicalResult ONNXTileOpShapeHelper::computeShape() { ONNXTileOpAdaptor operandAdaptor(operands); // Get info about input data operand. Value input = operandAdaptor.getInput(); + if (!hasShapeAndRank(input)) { + return failure(); + } int64_t inputRank = createIE->getShapedTypeRank(input); Value repeats = operandAdaptor.getRepeats(); // Compute outputDims diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp index 50e8663983..05a11d8189 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Transpose.cpp @@ -30,6 +30,9 @@ LogicalResult ONNXTransposeOpShapeHelper::computeShape() { ONNXTransposeOp transposeOp = llvm::cast(op); Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } auto rank = createIE->getShapedTypeRank(data); // Transposition which handles the default case of diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp index f9177073e0..842ed0b767 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unique.cpp @@ -22,6 +22,9 @@ LogicalResult ONNXUniqueOpShapeHelper::computeShape() { ONNXUniqueOpAdaptor operandAdaptor(operands, op->getAttrDictionary()); // Get info about X and K operands. Value X = operandAdaptor.getX(); + if (!hasShapeAndRank(X)) { + return failure(); + } int64_t rank = createIE->getShapedTypeRank(X); std::optional optionalAxis = operandAdaptor.getAxis(); // Generate the output dims. diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp index 5603ae408b..3bb049aa69 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Unsqueeze.cpp @@ -33,6 +33,9 @@ LogicalResult ONNXCommonUnsqueezeOpShapeHelper::customComputeShape( typename OP_TYPE::Adaptor operandAdaptor(operands, op->getAttrDictionary()); DimsExpr outputDims; Value data = operandAdaptor.getData(); + if (!hasShapeAndRank(data)) { + return failure(); + } int64_t dataRank = createIE->getShapedTypeRank(data); // Init state. @@ -64,7 +67,7 @@ LogicalResult ONNXCommonUnsqueezeOpShapeHelper::customComputeShape( if (std::find(unsqueezedAxes.begin(), unsqueezedAxes.end(), i) != unsqueezedAxes.end()) // found i in unsqueeze axles. - outputDims.emplace_back(LiteralIndexExpr(1)); + outputDims.emplace_back(LitIE(1)); else outputDims.emplace_back(createIE->getShapeAsDim(data, j++)); diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp index bdb06ab04a..ddb18c847b 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Upsample.cpp @@ -4,7 +4,7 @@ //===------------------ Upsample.cpp - ONNX Operations --------------------===// // -// Copyright 2019-2023 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -52,7 +52,7 @@ LogicalResult ONNXUpsampleOpShapeHelper::computeShape() { // When shape is also constant, replace questionmark by actual value. double dim = xShape[i].getLiteral(); double scale = valueAttr.getValues()[i].getValueAsDouble(); - outputDims[i] = LiteralIndexExpr((int64_t)(dim * scale)); + outputDims[i] = LitIE(static_cast(dim * scale)); } } } diff --git a/src/Dialect/ONNX/ONNXTraits.hpp b/src/Dialect/ONNX/ONNXTraits.hpp new file mode 100644 index 0000000000..cf04bd37e8 --- /dev/null +++ b/src/Dialect/ONNX/ONNXTraits.hpp @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- ONNXTraits.hpp - ONNX Op Traits --------------------===// +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. +// +// ============================================================================= +// +// This file defines traits of ONNX ops. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace OpTrait { + +template +class OpVersionTrait { +public: + template + class Impl : public OpTrait::TraitBase { + public: + int getOpVersion() { return version; } + }; +}; + +} // namespace OpTrait +} // namespace mlir diff --git a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp index 3b1318c15f..9072f0e227 100644 --- a/src/Dialect/ONNX/ONNXUnsupportedOps.hpp +++ b/src/Dialect/ONNX/ONNXUnsupportedOps.hpp @@ -25,7 +25,6 @@ UNSUPPORTED_OPS(ONNXAdagradOp) UNSUPPORTED_OPS(ONNXAdamOp) UNSUPPORTED_OPS(ONNXArrayFeatureExtractorOp) -UNSUPPORTED_OPS(ONNXBatchNormalizationOp) UNSUPPORTED_OPS(ONNXBinarizerOp) UNSUPPORTED_OPS(ONNXBlackmanWindowOp) UNSUPPORTED_OPS(ONNXCastMapOp) @@ -37,7 +36,6 @@ UNSUPPORTED_OPS(ONNXDeformConvOp) UNSUPPORTED_OPS(ONNXDictVectorizerOp) UNSUPPORTED_OPS(ONNXFeatureVectorizerOp) UNSUPPORTED_OPS(ONNXGradientOp) -UNSUPPORTED_OPS(ONNXGridSampleOp) UNSUPPORTED_OPS(ONNXHammingWindowOp) UNSUPPORTED_OPS(ONNXHannWindowOp) UNSUPPORTED_OPS(ONNXImputerOp) @@ -48,7 +46,6 @@ UNSUPPORTED_OPS(ONNXLpPoolOp) UNSUPPORTED_OPS(ONNXMaxPoolOp) UNSUPPORTED_OPS(ONNXMaxUnpoolOp) UNSUPPORTED_OPS(ONNXMelWeightMatrixOp) -UNSUPPORTED_OPS(ONNXMishOp) UNSUPPORTED_OPS(ONNXMomentumOp) UNSUPPORTED_OPS(ONNXMultinomialOp) UNSUPPORTED_OPS(ONNXNegativeLogLikelihoodLossOp) @@ -76,7 +73,10 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV11Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op) CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op) CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op) +CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV20Op) +CONVERTED_TO_SUPPORTED_OPS(ONNXGridSampleV16Op) CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp) +CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV13Op) CONVERTED_TO_SUPPORTED_OPS(ONNXPadV11Op) diff --git a/src/Dialect/ONNX/Transforms/ConstProp.cpp b/src/Dialect/ONNX/Transforms/ConstProp.cpp index 49d4855042..7b5aea28ba 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.cpp +++ b/src/Dialect/ONNX/Transforms/ConstProp.cpp @@ -34,6 +34,7 @@ #include "src/Pass/Passes.hpp" #include "src/Support/TypeUtilities.hpp" +#include #include #include @@ -186,23 +187,6 @@ Value createMinimumValueForClip( llvm::APFloat::getLargest, true, llvm::APInt::getMinValue); } -WideNum asWideNum(double n, Type elemType) { - return wideZeroDispatch(elemType, [n](auto wideZero) { - using cpptype = decltype(wideZero); - constexpr BType TAG = toBType; - return WideNum::widen(static_cast(n)); - }); -} - -/// Checks whether a constant tensor's elements are all equal to a given scalar. -bool isConstOf(Value constValue, double n) { - ElementsAttr constElements = getConstValueElements(constValue); - Type elemType = constElements.getElementType(); - assert(!elemType.isInteger(1) && "booleans are not supported"); - WideNum w = asWideNum(n, elemType); - return ElementsAttrBuilder::allEqual(constElements, w); -} - // Extracts number from a scalar constant value. WideNum getScalarNum(Value constValue) { ElementsAttr elements = getConstValueElements(constValue); @@ -431,6 +415,23 @@ struct ElementWiseUnaryOpImpl> { static T eval(T val) { return ~val; } }; +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { + if constexpr (std::is_integral_v) { + // Cast to int64_t to disambiguate abs if T is signed. + // Otherwise, just return the value. + if constexpr (std::is_signed_v) { + return std::abs(static_cast(val)); + } else { + return val; + } + } else { + return std::fabs(val); + }; + } +}; + template struct ElementWiseUnaryOpImpl> { static T eval(T val) { return ceil(val); } @@ -487,10 +488,15 @@ struct ElementWiseUnaryOpImpl> { }; template -struct ElementWiseUnaryOpImpl> { +struct ElementWiseUnaryOpImpl> { static T eval(T val) { return (1 / val); } }; +template +struct ElementWiseUnaryOpImpl> { + static T eval(T val) { return std::nearbyint(val); } +}; + template auto elementwiseUnaryOpFunction(Type elemType) { return getWideNumWrappedTemplateFunction( + replacingValue.getDefiningOp()); + + auto batchAxis = reverseSequenceOP.getBatchAxis(); + + ElementsAttr inputElements = getConstValueElements(inputValue); + ElementsAttr sequenceElements = getConstValueElements(sequenceValue); + OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext()); + ElementsAttr reverseSequencedElements = elementsBuilder.reverseSequence( + inputElements, sequenceElements, batchAxis); + return createReplacingConstantOp( + rewriter, replacingValue, reverseSequencedElements); +} + //===----------------------------------------------------------------------===// // Code to perform constant propagation for unsqueeze. //===----------------------------------------------------------------------===// Value ConstPropUnsqueeze( PatternRewriter &rewriter, Value replacingValue, Value input) { + assert(llvm::cast(replacingValue.getType()).hasStaticShape()); ArrayRef reshapedShape = getShape(replacingValue.getType()); ElementsAttr reshapedElements = ConstPropReshapeImpl(rewriter, replacingValue, input, reshapedShape); @@ -884,6 +912,7 @@ Value ConstPropUnsqueeze( Value ConstPropSqueeze( PatternRewriter &rewriter, Value replacingValue, Value input) { + assert(llvm::cast(replacingValue.getType()).hasStaticShape()); ArrayRef reshapedShape = getShape(replacingValue.getType()); ElementsAttr reshapedElements = ConstPropReshapeImpl(rewriter, replacingValue, input, reshapedShape); @@ -1047,6 +1076,7 @@ Value ConstPropGather(PatternRewriter &rewriter, Value replacingValue, Value ConstPropReshape( PatternRewriter &rewriter, Value replacingValue, Value constValue) { + assert(llvm::cast(replacingValue.getType()).hasStaticShape()); ArrayRef reshapedShape = getShape(replacingValue.getType()); ElementsAttr reshapedElements = ConstPropReshapeImpl(rewriter, replacingValue, constValue, reshapedShape); diff --git a/src/Dialect/ONNX/Transforms/ConstProp.td b/src/Dialect/ONNX/Transforms/ConstProp.td index d712e0a1b2..0869caafd4 100644 --- a/src/Dialect/ONNX/Transforms/ConstProp.td +++ b/src/Dialect/ONNX/Transforms/ConstProp.td @@ -71,6 +71,11 @@ def HasStaticShape: Constraint; +def IsRankedShapedType: Constraint, + "A value has a rank" +>; + def HasIntegerElementType: Constraint(getElementType($_self.getType()))">, "A value has integer element type" @@ -84,10 +89,18 @@ def IsConstOfOnes : Constraint< CPred<"isDenseONNXConstant($_self) && isConstOf($_self, 1.0)">, "Value is an all-ones constant tensor">; +def IsNotScalar: Constraint< + CPred<"!isScalarTensor($_self)">, + "Value is not a scalar value">; + def ValuesHaveSameType : Constraint< CPred<"$0.getType() == $1.getType()">, "Values have same type">; +def ValuesHaveSameDType : Constraint< + CPred<"$0.getType().cast().getElementType() == $1.getType().cast().getElementType()">, + "Values have same dtype">; + def IsMatMulIntegerLhsZero: Constraint< CPred<"isMatMulIntegerLhsZero($0, $1)">, "MatMulInteger lhs matrix is zero for given zero point">; @@ -138,6 +151,9 @@ def CreateCastOfConst : def CreateBitwiseNotOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; +def CreateAbsOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + def CreateCeilOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; @@ -174,6 +190,9 @@ def CreateReluOfConst : def CreateReciprocalOfConst : NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; +def CreateRoundOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + def CreateMaxOfConst : NativeCodeCall<"ConstPropVariadicElementwiseBinary($_builder, $0, $1)">; @@ -255,6 +274,9 @@ def CreateGemmOfConsts : def CreateTransposeOfConst : NativeCodeCall<"ConstPropTranspose($_builder, $0, $1)">; +def CreateReverseSequenceOfConst : + NativeCodeCall<"ConstPropReverseSequence($_builder, $0, $1, $2)">; + def CreateUnsqueezeOfConst: NativeCodeCall<"ConstPropUnsqueeze($_builder, $0, $1)">; @@ -298,9 +320,9 @@ def CreateScatterNDOfConst : // Use commutativity to normalize constants in the second position of Add. def AddConstCommutative1 : NamedPat<"AddConstCommutative1", // From add(c, x). - (ONNXAddOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), + (ONNXAddOp:$addOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), // To add(x, c). - (ONNXAddOp $x, $c), + (ONNXAddOp $x, $c, (location $addOp)), // To avoid infinite loop, constrain the first arguments to be anything but a constant. [(IsNotAConstant:$x)]>; @@ -316,6 +338,7 @@ def AddConstAssociative1 : NamedPat<"AddConstAssociative1", (ONNXAddOp $c1, $c2)), [(IsNotAConstant:$x), (HasOneUse:$lhs)]>; +// Do not apply this when both x and c are scalar since it is cheap to do. def AddConstAssociative2 : NamedPat<"AddConstAssociative2", // From add(add(x, c), y). (ONNXAddOp @@ -325,8 +348,10 @@ def AddConstAssociative2 : NamedPat<"AddConstAssociative2", (ONNXAddOp (ONNXAddOp $x, $y), $c), - [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs)]>; + [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), + (IsNotScalar:$x), (IsNotScalar:$c)]>; +// Do not apply this when both y and c are scalar since it is cheap to do. def AddConstAssociative3 : NamedPat<"AddConstAssociative3", // From add(x, add(y, c)). (ONNXAddOp @@ -336,7 +361,8 @@ def AddConstAssociative3 : NamedPat<"AddConstAssociative3", (ONNXAddOp (ONNXAddOp $x, $y), $c), - [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs)]>; + [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs), + (IsNotScalar:$y), (IsNotScalar:$c)]>; def AddConstAssociative4 : NamedPat<"AddConstAssociative4", // From add(add(x, c1), add(y, c2)). @@ -358,7 +384,7 @@ def AddConstProp : NamedPat<"AddConstProp", (CreateAddOfTwoConst $addOp, $lhs, $rhs), // Additional constraints (dense) [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$addOp)]>; + (SatisfiesExpansionBound:$addOp), (HasStaticShape:$addOp)]>; // TODO: Expand $x to $result's shape instead of requiring ValuesHaveSameType. def AddZerosOnRhs : NamedPat<"AddZerosOnRhs", @@ -383,7 +409,7 @@ def SubConstProp : NamedPat<"SubConstProp", // To c1-c2 (CreateSubOfTwoConst $subOp, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$subOp)]>; + (SatisfiesExpansionBound:$subOp), (HasStaticShape:$subOp)]>; // TODO: Expand $a to $result's shape instead of requiring ValuesHaveSameType. def SubZerosOnRhs : NamedPat<"SubZerosOnRhs", @@ -412,6 +438,14 @@ def BitwiseNotConstProp : NamedPat<"BitwiseNotofConst", (CreateBitwiseNotOfConst $bitwiseNotOp, $input), [(IsFromDenseONNXConstantOp:$input)]>; +// Constant Propagation for Abs +def AbsConstProp : NamedPat<"AbsofConst", + // From abs(c). + (ONNXAbsOp:$ceilOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c + (CreateAbsOfConst $ceilOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + // Constant Propagation for Ceil def CeilConstProp : NamedPat<"CeilofConst", // From ceil(c). @@ -492,6 +526,14 @@ def ReciprocalConstProp : NamedPat<"ReciprocalofConst", (CreateReciprocalOfConst $reciprocalOp, $input), [(IsFromDenseONNXConstantOp:$input)]>; +// Constant Propagation for Round +def RoundConstProp : NamedPat<"RoundofConst", + // From round(c) + (ONNXRoundOp:$roundOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To new_c. + (CreateRoundOfConst $roundOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + // Change a subtraction of a constant c by an addition of -c. Helpfull to combine // with other add optimizations. def SubConstToNeg : NamedPat<"SubConstToNeg", @@ -525,6 +567,14 @@ def ReciprocalOfConst : NamedPat<"ReciprocalOfConst", (CreateReciprocalOfConst $reciprocalOp, $input), [(IsFromDenseONNXConstantOp:$input)]>; +// Constant Propagation for Round +def RoundofConst : NamedPat<"RoundofConst", + // From onnx.Round(c) + (ONNXRoundOp:$reluOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_)), + // To round_even(c) + (CreateRoundOfConst $reluOp, $input), + [(IsFromDenseONNXConstantOp:$input)]>; + //===----------------------------------------------------------------------===// // Const propagation patterns for variadic elementwise operations. //===----------------------------------------------------------------------===// @@ -537,7 +587,7 @@ def MaxConstProp : NamedPat<"MaxConstProp", (CreateMaxOfConst $maxOp, $operandList), // Constraints [(IsVariadicOperandDenseONNXConstantOp:$operandList), - (SatisfiesExpansionBound:$maxOp)]>; + (SatisfiesExpansionBound:$maxOp), (HasStaticShape:$maxOp)]>; // Constant Propagation for Min def MinConstProp : NamedPat<"MinConstProp", @@ -547,7 +597,7 @@ def MinConstProp : NamedPat<"MinConstProp", (CreateMinOfConst $minOp, $operandList), // Constraints [(IsVariadicOperandDenseONNXConstantOp:$operandList), - (SatisfiesExpansionBound:$minOp)]>; + (SatisfiesExpansionBound:$minOp), (HasStaticShape:$minOp)]>; // Constant Propagation for Sum def SumConstProp : NamedPat<"SumConstProp", @@ -557,7 +607,7 @@ def SumConstProp : NamedPat<"SumConstProp", (CreateSumOfConst $sumOp, $operandList), // Constraints [(IsVariadicOperandDenseONNXConstantOp:$operandList), - (SatisfiesExpansionBound:$sumOp)]>; + (SatisfiesExpansionBound:$sumOp), (HasStaticShape:$sumOp)]>; //===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise MUL operations. @@ -567,9 +617,9 @@ def SumConstProp : NamedPat<"SumConstProp", // Use commutativity to normalize constants in the second position of Mul. def MulConstCommutative1 : NamedPat<"MulConstCommutative1", // From mul(c, x). - (ONNXMulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), + (ONNXMulOp:$mulOp (ONNXConstantOp:$c $_, $_, $_, $_, $_, $_, $_, $_), $x), // To mul(x, c). - (ONNXMulOp $x, $c), + (ONNXMulOp $x, $c, (location $mulOp)), // To avoid infinite loop, constrain the first arguments to be anything but a constant. [(IsNotAConstant:$x)]>; @@ -585,6 +635,7 @@ def MulConstAssociative1 : NamedPat<"MulConstAssociative1", (ONNXMulOp $c1, $c2)), [(IsNotAConstant:$x), (HasOneUse:$lhs)]>; +// Do not apply this when both x and c are scalar since it is cheap to do. def MulConstAssociative2 : NamedPat<"MulConstAssociative2", // From mul(mul(x, c), y). (ONNXMulOp @@ -594,8 +645,10 @@ def MulConstAssociative2 : NamedPat<"MulConstAssociative2", (ONNXMulOp (ONNXMulOp $x, $y), $c), - [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs)]>; + [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$lhs), + (IsNotScalar:$x), (IsNotScalar:$c)]>; +// Do not apply this when both y and c are scalar since it is cheap to do. def MulConstAssociative3 : NamedPat<"MulConstAssociative3", // From mul(x, mul(y, c)). (ONNXMulOp @@ -605,7 +658,8 @@ def MulConstAssociative3 : NamedPat<"MulConstAssociative3", (ONNXMulOp (ONNXMulOp $x, $y), $c), - [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs)]>; + [(IsNotAConstant:$x), (IsNotAConstant:$y), (HasOneUse:$rhs), + (IsNotScalar:$y), (IsNotScalar:$c)]>; def MulConstAssociative4 : NamedPat<"MulConstAssociative4", // From mul(mul(x, c1), mul(y, c2)). @@ -627,7 +681,7 @@ def MulConstProp : NamedPat<"MulConstProp", (CreateMulOfTwoConst $mulOp, $lhs, $rhs), // Multiplication constraints [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$mulOp)]>; + (SatisfiesExpansionBound:$mulOp), (HasStaticShape:$mulOp)]>; // TODO: Expand $x to $result's shape instead of requiring ValuesHaveSameType. def MulOnesOnRhs : NamedPat<"MulOnesOnRhs", @@ -652,7 +706,7 @@ def DivConstProp : NamedPat<"DivConstProp", (CreateDivOfTwoConst $divOp, $lhs, $rhs), // Division constraints [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$divOp)]>; + (SatisfiesExpansionBound:$divOp), (HasStaticShape:$divOp)]>; // TODO: Expand $x to $result's shape instead of requiring ValuesHaveSameType. def DivOnesOnRhs : NamedPat<"DivOnesOnRhs", @@ -678,7 +732,7 @@ def BitwiseAndConstPropPattern : NamedPat<"BitwiseAndConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateBitwiseAndOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXBitwiseOrOp @@ -690,7 +744,7 @@ def BitwiseOrConstPropPattern : NamedPat<"BitwiseOrConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateBitwiseOrOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXAndOp @@ -702,7 +756,7 @@ def AndConstPropPattern : NamedPat<"AndConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateAndOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXOrOp @@ -714,7 +768,7 @@ def OrConstPropPattern : NamedPat<"OrConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateOrOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXorOp @@ -726,7 +780,7 @@ def XorConstPropPattern : NamedPat<"XorConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateXorOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXEqualOp @@ -741,7 +795,7 @@ def EqualConstProp : NamedPat<"EqualConstProp", (CreateEqualOfTwoConst $result, $lhs, $rhs), // constraints [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (IsIntOrFloatType:$lhs), (SatisfiesExpansionBound:$result)]>; + (IsIntOrFloatType:$lhs), (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXLessOp @@ -753,7 +807,7 @@ def LessConstPropPattern : NamedPat<"LessConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateLessOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXGreaterOp @@ -765,7 +819,7 @@ def GreaterConstPropPattern : NamedPat<"GreaterConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateGreaterOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXLessOrEqualOp @@ -777,7 +831,7 @@ def LessOrEqualConstPropPattern : NamedPat<"LessOrEqualConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateLessOrEqualOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXGreaterOrEqualOp @@ -789,7 +843,7 @@ def GreaterOrEqualConstPropPattern : NamedPat<"GreaterOrEqualConstPropPattern", (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)), (CreateGreaterOrEqualOfTwoConst $result, $lhs, $rhs), [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$result)]>; + (SatisfiesExpansionBound:$result), (HasStaticShape:$result)]>; //===----------------------------------------------------------------------===// // Constant propagation for ONNXModOp @@ -802,7 +856,7 @@ def ModConstPropPattern : NamedPat<"ModConstPropPattern", $fmod), (CreateModOfTwoConst $modOp, $A, $B), [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), - (SatisfiesExpansionBound:$modOp)]>; + (SatisfiesExpansionBound:$modOp), (HasStaticShape:$modOp)]>; //===----------------------------------------------------------------------===// // Pattern for Clip. @@ -842,7 +896,7 @@ def WhereConstProp : NamedPat<"WhereConstProp", // Where constraints [(IsFromDenseONNXConstantOp:$condition), (IsFromDenseONNXConstantOp:$X), (IsFromDenseONNXConstantOp:$Y), - (SatisfiesExpansionBound:$whereOp)]>; + (SatisfiesExpansionBound:$whereOp), (HasStaticShape:$whereOp)]>; //===----------------------------------------------------------------------===// // Patterns for Reduce ops. @@ -985,7 +1039,7 @@ def GemmConstProp : NamedPat<"GemmConstProp", (CreateGemmOfConsts $gemmOp, $A, $B, $C), [(IsFromDenseONNXConstantOp:$A), (IsFromDenseONNXConstantOp:$B), (IsFromDenseONNXConstantOpOrNone:$C), - (SatisfiesExpansionBound:$gemmOp)]>; + (SatisfiesExpansionBound:$gemmOp), (HasStaticShape:$gemmOp)]>; //===----------------------------------------------------------------------===// // Patterns to enable opportunities with Transpose operations. @@ -998,6 +1052,17 @@ def TransposeofConst : NamedPat<"TransposeofConst", (CreateTransposeOfConst $resOp, $input), [(IsFromDenseONNXConstantOp:$input)]>; +//===----------------------------------------------------------------------===// +// Patterns to enable opportunities with ReverseSequence operations. +//===----------------------------------------------------------------------===// + +def ReverseSequenceofConst : NamedPat<"ReverseSequenceofConst", + // From ReverseSequenceOp(c, ba, ta) + (ONNXReverseSequenceOp:$resOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), + (ONNXConstantOp:$sequence_lens $_, $_, $_, $_, $_, $_, $_, $_), $batch_axis, $time_axis), + (CreateReverseSequenceOfConst $resOp, $input, $sequence_lens), + [(IsFromDenseONNXConstantOp:$input),(IsFromDenseONNXConstantOp:$sequence_lens)]>; + //===----------------------------------------------------------------------===// // Patterns to enable opportunities with Unsqueeze operations. //===----------------------------------------------------------------------===// @@ -1007,14 +1072,14 @@ def UnsqueezeofConst : NamedPat<"UnsqueezeofConst", (ONNXUnsqueezeOp:$resOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), $_), // To c' where c' is the unsqueezed value. (CreateUnsqueezeOfConst $resOp, $input), - [(IsFromDenseONNXConstantOp:$input)]>; + [(IsFromDenseONNXConstantOp:$input), (HasStaticShape:$resOp)]>; def UnsqueezeV11ofConst : NamedPat<"UnsqueezeV11ofConst", // From Unsqueeze (c, axis) (ONNXUnsqueezeV11Op:$resOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), $_), // To c' where c' is the unsqueezed value. (CreateUnsqueezeOfConst $resOp, $input), - [(IsFromDenseONNXConstantOp:$input)]>; + [(IsFromDenseONNXConstantOp:$input), (HasStaticShape:$resOp)]>; //===----------------------------------------------------------------------===// // Patterns to enable opportunities with Squeeze operations. @@ -1025,14 +1090,14 @@ def SqueezeofConst : NamedPat<"SqueezeofConst", (ONNXSqueezeOp:$resOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), $_), // To c' where c' is the unsqueezed value. (CreateSqueezeOfConst $resOp, $input), - [(IsFromDenseONNXConstantOp:$input)]>; + [(IsFromDenseONNXConstantOp:$input), (HasStaticShape:$resOp)]>; def SqueezeV11ofConst : NamedPat<"SqueezeV11ofConst", // From Squeeze (c, axis) (ONNXSqueezeV11Op:$resOp (ONNXConstantOp:$input $_, $_, $_, $_, $_, $_, $_, $_), $_), // To c' where c' is the unsqueezed value. (CreateSqueezeOfConst $resOp, $input), - [(IsFromDenseONNXConstantOp:$input)]>; + [(IsFromDenseONNXConstantOp:$input), (HasStaticShape:$resOp)]>; //===----------------------------------------------------------------------===// // Patterns to enable opportunities with Slice operations. @@ -1071,7 +1136,7 @@ def PadOfConst : NamedPat<"PadOfConst", def ConcatofConst : NamedPat<"ConcatofConst", (ONNXConcatOp:$resOp $input, $axis), (CreateConcatOfConst $resOp, $input, $axis), - [(IsVariadicOperandDenseONNXConstantOp:$input)] + [(IsVariadicOperandDenseONNXConstantOp:$input), (IsRankedShapedType:$resOp)] >; //===----------------------------------------------------------------------===// @@ -1180,6 +1245,7 @@ def PowConstProp : NamedPat<"PowConstProp", (CreatePowOfTwoConst $powOp, $lhs, $rhs), // Power constraints [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs), - (SatisfiesExpansionBound:$powOp)]>; + (SatisfiesExpansionBound:$powOp), (ValuesHaveSameDType $lhs, $rhs), + (HasStaticShape:$powOp)]>; #endif // ONNX_CONSTPROP diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index e2ffd745e3..4eb35f223b 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -20,13 +20,19 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/Debug.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrHelper.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" #include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" @@ -118,7 +124,7 @@ DenseElementsAttr createDenseArrayAttrOrEmpty( } Value createSequenceConstructOp( - PatternRewriter &rewriter, mlir::Value seq, mlir::OperandRange inputs) { + PatternRewriter &rewriter, Value seq, OperandRange inputs) { Type resType = seq.getType(); Location loc = seq.getLoc(); Value position = rewriter.create(loc); @@ -319,7 +325,7 @@ bool hasUnitStrides(ArrayAttr strides) { // Check if v's shape N x C x D1 x D2 ... x Dn has static dims D1 ... Dn. bool hasStaticSpatialDims(Value v) { - ShapedType type = cast(v.getType()); + ShapedType type = mlir::cast(v.getType()); if (!type.hasRank()) return false; // Shape has the form N x C x D1 x D2 ... x Dn. @@ -331,16 +337,141 @@ bool hasStaticSpatialDims(Value v) { return llvm::none_of(Ds, ShapedType::isDynamic); } +// In the following pattern, a SequenceAt can be replaced with Split +// %seq = onnx.SplitToSequence(%input, %split) {%axis : } +// %res = onnx.SequenceAt(%seq, %position) +// We just try to avoid using the sequence related ops, which are less +// optimized, or even not implemented in onnx-mlir. +// In the targeted use case, %split and %position are constant scalar and the +// tensor of %input and %res have static shape. +// This condition greatly reduces the complexity of code generation to replace +// SequenceAt with split op +// %res = onnx.Split(%input, onnx.expand(%split, %input.shape()[%axis])) +// {%axis : } : %position +// onnx.expand(%split, %input.shape()[%axis]) can be a constant under the +// assumed condition. +// Here %position has to be compiler time constant. +// For multiple SequenceAt from the same SplitToSequence result, the onnx.split +// for different SequenceAt are expected to be merged by optimization. +// Alternatively, Slice can be used +// %res = onnx.Slice(%input, %start, %end, %step) +// The start, and end for slice will be onnx.constant: +// start: %position*%split for %axis, 0 for other dimensionis +// end: (%positiion+1)*%split for %axis, upper bound for other dimension +// step: 1 for all dimensions +// The split approach may have better performance than the alternative slice +// approach, because the slicing is done separately. + +bool canSequenceAtBeReplaced(Value sequenceAtResult) { + if (!hasStaticShape(sequenceAtResult.getType())) + return false; + + ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); + + Value inputSequence = op.getInputSequence(); + Value position = op.getPosition(); + + if (!isDenseONNXConstant(position)) + return false; + + // Input sequence should be defined with SplitToSequence + ONNXSplitToSequenceOp splitToSequence = + inputSequence.getDefiningOp(); + if (!splitToSequence) + return false; + + // Check the pattern of the SplitToSequence op + Value input = splitToSequence.getInput(); + if (!hasStaticShape(input.getType())) + return false; + Value split = splitToSequence.getSplit(); + if (!isScalarConstantTensor(split)) + return false; + + return true; +} + +Attribute upgradeGridSampleV16Mode(PatternRewriter &rewriter, Attribute mode) { + const auto stringMode = mlir::cast(mode); + if (stringMode.strref() == "bilinear") { + return rewriter.getStringAttr("linear"); + } + if (stringMode.strref() == "bicubic") { + return rewriter.getStringAttr("cubic"); + } + assert(stringMode.strref() == "nearest"); + return mode; +} + +Value replaceSequenceAt( + PatternRewriter &rewriter, Location loc, Value sequenceAtResult) { + ONNXSequenceAtOp op = sequenceAtResult.getDefiningOp(); + + Value inputSequence = op.getInputSequence(); + Value position = op.getPosition(); + + ONNXConstantOp positionConstant = + mlir::cast(position.getDefiningOp()); + int64_t positionInt = getScalarValue(positionConstant); + + ONNXSplitToSequenceOp splitToSequence = + mlir::cast(inputSequence.getDefiningOp()); + + Value input = splitToSequence.getInput(); + Value split = splitToSequence.getSplit(); + + ONNXConstantOp splitConstant = + mlir::cast(split.getDefiningOp()); + int64_t splitInt = getScalarValue(splitConstant); + int64_t axisInt = splitToSequence.getAxis(); + + auto shape = getShape(input.getType()); + + OnnxBuilder create(rewriter, loc); + + Type sequenceElementType = + mlir::cast(inputSequence.getType()).getElementType(); + mlir::SmallVector outputTypes( + shape[axisInt] / splitInt, sequenceElementType); + auto numSplit = create.constantInt64( + mlir::SmallVector(shape[axisInt] / splitInt, splitInt)); + auto resultRange = create.split(outputTypes, input, numSplit, axisInt); + auto rawResult = resultRange[positionInt]; + + if (rawResult.getType() == sequenceAtResult.getType()) + return rawResult; + + // Temporary code for the error in the model generated by torch.onnx.export + // The the dim of the reuslt of SequenceAt op is different from the element + // type of the sequence.. + // My assumption is that the exporter is confused with squeeze and unsqueeze + // followed by the SequenceAt. There are two cases in the model: + // clang-format off + // Case #1: + // %16 = "onnx.SequenceAt"(%14, %15) {onnx_node_name = "n0"} : + // (!onnx.Seq>, tensor) -> tensor<1x100xf32> + // %23 = "onnx.Unsqueeze"(%16, %22) {onnx_node_name = "n2"} : + // (tensor<1x100xf32>, tensor) -> tensor<1x1x100xf32> + // Case#2: + // %67 = "onnx.SequenceAt"(%66, %15) {onnx_node_name = "n0"} : + // (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + // %71 = "onnx.Sigmoid"(%67) {onnx_node_name = "node_Sigmoid_60"} : + // (tensor<1x1x100xf32>) -> tensor<1x1x100xf32> + // clang-format on + // Thus, the compiler squeeze the tensor if needed. + return create.squeeze( + sequenceAtResult.getType(), rawResult, create.constantInt64(axisInt)); +} + bool shouldDecomposeConvTransposeOp(Value convTransposeResult) { -#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE + if (onnx_mlir::disableConvTransposeDecomposeOption) { + // Disable the ONNXConvTransposeOp decomposition patterns. + return false; + } ONNXConvTransposeOp op = - cast(convTransposeResult.getDefiningOp()); + mlir::cast(convTransposeResult.getDefiningOp()); return hasShapeAndRank(convTransposeResult) && hasStaticSpatialDims(op.getX()) && hasStaticSpatialDims(op.getW()); -#else - // Disable the ONNXConvTransposeOp decomposition patterns. - return false; -#endif } // Split on the specified axis. The length of each output is one. @@ -489,9 +620,7 @@ namespace { /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Dialect/ONNX/Transforms/ONNXDecompose.inc" -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - -RankedTensorType createResultType( +RankedTensorType createReducedType( Type outputType, int64_t axisValue, bool keepDims) { RankedTensorType outputShapeType = mlir::dyn_cast(outputType); @@ -512,6 +641,8 @@ RankedTensorType createResultType( return resultType; } +#ifdef ONNX_MLIR_ENABLE_STABLEHLO + struct SoftmaxPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -531,7 +662,7 @@ struct SoftmaxPattern : public OpRewritePattern { rewriter.getIntegerType(64, /*isSigned=*/true), 1); ArrayAttr axisAttr = rewriter.getI64ArrayAttr({axisValue}); RankedTensorType resultType = - createResultType(inputType, axisValue, /*keepDims=*/true); + createReducedType(inputType, axisValue, /*keepDims=*/true); Value maxInput = rewriter.create( odsLoc, resultType, input, axisAttr, keepDimsAttr); Value subValue = @@ -609,6 +740,453 @@ struct ConcatFusePattern : public OpRewritePattern { } }; +// ONNXHardSwishOp(input) can be decomposed as: +// input * ONNXHardSigmoid input, with alpha = 1/6 and beta = 0.5. +struct DecomposeHardSwishPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXHardSwishOp hardSwishOp, PatternRewriter &rewriter) const final { + + auto input = hardSwishOp.getX(); + auto hardSigmoid = rewriter.create(hardSwishOp->getLoc(), + hardSwishOp.getType(), input, rewriter.getF32FloatAttr(1.0 / 6.0), + rewriter.getF32FloatAttr(0.5)); + rewriter.replaceOpWithNewOp( + hardSwishOp, hardSwishOp.getType(), hardSigmoid, input); + return success(); + } +}; + +/// Decompose BatchNormV9 to BatchNorm +struct DecomposeBatchNormV9ToBatchNorm + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ONNXBatchNormalizationV9Op batchNormOpV9, + PatternRewriter &rewriter) const final { + auto savedMeanRes = batchNormOpV9.getSavedMean(); + auto savedVarRes = batchNormOpV9.getSavedVar(); + if (!savedMeanRes.use_empty() || !savedVarRes.use_empty()) { + return rewriter.notifyMatchFailure(batchNormOpV9.getLoc(), + "saved_mean and saved_variance must have no use."); + } + auto batchNormOp = rewriter.create( + batchNormOpV9.getLoc(), + TypeRange{ + batchNormOpV9.getY().getType(), + batchNormOpV9.getOutMean().getType(), + batchNormOpV9.getOutVar().getType(), + }, + batchNormOpV9.getX(), batchNormOpV9.getScale(), batchNormOpV9.getB(), + batchNormOpV9.getMean(), batchNormOpV9.getVar(), + batchNormOpV9.getEpsilon(), batchNormOpV9.getMomentum()); + rewriter.replaceOp(batchNormOpV9, + {batchNormOp.getY(), batchNormOp.getRunningMean(), + batchNormOp.getRunningVar(), + rewriter.create(batchNormOpV9.getLoc()), + rewriter.create(batchNormOpV9.getLoc())}); + return success(); + } +}; + +/// Decompose BatchNorm to BatchNormInferenceMode +struct DecomposeBatchNormToBatchNormInferenceMode + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ONNXBatchNormalizationOp batchNormOp, + PatternRewriter &rewriter) const final { + + auto meanRes = batchNormOp.getRunningMean(); + auto varianceRes = batchNormOp.getRunningVar(); + if (!meanRes.use_empty() || !varianceRes.use_empty()) { + return rewriter.notifyMatchFailure( + batchNormOp.getLoc(), "mean and variance must have no use."); + } + + rewriter.replaceOp(batchNormOp, + {rewriter.create( + batchNormOp.getLoc(), batchNormOp.getY().getType(), + batchNormOp.getX(), batchNormOp.getScale(), batchNormOp.getB(), + batchNormOp.getInputMean(), batchNormOp.getInputVar(), + batchNormOp.getEpsilon(), batchNormOp.getMomentum()), + rewriter.create(batchNormOp.getLoc()), + rewriter.create(batchNormOp.getLoc())}); + return success(); + } +}; + +// Decompose a pad with negative padding size to slice + pad +// Only supports static shapes +struct DecomposeSlicePadPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXPadOp padOp, PatternRewriter &rewriter) const final { + auto constantPad = padOp.getPads().getDefiningOp(); + if (!constantPad) { + return failure(); + } + std::optional padValues; + if (auto intAttrs = constantPad.getValueInts()) { + padValues = intAttrs; + } else if (auto attrs = constantPad.getValue()) { + padValues = attrs; + } + if (!padValues) { + return failure(); + } + auto elementsAttr = llvm::dyn_cast(*padValues); + if (!elementsAttr) { + return failure(); + } + const auto padElements = onnx_mlir::getElementsArray(elementsAttr); + const auto padElementsArray = padElements.get(); + if (llvm::none_of(padElementsArray, [](const auto v) { return v < 0; })) { + // No slicing needed + return failure(); + } + if (!padOp.getAxes().getDefiningOp()) { + // This is possible to implement but makes the implementation more + // difficult, so skip for now + return failure(); + } + const auto inputType = padOp.getData().getType().cast(); + if (!inputType.hasStaticShape()) { + // We need a static shape to calculate the ends for slice + return failure(); + } + auto sliceOp = buildSliceOp(padOp, rewriter, padElementsArray, inputType); + auto newPadOp = buildPadOp(padOp, rewriter, padElementsArray, sliceOp); + rewriter.replaceOp(padOp, newPadOp); + return success(); + } + +private: + // Builds ands inserts a pad op, that is guaranteed to only pad and not + // slice + static Value buildPadOp(ONNXPadOp orignalPadOp, PatternRewriter &rewriter, + ArrayRef padElementsArray, ONNXSliceOp sliceOp) { + SmallVector pads; + for (const auto padElem : padElementsArray) { + pads.push_back((padElem < 0) ? 0 : padElem); + } + if (llvm::any_of(pads, [](const auto p) { return p > 0; })) { + auto padsConstOp = onnx_mlir::createConstantOp( + rewriter, orignalPadOp->getLoc(), rewriter.getI64ArrayAttr(pads)); + auto padOp = rewriter.create(orignalPadOp->getLoc(), + orignalPadOp.getType(), sliceOp, padsConstOp, + orignalPadOp.getConstantValue(), orignalPadOp.getAxes(), + orignalPadOp.getMode()); + return padOp; + } + return sliceOp; // No pad needed if we only slice + } + + // Builds and inserts a slice op, and its inputs, that handles negative + // pads + static ONNXSliceOp buildSliceOp(ONNXPadOp padOp, PatternRewriter &rewriter, + ArrayRef padElementsArray, ShapedType inputType) { + const auto inputShape = inputType.getShape(); + const size_t dims = padElementsArray.size() / 2; + + assert(inputShape.size() == dims); + SmallVector sliceShape; + for (size_t i = 0; i < dims; ++i) { + auto sliceDimSize = inputShape[i]; + if (padElementsArray[i] < 0) { + sliceDimSize += padElementsArray[i]; + } + if (padElementsArray[i + dims] < 0) { + sliceDimSize += padElementsArray[i + dims]; + } + sliceShape.push_back(sliceDimSize); + } + auto sliceType = inputType.clone(sliceShape); + + SmallVector sliceStarts; + for (size_t i = 0; i < dims; ++i) { + if (padElementsArray[i] < 0) { + sliceStarts.push_back(-padElementsArray[i]); + } else { + sliceStarts.push_back(0); + } + } + auto startsConstOp = onnx_mlir::createConstantOp( + rewriter, padOp->getLoc(), rewriter.getI64ArrayAttr(sliceStarts)); + + SmallVector sliceEnds; + for (size_t i = 0; i < dims; ++i) { + const auto endIdx = inputShape[i]; + if (padElementsArray[i + dims] < 0) { + sliceEnds.push_back(endIdx + padElementsArray[i + dims]); + } else { + sliceEnds.push_back(endIdx); + } + } + auto endsConstOp = onnx_mlir::createConstantOp( + rewriter, padOp->getLoc(), rewriter.getI64ArrayAttr(sliceEnds)); + + auto sliceOp = rewriter.create(padOp->getLoc(), sliceType, + padOp.getData(), startsConstOp, endsConstOp, + rewriter.create(padOp->getLoc()), + rewriter.create(padOp->getLoc())); + return sliceOp; + } +}; + +namespace { +template +class SubArrayAccessHelper { +public: + explicit SubArrayAccessHelper(ArrayRef data, size_t iterArraySize) + : data(data), iterArraySize(iterArraySize) { + assert((data.size() % iterArraySize) == 0); + } + + [[nodiscard]] size_t size() const { return data.size() / iterArraySize; } + + ArrayRef operator[](size_t idx) const { + return data.slice(idx * iterArraySize, iterArraySize); + } + +private: + ArrayRef data; + size_t iterArraySize; +}; + +class IndicesContiguousCounter { +public: + explicit IndicesContiguousCounter( + ArrayRef firstElem, ArrayRef shapeToCheck) + : counter(firstElem), firstElem(firstElem), shapeToCheck(shapeToCheck) {} + + ArrayRef getCounter() const { return counter; } + + void increment() { + // Increment from the back, carry if necessary + for (auto [shapeToCheckDimSize, firstElemDimSize, c] : + llvm::zip(llvm::reverse(shapeToCheck), llvm::reverse(firstElem), + llvm::reverse(counter))) { + if (c == (shapeToCheckDimSize + firstElemDimSize - 1)) { + c = firstElemDimSize; // Carry and keep an eventual shift in mind + } else { + c++; + break; + } + } + } + +private: + SmallVector counter; + ArrayRef firstElem; + ArrayRef shapeToCheck; +}; + +} // namespace + +// Decomposes ScatterNDs into a single Split and Concat. +// We can always split ScatterNDs by splitting the input tensor together with +// the indices and their updates belonging to that part of the input tensor, +// performing the ScatterNDs on each split, and the concatenating the result. +// Here, we handle certain ScatterNDs where after splitting them into three, +// the first and last ScatterND have empty indices (because the indices don't +// affect their parts of the input tensor), and the middle ScatterND overwrites +// the full input with sequential indices (i.e. can be replaced by a copy of its +// update). +// +// Example: +// ` %indices = onnx.Constant dense<[[[[0, 1, 0], [0, 1, 1], [0, 1, 2], +// [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], +// [0, 1, 9]]]]> : tensor<1x1x10x3xi64> +// %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : +// (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> +// tensor<1x6x10x12xf32>` +// gets decomposed to: +// ` %0 = onnx.Constant dense<[1, 1, 4]> : tensor<3xi64> +// %1:3 = "onnx.Split"(%data, %0) {axis = 1 : si64} : (tensor<1x6x10x12xf32>, +// tensor<3xi64>) -> (tensor<1x1x10x12xf32>, tensor<1x1x10x12xf32>, +// tensor<1x4x10x12xf32>) +// %2 = "onnx.Concat"(%1#0, %updates, %1#2) {axis = 1 : si64} : +// (tensor<1x1x10x12xf32>,tensor<1x1x10x12xf32>, tensor<1x4x10x12xf32>) -> +// tensor<1x6x10x12xf32>` +// +// ScatterND pseudo code: +// output = np.copy(data) +// update_indices = indices.shape[:-1] +// for idx in np.ndindex(update_indices): +// output[indices[idx]] = updates[idx] +// +// Inputs: +// data (heterogeneous) - T: Tensor of rank r >= 1. +// indices (heterogeneous) - tensor(int64): Tensor of rank q >= 1. +// updates (heterogeneous) - T: Tensor of rank q + r - indices_shape[-1] - 1. +// +// Outputs: +// output (heterogeneous) - T: Tensor of rank r >= 1. +// +// To ensure that this decomposition to split and concat is +// valid, the following constraints need to hold: +// - r == rank(updates) +// - The shape of data and updates differs only in one dimension 'a' +// -- 'a' is the dimension where the split and concat will happen +// - The update indices need to be contiguous +// -- The update indices are the last dim in indices +// -- We call them contiguous, if each idx in indices is indexing the element +// in data, that is logically directly after the element indexed by the +// previous idx +// --- logically directly after means the element that will be accessed if +// the least significant value of an elements index is increased by one +// - The update indices need to cover/index the complete data, with the +// exception of dimension 'a', where they need to cover only updates[a] +struct DecomposeScatterNDPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXScatterNDOp scatterNDOp, PatternRewriter &rewriter) const final { + // Check preconditions + if (scatterNDOp.getReductionAttr().strref() != "none") { + return rewriter.notifyMatchFailure( + scatterNDOp, "Scatters with reduction are not supported"); + } + const auto data = scatterNDOp.getData(); + const auto indices = scatterNDOp.getIndices(); + const auto updates = scatterNDOp.getUpdates(); + if (!onnx_mlir::hasStaticShape(data.getType()) || + !onnx_mlir::hasStaticShape(indices.getType()) || + !onnx_mlir::hasStaticShape(updates.getType())) { + return rewriter.notifyMatchFailure( + scatterNDOp, "All operands need to have a static shape"); + } + const auto dataType = cast(data.getType()); + const auto dataShape = dataType.getShape(); + const auto updatesType = cast(updates.getType()); + const auto updateShape = updatesType.getShape(); + const auto indicesType = cast(indices.getType()); + const auto indicesShape = indicesType.getShape(); + if (dataType.getRank() != updatesType.getRank()) { + return rewriter.notifyMatchFailure(scatterNDOp, + "Only the case where data and update have the same rank " + "is supported"); + } + + const auto splitAxis = [&]() -> uint64_t { + // Split at the dim where the update and original data have a + // different size + for (auto [idx, dimData, dimUpdates] : + llvm::enumerate(dataShape, updateShape)) { + if (dimData != dimUpdates) { + return idx; + } + } + return dataType.getRank() - + 1; // Edge case, all elements get updated, split on the last dim + }(); + + for (auto [idx, dimData, dimUpdates] : + llvm::enumerate(dataShape, updateShape)) { + if (idx != splitAxis && dimData != dimUpdates) { + return rewriter.notifyMatchFailure( + scatterNDOp, "Only a single differing dimension is supported"); + } + } + + SmallVector indicesAsFlatArray; + if (!onnx_mlir::getI64ValuesFromONNXConstantOp( + indices, indicesAsFlatArray)) { + return rewriter.notifyMatchFailure( + scatterNDOp, "The indices need to be constant"); + } + if (indicesAsFlatArray.empty()) { + return rewriter.notifyMatchFailure( + scatterNDOp, "Empty indices are not supported"); // Skip the edge case + // of empty indices + } + const auto indicesLastDimSize = indicesShape.back(); + SubArrayAccessHelper indicesFlatAccessor( + indicesAsFlatArray, indicesLastDimSize); + const auto firstIndex = + indicesFlatAccessor[0]; // Safe, we have checked the length before + for (auto [idx, firstIndexDim] : llvm::enumerate(firstIndex)) { + if (idx != splitAxis && firstIndexDim != 0) { + return rewriter.notifyMatchFailure( + scatterNDOp, " Shifting is only supported on the split axis"); + } + if (idx == splitAxis && firstIndexDim < 0) { + return rewriter.notifyMatchFailure(scatterNDOp, + "Negative values with wrap around are not yet " + "supported"); // onnx allows negative values with + // wrap-around, this decomposition does + // not (for now) + } + } + + // Check that all indices are contiguous. + // - The check for contiguity and covering works the following way: + // -- Iterated over all idx in indices and compare the idx against the + // expected index, fail if it differs + // -- The expected index is calculated the following way: + // --- The expected index is initialized with the first index in indices and + // then always incremented by one. + // --- The increment works like a manual addition, the least significant + // digit/subindex gets incremented by one. If a digit overflows, it + // gets reset to the first index and the addition carries to the next, + // more significant digit. The addition overflows, if the index for an + // axis is equal to the size of this axis in updates/indices. (By + // definition the shape for indices.shape().drop(-1) must match the + // first dimensions in updates). If the addition overflows , the + // overflowing digit is reset to its value in the first index. This is + // zero for all axes, except for 'a', where it can be a positive number + // if the split/concat is in the middle of the tensor + assert( + updateShape.drop_back(updateShape.size() - (indicesShape.size() - 1)) == + indicesShape.drop_back(1) && + "Update and indicesShape should partially match for scatterNd"); + { + IndicesContiguousCounter counter(firstIndex, indicesShape.drop_back(1)); + for (size_t i = 0; i < indicesFlatAccessor.size(); ++i) { + if (counter.getCounter() != indicesFlatAccessor[i]) { + return rewriter.notifyMatchFailure( + scatterNDOp, "Indices are not contiguous"); + } + counter.increment(); + } + } + + onnx_mlir::MultiDialectBuilder create( + rewriter, scatterNDOp->getLoc()); + // Strategy for the decomposition: + // Split at the split axis, concat the update and part of the split + // a, b = split(input) + // a1, a2 = split(a) + // concat(a1, update, b) + // In onnx this split can be done in one: + // a1, a2, b = split(input) + const auto firstSplitPosition = + (splitAxis < firstIndex.size()) ? firstIndex[splitAxis] : 0; + const auto secondSplitPosition = + updateShape[splitAxis] + firstSplitPosition; + SmallVector splitTyFirstQuarter(dataShape); + splitTyFirstQuarter[splitAxis] = firstSplitPosition; + SmallVector splitTySecondQuarter(dataShape); + splitTySecondQuarter[splitAxis] = updateShape[splitAxis]; + SmallVector splitTySecondHalf(dataShape); + splitTySecondHalf[splitAxis] -= secondSplitPosition; + Value splitSize = create.onnx.constantInt64({firstSplitPosition, + updateShape[splitAxis], splitTySecondHalf[splitAxis]}); + const Type dataElementType = dataType.getElementType(); + ValueRange split = create.onnx.split( + {RankedTensorType::get(splitTyFirstQuarter, dataElementType), + RankedTensorType::get(splitTySecondQuarter, dataElementType), + RankedTensorType::get(splitTySecondHalf, dataElementType)}, + scatterNDOp.getData(), splitSize, splitAxis); + + Value concat = create.onnx.concat( + dataType, {split[0], scatterNDOp.getUpdates(), split[2]}, splitAxis); + rewriter.replaceOp(scatterNDOp, concat); + return success(); + } +}; + // Decompose the custom op FusedMatMul that is produced by ONNXRuntime. // According to FusedMatMul specification, it is the result of fusing MatMul and // Transpose: @@ -624,7 +1202,7 @@ struct ConcatFusePattern : public OpRewritePattern { // to determine the rank of A. // // Example of onnx.Custom: -// ``` +// ``` // "onnx.Custom"(%0, %1) {alpha = 1.250000e-01 : f32, // domain_name = "com.microsoft", // function_name = "FusedMatMul", @@ -783,14 +1361,20 @@ struct InstanceNormIntoLayerNormPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + static bool isDecomposable(ONNXInstanceNormalizationOp instanceNormOp) { + return onnx_mlir::hasStaticShape(instanceNormOp.getInput().getType()) && + onnx_mlir::hasStaticShape(instanceNormOp.getOutput().getType()); + } + LogicalResult matchAndRewrite(ONNXInstanceNormalizationOp instanceNormOp, PatternRewriter &rewriter) const final { // Match. - Value input = instanceNormOp.getInput(); - if (!onnx_mlir::isRankedShapedType(input.getType())) + if (!isDecomposable(instanceNormOp)) { return failure(); + } // Get info. + Value input = instanceNormOp.getInput(); Value scale = instanceNormOp.getScale(); Value bias = instanceNormOp.getB(); ShapedType inputType = mlir::cast(input.getType()); @@ -830,93 +1414,307 @@ struct InstanceNormIntoLayerNormPattern } }; +namespace { +template +bool isGroupNormDecomposable(OP_TYPE groupNormOp) { + const Type inputType = groupNormOp.getX().getType(); + return onnx_mlir::hasStaticShape(inputType) && + onnx_mlir::hasStaticShape(groupNormOp.getResult().getType()); +} +} // namespace + // Transform GroupNormalization into LayerNormalization -struct GroupNormIntoLayerNormPattern +template +constexpr bool scaleAndBiasWithNumGroupShape = + std::is_same_v; + +template +LogicalResult ONNXGroupNormalizationCommon( + OP_TYPE groupNormOp, PatternRewriter &rewriter) { + + // Match. + if (!isGroupNormDecomposable(groupNormOp)) + return failure(); + + // Get info. + Value input = groupNormOp.getX(); + Value scale = groupNormOp.getScale(); + Value bias = groupNormOp.getBias(); + ShapedType inputType = mlir::cast(input.getType()); + Type elementType = inputType.getElementType(); + auto inputShapeVal = inputType.getShape(); + int64_t C = inputShapeVal[1]; + int64_t inputRank = inputType.getRank(); + int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions. + assert(inputRank > nonSpacialRank && + "expected instance norm with input ranks > 2"); + int64_t spacialRank = inputRank - nonSpacialRank; + int64_t layerNormRank = inputRank + 1; // +1 as C is split to NG and C/NG + int64_t numGroups = groupNormOp.getNumGroups(); + + // Rewrite. + onnx_mlir::MultiDialectBuilder create( + rewriter, groupNormOp.getLoc()); + int64_t axis = nonSpacialRank; + int64_t numInNorm = layerNormRank - axis; + Type biasScaleType; + Value axes; + Value newBias; + Value newScale; + + //"numgroups" and "C" should have the same dimension index + llvm::SmallVector axesList, biasScaleVal; + + if constexpr (scaleAndBiasWithNumGroupShape) { + // Opset18 Uses "numgroups" the number of groups of channels for the scale + // and bias + // Unsqueeze scale/bias from [NG] to [1 x NG x 1 x ... x 1] with numInNorm + // 1s. + biasScaleVal.emplace_back(numGroups); + for (int64_t i = 1; i <= numInNorm; ++i) { + biasScaleVal.emplace_back(1); + axesList.emplace_back(i); + } + + axes = create.onnx.constantInt64(axesList); + biasScaleType = RankedTensorType::get(biasScaleVal, elementType); + newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); + newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); + } else { + // Opset21 Uses "C" the number of channels for the scale and bias + // The equivalent of "C" when split is "NG x C/NG" + // Reshape scale/bias from [C] to [NG x C/NG x 1 x ... x 1] with numInNorm + // 1s. + biasScaleVal.emplace_back(numGroups); + // C can be a dynamic or static value, account for that here + if (C != ShapedType::kDynamic) { + assert(C % numGroups == 0 && "expected numGroups to divide C"); + biasScaleVal.emplace_back(C / numGroups); + } else { + biasScaleVal.emplace_back(ShapedType::kDynamic); + } + + for (int64_t i = 2; i <= numInNorm; ++i) { + biasScaleVal.emplace_back(1); + } + + // Calculate the (possible) dynamic dimensions for biasScaleShape + Value NGShape = create.onnx.constantInt64({numGroups}); + Value oneDimShape = + create.onnx.constantInt64(SmallVector(spacialRank, 1)); + Type biasScaleShapeType = + RankedTensorType::get({inputRank}, rewriter.getI64Type()); + Value biasScaleShape = create.onnx.concat( + biasScaleShapeType, {NGShape, NGShape, oneDimShape}, /*axis*/ 0); + + // Reshape instead of unsqueeze (use biasScaleShape) + biasScaleType = RankedTensorType::get(biasScaleVal, elementType); + newScale = create.onnx.reshape(biasScaleType, scale, biasScaleShape); + newBias = create.onnx.reshape(biasScaleType, bias, biasScaleShape); + } + + // Convert input from N x C x D1...Dn to N x (NG x C/NG) x D1...Dn. + // First compute the new (possible dynamic) shape. + Type batchShapeType = RankedTensorType::get({1}, rewriter.getI64Type()); + Value NShape = create.onnx.shape( + batchShapeType, input, /*start*/ 0, /*exclusive end*/ 1); + Value NGandMin1Shape = create.onnx.constantInt64({numGroups, -1}); + Type spacialShapeType = + RankedTensorType::get({spacialRank}, rewriter.getI64Type()); + Value spacialShape = + create.onnx.shape(spacialShapeType, input, /*start*/ nonSpacialRank); + Type layerNormShapeType = + RankedTensorType::get({layerNormRank}, rewriter.getI64Type()); + Value layerNormShape = create.onnx.concat(layerNormShapeType, + {NShape, NGandMin1Shape, spacialShape}, /*axis*/ + 0); + // Compute type of converted input. + llvm::SmallVector layerNormShapeVal; + // Create a new tensor with the following dimensions: N, NG, C/NG, D1, D2, + // Dn... + layerNormShapeVal.emplace_back(inputShapeVal[0]); // N + layerNormShapeVal.emplace_back(numGroups); // NG + if (C != ShapedType::kDynamic) { + assert(C % numGroups == 0 && "expected numGroups to divide C"); + layerNormShapeVal.emplace_back(C / numGroups); // (C/NG) + } else + layerNormShapeVal.emplace_back(ShapedType::kDynamic); + for (int64_t i = 0; i < spacialRank; ++i) + layerNormShapeVal.emplace_back(inputShapeVal[nonSpacialRank + i]); // Dn + RankedTensorType layerNormInputType = + RankedTensorType::get(layerNormShapeVal, elementType); + Value layerNormInput = + create.onnx.reshape(layerNormInputType, input, layerNormShape); + // Create output using layer norm. + Value layerNormY = create.onnx.layerNorm(layerNormInputType, layerNormInput, + newScale, newBias, axis, groupNormOp.getEpsilonAttr()); + // Resize output to original size + Type inputShapeType = + RankedTensorType::get({inputRank}, rewriter.getI64Type()); + Value inputShape = create.onnx.shape(inputShapeType, input); + Type outputType = groupNormOp.getY().getType(); + Value Y = create.onnx.reshape(outputType, layerNormY, inputShape); + // Set the type of the output to be the same as the output of the original + // operation we are trying to replace. + Y.setType(groupNormOp.getResult().getType()); + // Replace operation. + rewriter.replaceOp(groupNormOp, Y); + return success(); +} + +struct GroupNormIntoLayerNormPattern1 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXGroupNormalizationOp groupNormOp, PatternRewriter &rewriter) const final { - // Match. - Value input = groupNormOp.getX(); - if (!onnx_mlir::isRankedShapedType(input.getType())) - return failure(); + return ONNXGroupNormalizationCommon( + groupNormOp, rewriter); + } +}; - // Get info. - Value scale = groupNormOp.getScale(); - Value bias = groupNormOp.getBias(); - ShapedType inputType = mlir::cast(input.getType()); - Type elementType = inputType.getElementType(); - auto inputShapeVal = inputType.getShape(); - int64_t C = inputShapeVal[1]; - int64_t inputRank = inputType.getRank(); - int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions. - assert(inputRank > nonSpacialRank && - "expected instance norm with input ranks > 2"); - int64_t spacialRank = inputRank - nonSpacialRank; - int64_t layerNormRank = inputRank + 1; // +1 as C is split to NG and C/NG - int64_t numGroups = groupNormOp.getNumGroups(); +struct GroupNormIntoLayerNormPattern2 + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - // Rewrite. - onnx_mlir::MultiDialectBuilder create( - rewriter, groupNormOp.getLoc()); - int64_t axis = nonSpacialRank; - int64_t numInNorm = layerNormRank - axis; - // Unsqueeze scale/bias from [NG] to [NG x 1 x 1 x ... x 1] with numInNorm - // 1s. - llvm::SmallVector axesList, biasScaleShape; - biasScaleShape.emplace_back(numGroups); - for (int64_t i = 1; i <= numInNorm; ++i) { - biasScaleShape.emplace_back(1); - axesList.emplace_back(i); + LogicalResult matchAndRewrite(ONNXGroupNormalizationV18Op groupNormOp, + PatternRewriter &rewriter) const final { + return ONNXGroupNormalizationCommon( + groupNormOp, rewriter); + } +}; + +/// Decompose `onnx.SoftmaxCrossEntropyLoss` to the following sequence: +/// In the following we assume classes is in dim=1 of scores. +/// 1. one_hot_encoded = onnx.Castlike(onnx.OneHot(labels, dim=1), scores) +/// 2. log_softmax = onnx.Log(onnx.Softmax(scores, dim=1)) +/// 3. product = onnx.Mul(log_softmax, one_hot_encoded) +/// if `weights` arg is nont `none` then we additionally perform +/// product = onnx.Mul(product, op.Unsqueeze(weights)) +/// where unsqueezing makes the operation broadcastable. +/// 4. reduce_sum = onnx.ReduceSum(product, dim=1) +/// 5. loss = onnx.ReduceMean(reduce_sum) if reduciton == "mean" +/// onnx.ReduceSum(reduce_sum) if reduction == "sum" +/// onnx.Squeeze(reduce_sum) if reduciton == "none" +/// +struct SoftmaxCrossEntropyPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXSoftmaxCrossEntropyLossOp sceOp, + PatternRewriter &rewriter) const final { + auto loc = sceOp.getLoc(); + onnx_mlir::OnnxBuilder create(rewriter, loc); + auto scores = sceOp.getScores(); + auto labels = sceOp.getLabels(); + auto weights = sceOp.getWeights(); + auto scoresTy = cast(scores.getType()); + auto labelsTy = cast(labels.getType()); + SmallVector newLabelsShape(labelsTy.getShape()); + newLabelsShape.insert(newLabelsShape.begin() + 1, scoresTy.getShape()[1]); + auto none = rewriter.create(loc); + auto numClasses = (scoresTy.isDynamicDim(1)) + ? create.dim(scores, 1) + : create.constantInt64({scoresTy.getShape()[1]}); + auto elemTy = scoresTy.getElementType(); + // Compute one hot encoded labels and cast to `scores` element type. + auto oneHotValsAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), + ArrayRef{0, 1}); + auto oneHotVals = create.constant(oneHotValsAttr); + auto oneHot = create.cast( + rewriter.create(loc, + RankedTensorType::get(newLabelsShape, labelsTy.getElementType()), + labels, numClasses, oneHotVals, /*axis=*/1), + /*saturate=*/ + rewriter.getIntegerAttr(rewriter.getIntegerType(64, true), 1), + TypeAttr::get(elemTy)); + // Compute logsoftmax of scores. + auto softmax = + rewriter.create(loc, scoresTy, scores, /*axis=*/1); + auto logSoftmax = rewriter.create(loc, scoresTy, softmax); + auto prod = rewriter.create(loc, logSoftmax, oneHot); + // Multiply by `weights` if not none. + if (auto weightTy = dyn_cast(weights.getType())) { + // Unsqueeze weight from [C] to [1 x C x 1 x ... x 1] to make it + // broadcast-compliant. + llvm::SmallVector unsqueezedShape(scoresTy.getRank(), 1); + unsqueezedShape[1] = scoresTy.getShape()[1]; + llvm::SmallVector axesList(scoresTy.getRank() - 1, 0); + std::iota(axesList.begin() + 1, axesList.end(), 2); + auto axes = create.constantInt64(axesList); + auto weightsUnsqueezed = create.unsqueeze( + RankedTensorType::get(unsqueezedShape, elemTy), weights, axes); + prod = rewriter.create(loc, prod, weightsUnsqueezed); } - Value axes = create.onnx.constantInt64(axesList); - Type biasScaleType = RankedTensorType::get(biasScaleShape, elementType); - Value newScale = create.onnx.unsqueeze(biasScaleType, scale, axes); - Value newBias = create.onnx.unsqueeze(biasScaleType, bias, axes); - // Convert input from N x C x D1...Dn to N x (NG x C/NG) x D1...Dn. - // First compute the new (possibly dynamic) shape. - Type batchShapeType = RankedTensorType::get({1}, rewriter.getI64Type()); - Value NShape = create.onnx.shape( - batchShapeType, input, /*start*/ 0, /*exclusive end*/ 1); - Value NGandMin1Shape = create.onnx.constantInt64({numGroups, -1}); - Type spacialShapeType = - RankedTensorType::get({spacialRank}, rewriter.getI64Type()); - Value spacialShape = - create.onnx.shape(spacialShapeType, input, /*start*/ nonSpacialRank); - Type layerNormShapeType = - RankedTensorType::get({layerNormRank}, rewriter.getI64Type()); - Value layerNormShape = create.onnx.concat( - layerNormShapeType, {NShape, NGandMin1Shape, spacialShape}, /*axis*/ 0); - // Compute type of converted input. - llvm::SmallVector layerNormShapeVal; - layerNormShapeVal.emplace_back(inputShapeVal[0]); - layerNormShapeVal.emplace_back(numGroups); - if (C != ShapedType::kDynamic) { - assert(C % numGroups == 0 && "expected numGroups to divide C"); - layerNormShapeVal.emplace_back(C / numGroups); - } else - layerNormShapeVal.emplace_back(ShapedType::kDynamic); - for (int64_t i = 0; i < spacialRank; ++i) - layerNormShapeVal.emplace_back(inputShapeVal[nonSpacialRank + i]); - RankedTensorType layerNormInputType = - RankedTensorType::get(layerNormShapeVal, elementType); - Value layerNormInput = - create.onnx.reshape(layerNormInputType, input, layerNormShape); - // Create output using layer norm. - Value layerNormY = create.onnx.layerNorm(layerNormInputType, layerNormInput, - newScale, newBias, axis, groupNormOp.getEpsilonAttr()); - // Resize output to original size - Type inputShapeType = - RankedTensorType::get({inputRank}, rewriter.getI64Type()); - Value inputShape = create.onnx.shape(inputShapeType, input); - Type outputType = groupNormOp.getY().getType(); - Value Y = create.onnx.reshape(outputType, layerNormY, inputShape); - // Set the type of the output to be the same as the output of the original - // operation we are trying to replace. - Y.setType(groupNormOp.getResult().getType()); - // Replace operation. - rewriter.replaceOp(groupNormOp, Y); + // Reduction across `class` (dim=1) axis. + auto axes = create.constant(onnx_mlir::createDenseArrayAttr( + rewriter, rewriter.getI64ArrayAttr({1}))); + auto reducedType = createReducedType(scoresTy, 1, /*keepdims=*/true); + Value loss = rewriter.create(loc, reducedType, prod, axes); + // ReduceMean/ReduceSum/Squeeze if reduction = mean/sum/none respectively. + // Set `axes=none` to indicate reducing all dims. + auto reduction = cast(sceOp.getReductionAttr()).getValue(); + if (reduction == "mean") { + if (isa(weights.getType())) { + loss = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + } else { + auto sumL = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + // Perform einsum(one_hot, weights) as a simple way of producing + // W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]] + auto scatteredWeights = rewriter.create(loc, + RankedTensorType::get(labelsTy.getShape(), elemTy), + ValueRange{oneHot, weights}, "ij...,j->i..."); + auto sumW = rewriter.create(loc, + RankedTensorType::get({}, elemTy), scatteredWeights, none, + /*keepdims=*/0); + loss = rewriter.create(loc, sumL, sumW); + } + } else if (reduction == "sum") { + loss = rewriter.create(loc, + RankedTensorType::get({}, elemTy), loss, none, + /*keepdims=*/0); + } else if (reduction == "none") { + loss = rewriter.create(loc, + createReducedType(reducedType, 1, /*keepdims=*/false), loss, axes); + } else { + llvm_unreachable("unexpected reduction type"); + } + // Negate. + loss = rewriter.create(loc, loss.getType(), loss); + // Second return value replacement depends if it is `none` or not. + if (isa(sceOp.getLogProb().getType())) + rewriter.replaceOp(sceOp, {loss, none}); + else + rewriter.replaceOp(sceOp, {loss, logSoftmax}); + return success(); + } +}; + +/// Decompose `onnx.Sum` to a sequence of `onnx.Add` +struct SumToAddPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXSumOp sumOp, PatternRewriter &rewriter) const final { + SmallVector inputs(sumOp.getData_0()); + assert(inputs.size() > 0 && "expected at least one input"); + Value result = inputs[0]; + if (inputs.size() > 1) { + inputs.erase(inputs.begin()); + for (auto input : inputs) { + result = rewriter.create(sumOp.getLoc(), result, input); + } + } + auto resultType = mlir::cast(sumOp.getResult().getType()); + if (resultType != result.getType()) + result = rewriter.create( + sumOp.getLoc(), resultType, result, 1, resultType.getElementType()); + rewriter.replaceOp(sumOp, result); return success(); } }; @@ -990,101 +1788,17 @@ struct DecomposeONNXToONNXPass void DecomposeONNXToONNXPass::runOnOperation() { func::FuncOp function = getOperation(); MLIRContext *context = &getContext(); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - - // These ops will be decomposed into other ONNX ops. Hence, they will not be - // available after this pass. - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - target.addDynamicallyLegalOp([](ONNXEinsumOp op) { - return !onnx_mlir::DecomposeEinsumPattern::isDecomposable(op); - }); - - target.addDynamicallyLegalOp([](ONNXConcatOp op) { - ONNXShapeOp shapeOp; - ONNXTransposeOp transposeOp; - return !isConcatFuseMatched(op, shapeOp, transposeOp); - }); - - // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs. - target.addDynamicallyLegalOp([](ONNXConstantOp op) { - return !(op.getValueFloatAttr() || op.getValueFloatsAttr() || - op.getValueIntAttr() || op.getValueIntsAttr() || - op.getValueStringAttr() || op.getValueStringsAttr()); - }); - - // Decompose CustomOp FusedMatMul introduced by onnxruntime: - // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul - target.addDynamicallyLegalOp([](ONNXCustomOp op) { - int64_t rankA, rankB; - FloatAttr alpha; - return !CustomOpFuseMatMulPattern::isCustomOpFusedMatMulMatched( - op, alpha, rankA, rankB); - }); - -#ifdef ONNX_MLIR_DECOMP_ONNX_CONVTRANSPOSE -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - // ONNXtoStablehlo pass has own rewriting for ConvTranspose Op using - // stablehlo ops. To avoid conflict with it, decomposing for ConvTranspose - // is disabled when the target is stablehlo. - if (this->target != "stablehlo") { -#endif - target.addDynamicallyLegalOp( - [](ONNXConvTransposeOp op) { - return !onnx_mlir::shouldDecomposeConvTransposeOp(op); - }); -#ifdef ONNX_MLIR_ENABLE_STABLEHLO - } -#endif -#endif - RewritePatternSet patterns(context); onnx_mlir::getDecomposeONNXToONNXPatterns(patterns); patterns.insert(context); + #ifdef ONNX_MLIR_ENABLE_STABLEHLO if (this->target == "stablehlo") { populateDecomposingONNXBeforeStablehloPatterns(patterns, context); - target.addIllegalOp(); } #endif - if (failed(applyPartialConversion(function, target, std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(function, std::move(patterns)))) signalPassFailure(); } @@ -1096,11 +1810,19 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( populateWithGenerated(patterns); patterns.insert(context); patterns.insert(context); + patterns.insert(context); // Decompose CustomOp FusedMatMul introduced by onnxruntime: // https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.FusedMatMul patterns.insert(context); patterns.insert(context); - patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); // TODO: consider whether to include SoftmaxPattern here } @@ -1111,4 +1833,4 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( std::unique_ptr onnx_mlir::createDecomposeONNXToONNXPass( const std::string &target) { return std::make_unique(target); -} +} \ No newline at end of file diff --git a/src/Dialect/ONNX/Transforms/Decompose.td b/src/Dialect/ONNX/Transforms/Decompose.td index 3cea294521..5454f13275 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.td +++ b/src/Dialect/ONNX/Transforms/Decompose.td @@ -71,6 +71,15 @@ def createScalarDenseAttrRank0 def ReshapeElementsAttrToRank0 : NativeCodeCall< "onnx_mlir::OnnxElementsAttrBuilder($0.getContext()).reshape(cast($0), {})">; +def ReplaceSequenceAt : NativeCodeCall< + "onnx_mlir::replaceSequenceAt($_builder, $_loc, $0)">; + +def UpgradeGridSampleV16Mode : NativeCodeCall< + "onnx_mlir::upgradeGridSampleV16Mode($_builder, $0)">; + +def CanSequenceAtBeReplaced : + Constraint, "check whether the SequenceAt can be replaced with split">; + // Create a DenseElementsAttr from a single attribute. def createDenseArrayAttrFromSingleAttr : NativeCodeCall<"::onnx_mlir::createDenseArrayAttr($_builder, $_builder.getArrayAttr($0))">; @@ -359,6 +368,18 @@ def ClipV12Pattern : Pat< (ONNXClipOp $x, $min, $max) >; +// Rewrite GridSample 20 to GridSample 22 +def GridSampleV20Pattern : Pat< + (ONNXGridSampleV20Op $x, $grid, $align_corners, $mode, $padding_mode), + (ONNXGridSampleOp $x, $grid, $align_corners, $mode, $padding_mode) +>; + +// Rewrite GridSample 16 to GridSample 20 +def GridSampleV16Pattern : Pat< + (ONNXGridSampleV16Op $x, $grid, $align_corners, $mode, $padding_mode), + (ONNXGridSampleV20Op $x, $grid, $align_corners, (UpgradeGridSampleV16Mode $mode), $padding_mode) +>; + def DFTV17Pattern : Pat< (ONNXDFTV17Op $x, $dft_length, $axis, $inverse, $onesided), (ONNXDFTOp $x, $dft_length, (ONNXConstantOpFromDenseAttr(createScalarDenseAttrRank0 $axis)), $inverse, $onesided) @@ -620,4 +641,16 @@ def ConstantOpNormalizationPattern6: Pat< [(AttributeIsNotNull:$stringsAttr)] >; +// Optimize for the pattern coming from torch.nn.LSTM exported from pytorch +// %32 = "onnx.SplitToSequence"(%30, %27) {axis = 0 : si64, keepdims = 0 : si64, onnx_node_name = "n1"} : (tensor<1x1x100xf32>, tensor) -> !onnx.Seq> +// %33 = "onnx.SequenceAt"(%32, %26) {onnx_node_name = "n0"} : (!onnx.Seq>, tensor) -> tensor<1x100xf32> +// When shape and size/axis related value are constant, this sequence of code +// can be translated into onnx.slice + +def ReplaceSequenceAtPattern: Pat< + (ONNXSequenceAtOp:$res $seq, $position), + (ReplaceSequenceAt $res), + [(CanSequenceAtBeReplaced:$res)] +>; + #endif // ONNX_DECOMPOSE diff --git a/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp index 0c8864c556..3b3a46e68e 100644 --- a/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp +++ b/src/Dialect/ONNX/Transforms/DecomposeEinsum.cpp @@ -107,7 +107,7 @@ struct Output : public einsum::Parameter { } void eraseAxis(int64_t a) { - assert(0 <= a && a < (int64_t)size() && + assert(0 <= a && a < static_cast(size()) && "axis a should be nonnegative and within range"); shape.erase(shape.begin() + a); subscripts.erase(subscripts.begin() + a); diff --git a/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp index ccfd5fe154..0e58963512 100644 --- a/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp @@ -146,11 +146,11 @@ struct ONNXHybridTransformPass maxNumRewritesOffset + maxNumRewritesMultiplier * numOps; } if (failed(applyPatternsAndFoldGreedily(body, patterns, config))) { - llvm::errs() << "Warning: onnx-hybrid-transform didn't converge with " + llvm::errs() << "\nWarning: onnx-hybrid-transform didn't converge with " << "max-num-rewrites-offset=" << maxNumRewritesOffset.getValue() << ", " << "max-num-rewrites-multiplier=" - << maxNumRewritesMultiplier.getValue() << "\n"; + << maxNumRewritesMultiplier.getValue() << "\n\n"; } inferFunctionReturnShapes(f); diff --git a/src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp b/src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp index aa627e73c3..ce7a04a849 100644 --- a/src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp +++ b/src/Dialect/ONNX/Transforms/ONNXOpTransformPass.cpp @@ -40,8 +40,10 @@ struct ONNXOpTransformPass : public mlir::PassWrapper onnxOpTransformReport{*this, "onnx-op-transform-report", llvm::cl::desc("Report diagnostic info for op transform passes."), llvm::cl::init(false)}; + // NOTE: FlexML changes the default for this flag to false, as we do not want + // to run the CPU specific transformations. Option onnxOpTransformTargetCPU{*this, "onnx-op-transform-target-cpu", - llvm::cl::desc("Target CPU op transform passes."), llvm::cl::init(true)}; + llvm::cl::desc("Target CPU op transform passes."), llvm::cl::init(false)}; Option onnxOpTransformEnableSimdDataLayout{*this, "onnx-op-transform-simd-data-layout", llvm::cl::desc("Enable SIMD data layout opt in op transform passes."), diff --git a/src/Dialect/ONNX/Transforms/Recompose.cpp b/src/Dialect/ONNX/Transforms/Recompose.cpp index 3e32f2ca6c..9a4eb2ace6 100644 --- a/src/Dialect/ONNX/Transforms/Recompose.cpp +++ b/src/Dialect/ONNX/Transforms/Recompose.cpp @@ -48,17 +48,19 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { LogicalResult matchAndRewrite( ONNXMulOp mulOp, PatternRewriter &rewriter) const final { using namespace onnx_mlir; - Location loc = mulOp.getLoc(); // Match Value x, scale; FloatAttr epsilon; int64_t axis; bool isRMSLayerNorm; - if (!matchLayerNormPattern(mulOp, x, scale, axis, epsilon, isRMSLayerNorm)) + SmallVector layerNormLocations; + if (!matchLayerNormPattern( + mulOp, x, scale, axis, epsilon, layerNormLocations, isRMSLayerNorm)) return failure(); // Replace - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create( + rewriter, rewriter.getFusedLoc(layerNormLocations)); Type xType = x.getType(); Value noneVal = create.onnx.none(); Value res; @@ -66,6 +68,7 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { res = create.onnx.RMSLayerNorm(xType, x, scale, noneVal, axis, epsilon); else res = create.onnx.layerNorm(xType, x, scale, noneVal, axis, epsilon); + copySingleResultType(mulOp, res); rewriter.replaceOp(mulOp, res); return success(); } @@ -99,7 +102,7 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { */ static bool matchLayerNormPattern(ONNXMulOp LayerNormOp, Value &x, Value &scale, int64_t &axis, FloatAttr &epsilonAttr, - bool &isRMSLayerNorm) { + SmallVectorImpl &layerNormLocations, bool &isRMSLayerNorm) { using namespace onnx_mlir; Location loc = LayerNormOp.getLoc(); isRMSLayerNorm = false; @@ -110,11 +113,17 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { // Replicate of values, check that they are identical to originals. Value d1, d2; // Operations that will be gathered and kept locally. - Operation *nsMulOp, *ddMulOp, *nDivOp, *nMulOp, *isdRecipOp, *sdSqrtOp, - *veAddOp, *vReduceOp, *mReduceOp, *dSubOp; - nsMulOp = LayerNormOp.getOperation(); + Operation *nsMulOp = LayerNormOp.getOperation(); + Operation *ddMulOp = nullptr; + Operation *nDivOp = nullptr; + Operation *nMulOp = nullptr; + Operation *isdRecipOp = nullptr; + Operation *sdSqrtOp = nullptr; + Operation *veAddOp = nullptr; + Operation *vReduceOp = nullptr; + Operation *mReduceOp = nullptr; + Operation *dSubOp = nullptr; // after this group, we have defined norm, scale, d, and sdSqrtOp. - nDivOp = nMulOp = isdRecipOp = nullptr; if (operandOfOpDefinedBy(nDivOp, nsMulOp, norm, scale, 0) || operandOfOpDefinedBy(nDivOp, nsMulOp, scale, norm, 1)) { // Matched norm = d / stdDev. @@ -126,7 +135,6 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { // %norm = "onnx.Div"(%d, %stdDev) if (!operandOfOpDefinedBy(sdSqrtOp, nDivOp, d, stdDev, 1)) return reportFailure("RMS missing std dev (via div), sqrt op"); - } else if (operandOfOpDefinedBy( nMulOp, nsMulOp, norm, scale, 0) || operandOfOpDefinedBy( @@ -190,15 +198,19 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { // %stdDev = "onnx.Sqrt"(%varEps) if (!operandOfOpDefinedBy(veAddOp, sdSqrtOp, varEps)) return reportFailure("RMS missing var + eps, add op"); - // %var = "onnx.ReduceMeanV13"(%dd) + // %var = "onnx.ReduceMean(V13)"(%dd) // %varEps = "onnx.Add"(%var, %eps) - if (!operandOfOpDefinedBy( - vReduceOp, veAddOp, var, epsilon, 0) && - !operandOfOpDefinedBy( - vReduceOp, veAddOp, epsilon, var, 1)) + if ((!operandOfOpDefinedBy( + vReduceOp, veAddOp, var, epsilon, 0) && + !operandOfOpDefinedBy( + vReduceOp, veAddOp, epsilon, var, 1)) && + (!operandOfOpDefinedBy( + vReduceOp, veAddOp, var, epsilon, 0) && + !operandOfOpDefinedBy( + vReduceOp, veAddOp, epsilon, var, 1))) return reportFailure("RMS missing var, reduce mean op"); // %dd = "onnx.Mul"(%d, %d) - // %var = "onnx.ReduceMeanV13"(%dd) + // %var = "onnx.ReduceMean(V13)"(%dd) if (!operandOfOpDefinedBy(ddMulOp, vReduceOp, dd)) return reportFailure("RMS missing DD, mul op"); @@ -230,10 +242,12 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { if (!isScalarTensor(epsilon)) return reportFailure("RMS epsilon is expected to be scalar"); ONNXConstantOp epsilonOp = - dyn_cast(epsilon.getDefiningOp()); + mlir::dyn_cast(epsilon.getDefiningOp()); if (!epsilonOp) return reportFailure("RMS epsilon needs to be a constant"); - epsilonAttr = epsilonOp.getValueFloatAttr(); + const auto epsilonValue = getScalarValue(epsilonOp); + epsilonAttr = + FloatAttr::get(Float32Type::get(epsilonOp->getContext()), epsilonValue); // Check axes. if (!hasShapeAndRank(dd)) return reportFailure("RMS need rank and shape for input dd"); @@ -253,10 +267,12 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { if (hasFullPattern && !operandOfOpDefinedBy(dSubOp, ddMulOp, d1, d2, 1)) hasFullPattern = reportFailure("LN missing D, sub op"); - // %mean = "onnx.ReduceMeanV13"(%x) + // %mean = "onnx.ReduceMean(V13)"(%x) // %d = "onnx.Sub"(%X, %mean) - if (hasFullPattern && !operandOfOpDefinedBy( - mReduceOp, dSubOp, x1, mean, 1)) + if (hasFullPattern && (!operandOfOpDefinedBy( + mReduceOp, dSubOp, x1, mean, 1) && + !operandOfOpDefinedBy( + mReduceOp, dSubOp, x1, mean, 1))) hasFullPattern = reportFailure("LN missing mean, reduce mean op"); // 4: We have the ops for a traditional LM pattern, now check a few more @@ -266,8 +282,15 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { if (hasFullPattern) { // Verify that the mReduceOp uses x as well. - auto lnOp = cast(mReduceOp); - Value x2 = lnOp.getData(); + Value x2 = [](Operation *op) { + if (auto rmOp = mlir::dyn_cast(op)) { + return rmOp.getData(); + } + if (auto rmV13Op = mlir::dyn_cast(op)) { + return rmV13Op.getData(); + } + llvm_unreachable("Expected ONNXReduceMeanOp or ONNXReduceMeanV13Op"); + }(mReduceOp); if (x1 != x2) hasFullPattern = reportFailure( "LN input x to mean/ReduceMean and sub are different"); @@ -299,21 +322,61 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { LLVM_DEBUG( llvm::dbgs() << "RMSLayerNorm from mult, axis " << axis << "\n"); } + // Collect the locations of the recomposed ops + if (mReduceOp) + layerNormLocations.push_back(mReduceOp->getLoc()); + if (dSubOp) + layerNormLocations.push_back(dSubOp->getLoc()); + layerNormLocations.push_back(ddMulOp->getLoc()); + layerNormLocations.push_back(vReduceOp->getLoc()); + layerNormLocations.push_back(veAddOp->getLoc()); + layerNormLocations.push_back(sdSqrtOp->getLoc()); + if (isdRecipOp) + layerNormLocations.push_back(isdRecipOp->getLoc()); + if (nMulOp) + layerNormLocations.push_back(nMulOp->getLoc()); + if (nDivOp) + layerNormLocations.push_back(nDivOp->getLoc()); + layerNormLocations.push_back(loc); + return true; } private: static bool suitableAxis(Operation *op, int64_t xRank, int64_t &axis) { - ONNXReduceMeanV13Op reduceOp = cast(op); - if (reduceOp.getKeepdims() != 1) - return reportFailure("need keepdims = 1"); - ArrayAttr axesAttr = reduceOp.getAxesAttr(); - int64_t axesSize = axesAttr.size(); + SmallVector axes; // The axes attribute/operand of the ReduceMeanOp + if (auto reduceOpV13 = mlir::dyn_cast(op)) { + if (reduceOpV13.getKeepdims() != 1) + return reportFailure("need keepdims = 1"); + ArrayAttr axesAttr = reduceOpV13.getAxesAttr(); + for (size_t i = 0; i < axesAttr.size(); ++i) { + axes.emplace_back(onnx_mlir::ArrayAttrIntVal(axesAttr, i)); + } + } else if (auto reduceOp = mlir::dyn_cast(op)) { + if (reduceOp.getKeepdims() != 1) + return reportFailure("need keepdims = 1"); + Value axesValue = reduceOp.getAxes(); + if (isa(axesValue.getType())) { + if (reduceOp.getNoopWithEmptyAxes()) { + // No reduction + return reportFailure("needs a reduction on at least one dimension"); + } else { + // Reduction on all dimensions + axis = 0; + return true; + } + } + if (!onnx_mlir::getI64ValuesFromONNXConstantOp(axesValue, axes)) { + return reportFailure("only static axes are supported"); + } + } else { + llvm_unreachable("ReduceMean is the only supported op"); + } + // Record axes value in bit vector. llvm::SmallBitVector reduceAxes(xRank, false); - for (int64_t i = 0; i < axesSize; ++i) { - int64_t a = onnx_mlir::getAxisInRange( - onnx_mlir::ArrayAttrIntVal(axesAttr, i), xRank); + for (int64_t axe : axes) { + int64_t a = onnx_mlir::getAxisInRange(axe, xRank); reduceAxes[a] = true; } // Check that we have a "false"* "true"+ pattern. @@ -340,10 +403,214 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern { } }; +struct RecomposeGeluFromMulPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXMulOp mulOp, PatternRewriter &rewriter) const final { + using namespace onnx_mlir; + Location loc = mulOp.getLoc(); + // Match: + // - for exact gelu + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + // where 1.41421354 is sqrt(2). + // + // or + // + // - for approximate gelu + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // where 0.797884583 is sqrt(2/pi). + Value x; + bool isExactGelu = false; + if (!matchGeluPattern(mulOp, x, isExactGelu)) + return failure(); + + // Replace + MultiDialectBuilder create(rewriter, loc); + StringAttr approximateAttr = + rewriter.getStringAttr(isExactGelu ? "none" : "tanh"); + Value res = create.onnx.gelu(x, approximateAttr); + copySingleResultType(mulOp, res); + rewriter.replaceOp(mulOp, res); + return success(); + } + + static bool matchGeluPattern(ONNXMulOp mulOp, Value &x, bool &isExactGelu) { + using namespace onnx_mlir; + // Subgraph to match: + // - for exact gelu + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + // where 1.41421354 is sqrt(2). + // + // or + // + // - for approximate gelu + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // where 0.797884583 is sqrt(2/pi). + // + // Associcative and communitative properties are handled. + + // Helper function. + auto constOf = [](Value v, double n) { + return isDenseONNXConstant(v) && isConstOf(v, n); + }; + + // Match 0.5 * a * b + // Two associative cases depending on which Mul 0.5 belongs to: + // - 0.5 * (a * b) + // - (0.5 * a) * b + // For each case, we have four communitive cases: 2 for the outer Mul and 2 + // for the inner Mul. In total, we handle 8 cases. + Value lhs = mulOp.getOperand(0); + Value rhs = mulOp.getOperand(1); + + Value fstMulVal, sndMulVal; + bool foundHalf = false; + + ONNXMulOp innerMulOp; + if (matchConstAndOp(lhs, rhs, 0.5, innerMulOp)) { + // - 0.5 * (a * b) or (a * b) * 0.5 + fstMulVal = innerMulOp.getOperand(0); + sndMulVal = innerMulOp.getOperand(1); + foundHalf = true; + } + if (!foundHalf && !constOf(lhs, 0.5) && !constOf(rhs, 0.5)) { + if (auto lhsMulOp = lhs.getDefiningOp()) { + // - (0.5 * a) * b + Value l = lhsMulOp.getOperand(0); + Value r = lhsMulOp.getOperand(1); + if (constOf(l, 0.5)) { + fstMulVal = r; + sndMulVal = rhs; + foundHalf = true; + } else if (constOf(r, 0.5)) { + fstMulVal = l; + sndMulVal = rhs; + foundHalf = true; + } + } + if (!foundHalf) { + if (auto rhsMulOp = rhs.getDefiningOp()) { + // - b * (0.5 * a) + Value l = rhsMulOp.getOperand(0); + Value r = rhsMulOp.getOperand(1); + if (constOf(l, 0.5)) { + fstMulVal = lhs; + sndMulVal = r; + foundHalf = true; + } else if (constOf(r, 0.5)) { + fstMulVal = lhs; + sndMulVal = l; + foundHalf = true; + } + } + } + } + if (!foundHalf) + return reportFailure("missing 0.5 * a * b"); + + // Exact gelu. + // Match 1 + erf() + bool foundErf = false; + ONNXErfOp erfOp; + // Try the first operand. + if (auto add1Op = fstMulVal.getDefiningOp()) { + foundErf = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp); + if (foundErf) + x = sndMulVal; + } + if (!foundErf) { + // Try the second operand. + if (auto add1Op = sndMulVal.getDefiningOp()) { + foundErf = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, erfOp); + if (foundErf) + x = fstMulVal; + } + } + if (foundErf) { + // gelu(x) = 0.5 * x * (1 + erf(x/1.41421354)) + Value erfInput = erfOp.getOperand(); + auto divOp = erfInput.getDefiningOp(); + if (!divOp) + return reportFailure("[Exact] missing div op"); + if (divOp.getOperand(0) != x) + return reportFailure("[Exact] missing x in x/1.41421354"); + if (!constOf(divOp.getOperand(1), 1.41421354)) + return reportFailure("[Exact] missing 1.41421354"); + isExactGelu = true; + return true; + } else { + // Do not return here, we still check the approximate case. + reportFailure("[Exact] missing (1 + erf)"); + } + + // Approximate gelu. + // gelu(x) = 0.5 * x * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)]) + // Match 1 + tanh() + bool foundTanh = false; + ONNXTanhOp tanhOp; + // Try the first operand. + if (auto add1Op = fstMulVal.getDefiningOp()) { + foundTanh = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp); + if (foundTanh) + x = sndMulVal; + } + if (!foundTanh) { + // Try the second operand. + if (auto add1Op = sndMulVal.getDefiningOp()) { + foundTanh = matchConstAndOp( + add1Op.getOperand(0), add1Op.getOperand(1), 1.0, tanhOp); + if (foundTanh) + x = fstMulVal; + } + } + if (!foundTanh) + return reportFailure("[Approximate] missing (1 + tanh)"); + + // Match 0.797884583 * (x + 0.044715 * x^3) + auto mul1Op = tanhOp.getOperand().getDefiningOp(); + if (!mul1Op) + return reportFailure("[Approximate] missing mul op for (0.797884583 *)"); + ONNXAddOp add2Op; + if (!matchConstAndOp( + mul1Op.getOperand(0), mul1Op.getOperand(1), 0.797884583, add2Op)) + return reportFailure( + "[Approximate] missing add op for (x + 0.044715*x^3))"); + + // Match x + 0.044715 * x^3 + ONNXMulOp mul2Op; + if (!matchValueAndOp( + add2Op.getOperand(0), add2Op.getOperand(1), x, mul2Op)) + return reportFailure("[Approximate] missing mul op for 0.044715 * x^3"); + + // Match 0.044715 * x^3 + ONNXPowOp powOp; + if (!matchConstAndOp( + mul2Op.getOperand(0), mul2Op.getOperand(1), 0.044715, powOp)) + return reportFailure("[Approximate] missing 0.044715 and/or pow op"); + + // Match x^3 + lhs = powOp.getOperand(0); + rhs = powOp.getOperand(1); + if (lhs == x && constOf(rhs, 3.0)) + return true; + + return reportFailure("subgraph not found"); + } + + static bool reportFailure(std::string msg) { + // Can disable line below if not needed. + LLVM_DEBUG(llvm::dbgs() << "Gelu failure: " << msg << "\n"); + return false; + } +}; + struct RecomposeQLinearMatMulFromQuantizeLinearPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite( ONNXQuantizeLinearOp qlOp, PatternRewriter &rewriter) const final { using namespace onnx_mlir; @@ -360,6 +627,7 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern aZeroPoint, b, bScale, bZeroPoint, outScale, outZeroPoint); rewriter.replaceOp(qlOp, res); + copySingleResultType(qlOp, res); return success(); } @@ -378,8 +646,8 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern matmulOp, quantizeOp, qlX, 0); if (!matchMatMul) return false; - matA = cast(matmulOp).getA(); - matB = cast(matmulOp).getB(); + matA = mlir::cast(matmulOp).getA(); + matB = mlir::cast(matmulOp).getB(); // Matching input A of MatMul. auto dlOpA = matA.getDefiningOp(); if (!dlOpA) @@ -442,8 +710,16 @@ void RecomposeONNXToONNXPass::runOnOperation() { FloatAttr epsilon; int64_t axis; bool isRMSLayerNorm; - return !RecomposeLayerNormFromMulPattern::matchLayerNormPattern( - op, x, scale, axis, epsilon, isRMSLayerNorm); + SmallVector layerNormLocations; + if (RecomposeLayerNormFromMulPattern::matchLayerNormPattern( + op, x, scale, axis, epsilon, layerNormLocations, isRMSLayerNorm)) + return false; + + bool isExactGelu; + if (RecomposeGeluFromMulPattern::matchGeluPattern(op, x, isExactGelu)) + return false; + + return true; }); // Recompose QLinearMatMul, starting from QuantizeLinear. @@ -469,6 +745,7 @@ void RecomposeONNXToONNXPass::runOnOperation() { void onnx_mlir::getRecomposeONNXToONNXPatterns( mlir::RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); + patterns.insert(context); patterns.insert(context); patterns.insert(context); } diff --git a/src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp b/src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp index e490019afe..bdf3c55326 100644 --- a/src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp +++ b/src/Dialect/ONNX/Transforms/SetONNXNodeName.cpp @@ -70,7 +70,7 @@ void SetONNXNodeNamePass::runOnOperation() { std::string s = nodeName.getValue().str(); bool succeeded = nodeNames.insert(s).second; if (!succeeded) { - llvm::outs() << "Duplicated " << nodeNameAttr << ": " << s + llvm::errs() << "Duplicated " << nodeNameAttr << ": " << s << ". It will be updated with a new string.\n"; opsNeedNodeName.insert(op); } diff --git a/src/Dialect/ONNX/Transforms/ShapeInference.cpp b/src/Dialect/ONNX/Transforms/ShapeInference.cpp index ff29a8734b..be9ff5a948 100644 --- a/src/Dialect/ONNX/Transforms/ShapeInference.cpp +++ b/src/Dialect/ONNX/Transforms/ShapeInference.cpp @@ -10,6 +10,8 @@ #include "ShapeInference.hpp" +#define DEBUG_TYPE "onnx-shape-inference" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -24,17 +26,17 @@ namespace onnx_mlir { namespace { bool hasDynamicOrUnknownShape(Type type) { - if (auto tensorType = dyn_cast(type)) + if (auto tensorType = mlir::dyn_cast(type)) return !tensorType.hasStaticShape(); if (mlir::isa(type)) return false; - if (auto seqType = dyn_cast(type)) + if (auto seqType = mlir::dyn_cast(type)) return ShapedType::isDynamic(seqType.getLength()) || hasDynamicOrUnknownShape(seqType.getElementType()); - if (auto optType = dyn_cast(type)) + if (auto optType = mlir::dyn_cast(type)) return hasDynamicOrUnknownShape(optType.getElementType()); llvm_unreachable("unknown type"); @@ -75,7 +77,8 @@ LogicalResult inferShapes( OperationFingerPrint after(shapeInfOp); if (failed(outcome)) { assert(after == before && "op must be unchanged on failure"); - return shapeInfOp.emitOpError("shape inference failed"); + LLVM_DEBUG(llvm::errs() << "shape inference failed"); + return failure(); } // succeed only shapeInfOp or its result types changed return after == before ? failure() : success(); diff --git a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp index efcf1e0cca..397fe18cd5 100644 --- a/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp +++ b/src/Dialect/ONNX/Transforms/SimplifyShapeRelatedOps.cpp @@ -79,7 +79,7 @@ void getDimsInt64(Value val, SmallVectorImpl &result) { SmallVector dims; getDims(val, dims); for (Value v : dims) { - if (auto constOp = dyn_cast(v.getDefiningOp())) { + if (auto constOp = mlir::dyn_cast(v.getDefiningOp())) { auto valueAttr = mlir::cast(constOp.getValueAttr()); int64_t dim = valueAttr.getSplatValue(); result.emplace_back(dim); diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 166a19217d..e9dbf2e77b 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -86,11 +86,13 @@ std::unique_ptr createONNXPreKrnlVerifyPass(); /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); std::unique_ptr createLowerToKrnlPass(bool enableTiling, - bool enableSIMD, bool enableParallel, std::string opsForCall); + bool enableSIMD, bool enableParallel, bool enableFastMath, + std::string opsForCall); void configureOnnxToKrnlLoweringPass(bool reportOnParallel, bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd, bool simdIsEnabled); std::unique_ptr createProcessScfParallelPrivatePass(); +std::unique_ptr createProcessKrnlParallelClausePass(); #ifdef ONNX_MLIR_ENABLE_STABLEHLO /// Add pass for lowering to Stablehlo IR. @@ -122,6 +124,5 @@ std::unique_ptr createConvertKrnlToLLVMPass(bool verifyInputTensors, /// Pass for lowering Onnx ops to TOSA dialect std::unique_ptr createConvertONNXToTOSAPass(); - } // namespace onnx_mlir #endif \ No newline at end of file diff --git a/src/Runtime/ExecutionSession.cpp b/src/Runtime/ExecutionSession.cpp index c7607b8966..b00446edd7 100644 --- a/src/Runtime/ExecutionSession.cpp +++ b/src/Runtime/ExecutionSession.cpp @@ -125,7 +125,8 @@ const std::string *ExecutionSession::queryEntryPoints( int64_t *numOfEntryPoints) const { if (!isInitialized) throw std::runtime_error(reportInitError()); - return (const std::string *)_queryEntryPointsFunc(numOfEntryPoints); + return reinterpret_cast( + _queryEntryPointsFunc(numOfEntryPoints)); } void ExecutionSession::setEntryPoint(const std::string &entryPointName) { @@ -170,7 +171,8 @@ std::vector ExecutionSession::run( std::vector omts; for (const auto &inOmt : ins) omts.emplace_back(inOmt.get()); - auto *wrappedInput = omTensorListCreate(omts.data(), (int64_t)omts.size()); + auto *wrappedInput = + omTensorListCreate(omts.data(), static_cast(omts.size())); // Run inference. auto *wrappedOutput = _entryPointFunc(wrappedInput); diff --git a/src/Runtime/OMExternalConstant.inc b/src/Runtime/OMExternalConstant.inc index 570deb562c..44b234baa4 100644 --- a/src/Runtime/OMExternalConstant.inc +++ b/src/Runtime/OMExternalConstant.inc @@ -19,6 +19,7 @@ typedef int make_iso_compilers_happy; #include #include +#include #include #include #include @@ -58,54 +59,40 @@ void checkEndianness(const char constPackIsLE) { /// /// This function is thread-safe. /// -void omMMapBinaryFile( - void **constAddr, char *filename, int64_t size, int64_t isLE) { - checkEndianness(isLE); - char *fname = filename; -#ifdef __MVS__ - // Convert the file name to EBCDIC for the open call. - char *tPath; - size_t tLen = strlen(fname); - tPath = (char *)malloc(tLen); - if (!tPath) { - fprintf(stderr, "Error while malloc"); - return; - } - memcpy(tPath, fname, tLen); - __a2e_s(tPath); - fname = tPath; -#endif - +bool omMMapBinaryFile( + void **constAddr, char *fname, int64_t size, int64_t isLE) { if (constAddr == NULL) { - perror("Error: null pointer"); - return; + fprintf(stderr, "Error: null pointer."); + return false; } - if (constAddr[0] == NULL) { - char *filePath; - char *basePath = getenv("OM_CONSTANT_PATH"); - if (basePath) { - size_t baseLen = strlen(basePath); - size_t fnameLen = strlen(fname); - size_t sepLen = strlen(DIR_SEPARATOR); - size_t filePathLen = baseLen + sepLen + fnameLen; - filePath = (char *)malloc(filePathLen); - if (!filePath) { - fprintf(stderr, "Error while malloc"); - return; - } - memcpy(filePath, basePath, baseLen); - memcpy(filePath + baseLen, DIR_SEPARATOR, sepLen); - memcpy(filePath + baseLen + sepLen, fname, fnameLen); - filePath[filePathLen] = '\0'; - } else { - filePath = (char *)fname; - } - int fd = open(filePath, O_RDONLY); - if (fd < 0) { - fprintf(stderr, "Error while opening %s\n", filePath); - return; + // Already mmaped. Nothing to do. + if (constAddr[0] != NULL) + return true; + + char *filePath; + char *basePath = getenv("OM_CONSTANT_PATH"); + if (basePath) { + size_t baseLen = strlen(basePath); + size_t fnameLen = strlen(fname); + size_t sepLen = strlen(DIR_SEPARATOR); + size_t filePathLen = baseLen + sepLen + fnameLen + 1; + filePath = (char *)malloc(filePathLen); + if (!filePath) { + fprintf(stderr, "Error while malloc: %s", strerror(errno)); + return false; } + snprintf(filePath, filePathLen, "%s%s%s", basePath, DIR_SEPARATOR, fname); + } else { + filePath = (char *)fname; + } + int fd = open(filePath, O_RDONLY); + if (fd < 0) { + fprintf(stderr, "Error while opening %s: %s\n", filePath, strerror(errno)); + if (basePath) + free(filePath); + return false; + } #ifdef __MVS__ void *tempAddr = mmap(0, size, PROT_READ, __MAP_MEGA, fd, 0); @@ -113,36 +100,34 @@ void omMMapBinaryFile( void *tempAddr = mmap(0, size, PROT_READ, MAP_SHARED, fd, 0); #endif - if (tempAddr == MAP_FAILED) { - fprintf(stderr, "Error while mmapping %s\n", fname); - close(fd); - return; - } - - /* Prepare to compare-and-swap to setup the shared constAddr. - * If we fail, another thread beat us so free our mmap. - */ -#ifdef __MVS__ - void *expected = NULL; - if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)tempAddr)) - munmap(tempAddr, size); -#else - if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr)) - munmap(tempAddr, size); -#endif - - /* Either we succeeded in setting constAddr or someone else did it. - * Either way, constAddr is now setup. We can close our fd without - * invalidating the mmap. - */ + if (tempAddr == MAP_FAILED) { + fprintf(stderr, "Error while mmapping %s: %s\n", fname, strerror(errno)); close(fd); if (basePath) free(filePath); + return false; } + /* Prepare to compare-and-swap to setup the shared constAddr. + * If we fail, another thread beat us so free our mmap. + */ #ifdef __MVS__ - free(tPath); + void *expected = NULL; + if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)&tempAddr)) + munmap(tempAddr, size); +#else + if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr)) + munmap(tempAddr, size); #endif + + /* Either we succeeded in setting constAddr or someone else did it. + * Either way, constAddr is now setup. We can close our fd without + * invalidating the mmap. + */ + close(fd); + if (basePath) + free(filePath); + return true; } /// Return the address of a constant at a given offset. @@ -156,11 +141,11 @@ void omMMapBinaryFile( void omGetExternalConstantAddr( void **outputAddr, void **baseAddr, int64_t offset) { if (outputAddr == NULL) { - perror("Error: null pointer"); + fprintf(stderr, "Error: null pointer."); return; } if (baseAddr == NULL) { - perror("Error: null pointer"); + fprintf(stderr, "Error: null pointer."); return; } // Constant is already loaded. Nothing to do. diff --git a/src/Runtime/OMSort.inc b/src/Runtime/OMSort.inc index fea3252751..fa65cc8433 100644 --- a/src/Runtime/OMSort.inc +++ b/src/Runtime/OMSort.inc @@ -89,7 +89,7 @@ typedef int( #pragma GCC diagnostic ignored "-Wcast-qual" #endif -#define Load(typeName, to, from) typeName (to) = (from) +#define Load(typeName, to, from) typeName to = from // Convert f16 elements to f32 for comparison because we don't have logic to // compare f16 elements directly on all platforms. diff --git a/src/Runtime/OMTensor.inc b/src/Runtime/OMTensor.inc index cdc2fd05f3..49e24befc3 100644 --- a/src/Runtime/OMTensor.inc +++ b/src/Runtime/OMTensor.inc @@ -480,7 +480,7 @@ static void printData(FILE *fout, const OMTensor *tensor) { /* Helper macros to print data for 1-4D tensors */ #define LOOP_1(INDEX, IV, UB) \ fprintf(fout, "["); \ - for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ indexes[(INDEX)] = (IV); \ @@ -491,7 +491,7 @@ static void printData(FILE *fout, const OMTensor *tensor) { #define LOOP_2(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ indexes[(INDEX)] = (IV); \ @@ -501,7 +501,7 @@ static void printData(FILE *fout, const OMTensor *tensor) { #define LOOP_3(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ indexes[(INDEX)] = (IV); \ @@ -511,7 +511,7 @@ static void printData(FILE *fout, const OMTensor *tensor) { #define LOOP_4(INDEX, IV, UB, ...) \ fprintf(fout, "["); \ - for (int64_t (IV) = 0; (IV) < (UB); ++(IV)) { \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ if (IV) \ fprintf(fout, ", "); \ indexes[(INDEX)] = (IV); \ @@ -519,6 +519,26 @@ static void printData(FILE *fout, const OMTensor *tensor) { } \ fprintf(fout, "]"); +#define LOOP_5(INDEX, IV, UB, ...) \ + fprintf(fout, "["); \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ + if (IV) \ + fprintf(fout, ", "); \ + indexes[(INDEX)] = (IV); \ + LOOP_4((INDEX) + 1, __VA_ARGS__) \ + } \ + fprintf(fout, "]"); + +#define LOOP_6(INDEX, IV, UB, ...) \ + fprintf(fout, "["); \ + for (int64_t IV = 0; (IV) < (UB); ++(IV)) { \ + if (IV) \ + fprintf(fout, ", "); \ + indexes[(INDEX)] = (IV); \ + LOOP_5((INDEX) + 1, __VA_ARGS__) \ + } \ + fprintf(fout, "]"); + const OM_DATA_TYPE dataType = omTensorGetDataType(tensor); const int64_t rank = omTensorGetRank(tensor); const int64_t *shape = omTensorGetShape(tensor); @@ -545,6 +565,14 @@ static void printData(FILE *fout, const OMTensor *tensor) { int64_t indexes[4]; LOOP_4(0, i, shape[0], j, shape[1], k, shape[2], l, shape[3]) } break; + case 5: { + int64_t indexes[5]; + LOOP_5(0, i, shape[0], j, shape[1], k, shape[2], l, shape[3], m, shape[4]) + } break; + case 6: { + int64_t indexes[6]; + LOOP_6(0, i, shape[0], j, shape[1], k, shape[2], l, shape[3], m, shape[4], n, shape[5]) + } break; default: assert(false && "not implemented"); } @@ -577,6 +605,7 @@ void omTensorPrint(const char *msg, const OMTensor *tensor) { msg += 2; len -= 2; } + bool hadOneOrMoreFormats = false; while (len > 0) { if (msg[0] == '%') { if (len < 2) { @@ -586,12 +615,15 @@ void omTensorPrint(const char *msg, const OMTensor *tensor) { if (msg[1] == 'd') { /* Letter `d` for data. */ assert(tensor && "attempt to print a null OMTensor"); printData(fout, tensor); + hadOneOrMoreFormats = true; } else if (msg[1] == 's') { /* Letter `s` for signature. */ assert(tensor && "attempt to print a null OMTensor"); printSignature(fout, tensor); + hadOneOrMoreFormats = true; } else if (msg[1] == 't') { /* Letter `t` for type only. */ assert(tensor && "attempt to print a null OMTensor"); printType(fout, tensor); + hadOneOrMoreFormats = true; } else if (msg[1] == 'e') { /* Letter `e` for end. */ return; } else { @@ -607,6 +639,13 @@ void omTensorPrint(const char *msg, const OMTensor *tensor) { msg++; len--; } + if (!hadOneOrMoreFormats) { + // default per Krnl.td: %s%d + fprintf(fout, "\n"); + printSignature(fout, tensor); + printData(fout, tensor); + fprintf(fout, "\n"); + } } #ifdef __cplusplus diff --git a/src/Runtime/python/CMakeLists.txt b/src/Runtime/python/CMakeLists.txt index 16a93e9eea..a87f1ae0c9 100644 --- a/src/Runtime/python/CMakeLists.txt +++ b/src/Runtime/python/CMakeLists.txt @@ -21,6 +21,7 @@ add_onnx_mlir_library(OMPyExecutionSessionBase OMMlirUtilities pybind11::embed pybind11::python_link_helper + onnx ) if(MSVC) target_link_libraries(OMPyExecutionSessionBase @@ -69,11 +70,11 @@ target_include_directories(OMPyExecutionSessionBase # target_include_directories. pybind11_add_module(PyRuntimeC PyExecutionSession.cpp) add_dependencies(PyRuntimeC onnx_proto) -target_compile_options(PyRuntimeC - PRIVATE - $<$,$,$>:-frtti -fexceptions> - $<$:/EHsc /GR> - ) +if (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "MSVC") + target_compile_options(PyRuntimeC PRIVATE /EHsc /GR) +elseif (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "GNU") + target_compile_options(PyRuntimeC PRIVATE -frtti -fexceptions) +endif() target_compile_definitions(PyRuntimeC PRIVATE $ @@ -88,17 +89,19 @@ target_link_libraries(PyRuntimeC ) llvm_update_compile_flags(PyRuntimeC) -install(TARGETS PyRuntimeC - DESTINATION lib - ) +if(ONNX_MLIR_INSTALL_PYTHON_EXTENSIONS) + install(TARGETS PyRuntimeC + DESTINATION lib + ) +endif() pybind11_add_module(PyCompileAndRuntimeC PyOMCompileExecutionSession.cpp) add_dependencies(PyCompileAndRuntimeC onnx_proto) -target_compile_options(PyCompileAndRuntimeC - PRIVATE - $<$,$,$>:-frtti -fexceptions> - $<$:/EHsc /GR> - ) +if (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "MSVC") + target_compile_options(PyCompileAndRuntimeC PRIVATE /EHsc /GR) +elseif (CMAKE_CXX_COMPILER_FRONTEND_VARIANT STREQUAL "GNU") + target_compile_options(PyCompileAndRuntimeC PRIVATE -frtti -fexceptions) +endif() target_compile_definitions(PyCompileAndRuntimeC PRIVATE $ @@ -115,6 +118,8 @@ target_link_libraries(PyCompileAndRuntimeC ) llvm_update_compile_flags(PyCompileAndRuntimeC) -install(TARGETS PyCompileAndRuntimeC - DESTINATION lib - ) +if(ONNX_MLIR_INSTALL_PYTHON_EXTENSIONS) + install(TARGETS PyCompileAndRuntimeC + DESTINATION lib + ) +endif() diff --git a/src/Support/Diagnostic.cpp b/src/Support/Diagnostic.cpp index cf882c3fc1..58a409158c 100644 --- a/src/Support/Diagnostic.cpp +++ b/src/Support/Diagnostic.cpp @@ -19,19 +19,22 @@ using namespace mlir; namespace onnx_mlir { template -LogicalResult Diagnostic::emitAttributeOutOfRangeError(Operation &op, - const llvm::Twine &attrName, T attrVal, Range validRange) { +LogicalResult Diagnostic::emitAttributeOutOfRangeError( + Operation &op, const llvm::Twine &attrName, T attrVal, Range range) { static_assert(std::is_arithmetic::value, "Expecting an arithmetic type"); llvm::Twine msg(op.getName().getStringRef() + ": "); + std::string rangeMessage = + range.isValid() ? "" : " <>"; return emitError(op.getLoc(), msg.concat("'" + attrName + "'") .concat(" value is ") .concat(std::to_string(attrVal)) .concat(", accepted range is [") - .concat(std::to_string(validRange.min)) + .concat(std::to_string(range.min)) .concat(", ") - .concat(std::to_string(validRange.max)) - .concat("]")); + .concat(std::to_string(range.max)) + .concat("]") + .concat(rangeMessage)); } template diff --git a/src/Support/Diagnostic.hpp b/src/Support/Diagnostic.hpp index 628abfe649..1187d456b0 100644 --- a/src/Support/Diagnostic.hpp +++ b/src/Support/Diagnostic.hpp @@ -36,15 +36,22 @@ class Diagnostic { T max; public: + // Range is used in error situations, so having an assert is not very useful + // as that assert may crash the program instead of reporting the error + // condition. New approach is to report the error with an additional + // warning. Range(T min, T max) : min(min), max(max) { - assert(min <= max && "Illegal range"); + if (!isValid()) + llvm::errs() << "Warning: badly formed range(min=" << min + << ", max=" << max << ")\n"; } + bool isValid() { return min <= max; } }; /// Diagnostic message for attribute value outside of a supplied range. template static mlir::LogicalResult emitAttributeOutOfRangeError(mlir::Operation &op, - const llvm::Twine &attrName, T attrVal, Range validRange); + const llvm::Twine &attrName, T attrVal, Range range); /// Verifies whether 2 inputs have the same rank. template diff --git a/src/Support/SmallVectorHelper.hpp b/src/Support/SmallVectorHelper.hpp index a20047c578..e3ae9c2a3e 100644 --- a/src/Support/SmallVectorHelper.hpp +++ b/src/Support/SmallVectorHelper.hpp @@ -12,6 +12,9 @@ // //===----------------------------------------------------------------------===// +#ifndef ONNX_MLIR_SMALL_VECTOR_HELPER_H +#define ONNX_MLIR_SMALL_VECTOR_HELPER_H + #include "llvm/ADT/SmallVector.h" //===----------------------------------------------------------------------===// @@ -103,3 +106,5 @@ llvm::SmallVector lastFew( res.emplace_back(vec[i]); return res; } + +#endif diff --git a/src/Support/TypeUtilities.cpp b/src/Support/TypeUtilities.cpp index 7d1a05cbb5..ce43527cba 100644 --- a/src/Support/TypeUtilities.cpp +++ b/src/Support/TypeUtilities.cpp @@ -31,7 +31,7 @@ bool isRankedShapedType(Type ty) { } /// Check if a type has static shape. -bool hasStaticShape(mlir::Type ty) { +bool hasStaticShape(Type ty) { if (!isRankedShapedType(ty)) return false; return mlir::cast(ty).hasStaticShape(); diff --git a/src/Tools/binary-decoder/BinaryDecoder.cpp b/src/Tools/binary-decoder/BinaryDecoder.cpp index 6cc6eb39c8..8fc6e1dcad 100644 --- a/src/Tools/binary-decoder/BinaryDecoder.cpp +++ b/src/Tools/binary-decoder/BinaryDecoder.cpp @@ -4,7 +4,7 @@ //===----- BinaryDecoder.cpp - Decode binary files into typed arrays ------===// // -// Copyright 2019-2020 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -63,7 +63,7 @@ llvm::cl::opt DataType( template int printBuffer(std::vector buffer) { - auto *ptr = (T *)&buffer[0]; + auto *ptr = reinterpret_cast(&buffer[0]); auto data = std::vector(ptr, ptr + buffer.size() / sizeof(T)); for (const auto &elem : data) std::cout << elem << " "; diff --git a/src/Tools/onnx-mlir-opt/CMakeLists.txt b/src/Tools/onnx-mlir-opt/CMakeLists.txt index a90a670af9..ada8a839df 100644 --- a/src/Tools/onnx-mlir-opt/CMakeLists.txt +++ b/src/Tools/onnx-mlir-opt/CMakeLists.txt @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS) +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) add_onnx_mlir_executable(onnx-mlir-opt onnx-mlir-opt.cpp @@ -19,4 +20,5 @@ add_onnx_mlir_executable(onnx-mlir-opt MLIROpenMPToLLVM MLIROptLib MLIRSCFToOpenMP + ${dialect_libs} ) diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index b8285fcb69..deb0783bd0 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -90,6 +90,7 @@ void registerOMPasses(int optLevel) { return createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, /*enableSIMD, should consider disableSimdOption*/ optLevel >= 3, /*enableParallel*/ false, + /*enableFastMath*/ false, /*default is still off*/ /*opsForCall*/ ""); }); @@ -97,6 +98,10 @@ void registerOMPasses(int optLevel) { return createProcessScfParallelPrivatePass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return createProcessKrnlParallelClausePass(); + }); + mlir::registerPass([]() -> std::unique_ptr { return krnl::createConvertSeqToMemrefPass(); }); diff --git a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp index e483792ef3..29411aaf68 100644 --- a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp +++ b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp @@ -4,7 +4,7 @@ //===-------------- onnx-mlir-opt.cpp - Optimization Driver ---------------===// // -// Copyright 2019-2022 The IBM Research Authors. +// Copyright 2019-2024 The IBM Research Authors. // // ============================================================================= // @@ -57,14 +57,14 @@ void scanAndSetOptLevel(int argc, char **argv) { num = atoi(&argv[i][2]); // Get the number starting 2 char down. // Silently ignore out of bound opt levels. if (num >= 0 && num <= 3) { - OptimizationLevel = (OptLevel)num; + OptimizationLevel = static_cast(num); return; } } } void scanAndSetMCPU(int argc, char **argv) { - // Scan --mcpu and add them to the mcpu option. + // Scan for (deprecated) --mcpu and add them to the mcpu option. for (int i = argc - 1; i > 0; --i) { std::string currStr(argv[i]); if (currStr.find("--mcpu=") == 0) { @@ -80,6 +80,25 @@ void scanAndSetMCPU(int argc, char **argv) { } } +void scanAndSetMArch(int argc, char **argv) { + // Scan --march and add them to the march option. + for (int i = argc - 1; i > 0; --i) { + std::string currStr(argv[i]); + if (currStr.find("--march=") == 0) { + std::string archKind( + &argv[i][8]); // Get the string starting 8 chars down. + setTargetArch(archKind); + break; + } + if (currStr.find("-march=") == 0) { + std::string archKind( + &argv[i][7]); // Get the string starting 7 chars down. + setTargetArch(archKind); + break; + } + } +} + void scanAndSetMAccel(int argc, char **argv) { // Scan accelerators and add them to the maccel option. for (int i = argc - 1; i > 0; --i) { @@ -106,9 +125,10 @@ int main(int argc, char **argv) { // before command line options are parsed. scanAndSetOptLevel(argc, argv); - // Scan CPU manually now as it is needed to register passes + // Scan CPU and Arch manually now as it is needed to register passes // before command line options are parsed. scanAndSetMCPU(argc, argv); + scanAndSetMArch(argc, argv); // Scan maccel manually now as it is needed to initialize accelerators // before ParseCommandLineOptions() is called. diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 240f74b4e5..cc51752de0 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -8,12 +8,14 @@ add_onnx_mlir_library(OMLowerKrnlRegion MLIRTransformUtils ) - add_onnx_mlir_library(OMScfParallelPrivateRegion +add_onnx_mlir_library(OMScfParallelPrivateRegion ProcessScfParallelPrivate.cpp + ProcessKrnlParallelClause.cpp LINK_LIBS PUBLIC OMSupport MLIRTransformUtils + MLIROpenMPToLLVM ) add_onnx_mlir_library(OMInstrument diff --git a/src/Transform/ProcessKrnlParallelClause.cpp b/src/Transform/ProcessKrnlParallelClause.cpp new file mode 100644 index 0000000000..2f0d99329a --- /dev/null +++ b/src/Transform/ProcessKrnlParallelClause.cpp @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// This pass seeks KrnlParallelClauseOp and integrate its parameter in the +// enclosing OpenMP Parallel construct. +// +//===----------------------------------------------------------------------===// + +#include "src/Transform/ProcessKrnlParallelClause.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" + +#include "src/Support/TypeUtilities.hpp" + +#define DEBUG_TYPE "krnl-parallel-clause" + +using namespace mlir; + +namespace { + +struct ProcessKrnlParallelClauseWithoutScopePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlParallelClauseOp clauseOp, PatternRewriter &rewriter) const final { + // Get Parallel Krnl Clause + Operation *op = clauseOp.getOperation(); + Value numThreads = clauseOp.getNumThreads(); + auto procBind = clauseOp.getProcBind(); + + Operation *parentParallelOp = op->getParentOp(); + while (!llvm::dyn_cast_or_null(parentParallelOp)) + parentParallelOp = parentParallelOp->getParentOp(); + + if (parentParallelOp) { + // Has an enclosing OpenMP parallel construct (expected). + LLVM_DEBUG(llvm::dbgs() + << "Have a KrnlParallelClause with its OMP Parallel op\n"); + omp::ParallelOp parOp = llvm::cast(parentParallelOp); + if (numThreads) { + LLVM_DEBUG(llvm::dbgs() << " with a specific num_threads clause\n"); + // Set the numbers of threads as indicated by clause op. + // WARNING: by moving the use of numThreads from inside the loop to the + // outer OpenMP parallel construct, we may potentially move the use of + // numThreads before its definition. However, because numThreads is by + // definition loop invariant, it is very unlikely that this case occurs. + // Nevertheless, this warning attests that this might be a possibility. + // In such case, we would get a compiler warning/error of use before + // def. + MutableOperandRange mutableNumThreads = parOp.getNumThreadsMutable(); + mutableNumThreads.assign(numThreads); + } + if (procBind.has_value()) { + auto str = procBind.value().str(); + LLVM_DEBUG(llvm::dbgs() + << " with a specific proc_bind clause: " << str << "\n"); + // Set the affinity as indicated by the clause op. + if (str == "primary") + parOp.setProcBindKind(omp::ClauseProcBindKind::Primary); + else if (str == "close") + parOp.setProcBindKind(omp::ClauseProcBindKind::Close); + else if (str == "spread") + parOp.setProcBindKind(omp::ClauseProcBindKind::Spread); + else + llvm_unreachable("unkown proc_bind clause"); + } + } + // Useful info from KrnlParallelClauseOp was extracted, remove now. + rewriter.eraseOp(op); + return success(); + } +}; + +struct ProcessKrnlParallelClausePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ProcessKrnlParallelClausePass) + + ProcessKrnlParallelClausePass() {} + ProcessKrnlParallelClausePass(const ProcessKrnlParallelClausePass &pass) + : mlir::PassWrapper>() {} + + StringRef getArgument() const override { + return "process-krnl-parallel-clause"; + } + + StringRef getDescription() const override { + return "Migrate info from Krnl Parallel Clause into OpenMP Parallel " + "operation."; + } + + void runOnOperation() final; + + typedef PassWrapper> + BaseType; +}; + +void ProcessKrnlParallelClausePass::runOnOperation() { + func::FuncOp function = getOperation(); + MLIRContext *context = &getContext(); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + // Op that is used and removed here. + target.addIllegalOp(); + + RewritePatternSet patterns(context); + onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns(patterns); + + if (failed(applyPartialConversion(function, target, std::move(patterns)))) + signalPassFailure(); +} + +} // namespace + +void onnx_mlir::getKrnlParallelClauseIntoOpenMPPatterns( + mlir::RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); +} + +/*! + * Create a Krnl Parallel Clause pass. + */ +std::unique_ptr onnx_mlir::createProcessKrnlParallelClausePass() { + return std::make_unique(); +} diff --git a/src/Transform/ProcessKrnlParallelClause.hpp b/src/Transform/ProcessKrnlParallelClause.hpp new file mode 100644 index 0000000000..7f5a7bc368 --- /dev/null +++ b/src/Transform/ProcessKrnlParallelClause.hpp @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-- ProcessKrnlParallelClause.cpp - handle Krnl Parallel Clauses ------===// +// +// Copyright 2024 The IBM Research Authors. +// +// ============================================================================= +// This pass seeks KrnlParallelClauseOp and integrate its parameter in the +// enclosing OpenMP Parallel construct. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H +#define ONNX_MLIR_PROCESS_KRNL_PARALLEL_CLAUSE_H + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +// Exports the patterns. They are all plain rewrite patterns that can be used +// with any PatternRewriter, not conversion patterns. +void getKrnlParallelClauseIntoOpenMPPatterns(mlir::RewritePatternSet &patterns); + +} // namespace onnx_mlir +#endif diff --git a/src/Transform/ProcessScfParallelPrivate.cpp b/src/Transform/ProcessScfParallelPrivate.cpp index 998a138878..1996995a37 100644 --- a/src/Transform/ProcessScfParallelPrivate.cpp +++ b/src/Transform/ProcessScfParallelPrivate.cpp @@ -154,7 +154,7 @@ void onnx_mlir::getParallelPrivateScfToScfPatterns( } /*! - * Create a RecomposeONNX pass. + * Create a SCF Parallel Private pass. */ std::unique_ptr onnx_mlir::createProcessScfParallelPrivatePass() { return std::make_unique(); diff --git a/src/Transform/ProcessScfParallelPrivate.hpp b/src/Transform/ProcessScfParallelPrivate.hpp index fe6428c92c..d9450eba47 100644 --- a/src/Transform/ProcessScfParallelPrivate.hpp +++ b/src/Transform/ProcessScfParallelPrivate.hpp @@ -20,8 +20,8 @@ namespace onnx_mlir { -// Exports the RecomposeONNXToONNXPass patterns. They are all plain rewrite -// patterns that can be used with any PatternRewriter, not conversion patterns. +// Exports the patterns. They are all plain rewrite patterns that can be used +// with any PatternRewriter, not conversion patterns. void getParallelPrivateScfToScfPatterns(mlir::RewritePatternSet &patterns); } // namespace onnx_mlir diff --git a/src/Version/CMakeLists.txt b/src/Version/CMakeLists.txt index 782abdf5ed..0f8713b467 100644 --- a/src/Version/CMakeLists.txt +++ b/src/Version/CMakeLists.txt @@ -53,6 +53,10 @@ set_source_files_properties("${version_inc}" HEADER_FILE_ONLY TRUE ) +if (ONNX_MLIR_BUILD_INTREE) + set(LLVM_PACKAGE_VERSION ${LLVM_VERSION}) +endif() + add_onnx_mlir_library(OMVersion Version.cpp ${version_inc} @@ -98,4 +102,6 @@ if (ONNX_MLIR_VENDOR) list(APPEND DEFINITIONS "ONNX_MLIR_VENDOR=\"${ONNX_MLIR_VENDOR}\"") endif() list(APPEND DEFINITIONS "LLVM_PACKAGE_VERSION=\"${LLVM_PACKAGE_VERSION}\"") -target_compile_definitions(OMVersion PUBLIC ${DEFINITIONS}) +# AMD change: Make the definition of ONNX_MLIR_SHA private to avoid recompiling +# all objects when moving to another commit. +target_compile_definitions(OMVersion PRIVATE ${DEFINITIONS}) diff --git a/src/onnx-mlir.cpp b/src/onnx-mlir.cpp index be1d40554e..0a97e7e076 100644 --- a/src/onnx-mlir.cpp +++ b/src/onnx-mlir.cpp @@ -78,7 +78,9 @@ int main(int argc, char *argv[]) { // Add the short inputFilename to the first compile phase printout so that we // may better determine which compilation we are dealing with. std::filesystem::path p(inputFilename); - std::string modelShortName = p.filename(); + std::string modelShortName = p.filename().string(); + // Configure compile phase information. + SET_TOTAL_COMPILE_PHASE(emissionTarget); std::string msg = "Importing ONNX Model to MLIR Module from \"" + modelShortName + "\""; showCompilePhase(msg); diff --git a/test/accelerators/NNPA/backend/CMakeLists.txt b/test/accelerators/NNPA/backend/CMakeLists.txt index cb471514dc..272be114d2 100644 --- a/test/accelerators/NNPA/backend/CMakeLists.txt +++ b/test/accelerators/NNPA/backend/CMakeLists.txt @@ -101,12 +101,14 @@ endif() # "NO_DYNAMIC_SHAPE_TEST", backend test is skipped, otherwise the string is # passed as --dimParams option. "0:0=a,1=b,2=c|1:0=a,1=b,2=c" means that the # first, second and third dimensions of the first and second input arguments # are the same respectively. -set(NNPA_TEST_LIST +set(NNPA_TEST_LIST_z16 + # To rebuild after changes: make onnx_mlir_supported_ops # ==ARCH== NNPA - # ==ADDITIONAL_PARAGRAPH== NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. + # ==ADDITIONAL_PARAGRAPH== NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. NNPA currently only support DLFLOAT16 as its data type. Common data formats like FP32, FP16, BFLOAT need to undergo data conversions to the NNPA internal format DLFLOAT16. Hence ONNX ops which updated their tensors to BFLOAT16 will not be natively supported on NNPA. Onnx-mlir with NNPA utilizes hardware when possible. To accomplish this, the compiler converts ONNX ops to [ZHigh](Dialects/zhigh.md) ops, [ZLow](Dialects/zlow.md) ops, and are processed by the [IBM Z Deep Neural Network Library (zDNN)](https://github.com/IBM/zDNN). # ==OP== Add + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. # Scalar tensor not supported. @@ -114,6 +116,7 @@ set(NNPA_TEST_LIST # test_add_bcast_cpu # bcast not supported # ==OP== AveragePool + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors (N x C x H x W).
- `kernel_shape` must be static.
- `count_include_pad` must be default value(0).
- `ceil_mode` must be default value(0). # test_averagepool_1d_default_cpu @@ -131,12 +134,14 @@ set(NNPA_TEST_LIST # test_averagepool_3d_default_cpu # ==OP== BatchNormalization + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input and output tensor must be 4D(N x C x H x W). test_batchnorm_epsilon_cpu,zdnn_mul_ext,"0:0=a,1=b,2=c,3=d|1:0=b|2:0=b|3:0=b|4:0=b" test_batchnorm_example_cpu,zdnn_mul_ext,"0:0=a,1=b,2=c,3=d|1:0=b|2:0=b|3:0=b|4:0=b" - + # ==OP== Conv + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- Dimension in Height and weight must be static.
- `group` must be default value(1).
- `dilations` must be default value(1).
- Input and output tensors must have 4D (N x C x H x W).
- `kernel_shape` must be static. test_basic_conv_with_padding_cpu,zdnn_conv2d,NO_DYNAMIC_SHAPE_TEST @@ -147,6 +152,7 @@ set(NNPA_TEST_LIST # test_conv_with_strides_and_asymmetric_padding_cpu # ==OP== ConvTranspose + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - 1D and 3D not supported because Conv1D and Conv3D not supported in zDNN. non-default `dilations` not supported because dilated convolution not supported in zDNN. # Spatial dims must be static. @@ -161,6 +167,7 @@ set(NNPA_TEST_LIST test_convtranspose_pads_cpu,zdnn_conv2d # ==OP== Div + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. test_div_cpu,zdnn_div_ext,"0:0=a,1=b,2=c|1:0=a,1=b,2=c" @@ -168,14 +175,20 @@ set(NNPA_TEST_LIST test_div_example_cpu,zdnn_div_ext,"0:0=a|1:0=a" # ==OP== Exp + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input tensor must have 4 dimensions. test_exp_cpu,zdnn_exp_ext test_exp_example_cpu,zdnn_exp_ext # ==OP== Gemm + # ==LEVEL== z16,arch15 # ==MIN== 6 - # ==LIM== - `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`. + # ==LIM== - `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`.
+ + # Commented out for the moment. + # -`gemm_transposeA` and `gemm_transposeB` will require an "--march" or an NNPA level of at least arch15, and the "transA" or "transB" attribute must be non-zero. + # test_gemm_all_attributes_cpu # test_gemm_alpha_cpu # test_gemm_beta_cpu @@ -189,6 +202,7 @@ set(NNPA_TEST_LIST test_gemm_transposeB_cpu,zdnn_matmul_op_ext # ==OP== GlobalAveragePool + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - Input shape must be 4D tensor(NCHW).
- Dimensions in `H` and `W` must be static. test_globalaveragepool_cpu,zdnn_meanreduce2d,NO_DYNAMIC_SHAPE_TEST @@ -198,26 +212,22 @@ set(NNPA_TEST_LIST # test_globalmaxpool_precomputed_cpu # ==OP== GRU + # ==LEVEL== z16,arch15 # ==MIN== 7 # ==LIM== - `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- If `B` and `initial_h` are given, they must have static dimensions.
- `sequence_lens` is not supported for bidirectional GRU.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `linear_before_reset` must be 1.
- `layout` is not supported. # test_gru_defaults_cpu # test_gru_seq_length_cpu # test_gru_with_initial_bias_cpu - # ==OP== LeakyRelu - # ==MIN== 6 - # ==LIM== The operations immediately before and after the LeakyRelu operation must be executed on the NNPA. Otherwise, LeakyRelu is executed on the CPU. This limitation is set to avoid performance degradation. - # Leakyrelu op in following test cases doesn't run on NNPA because single LeakyRelu op is included. - # test_leakyrelu_cpu - # test_leakyrelu_default_cpu - # test_leakyrelu_example_cpu - # ==OP== Log + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input tensor must have 4 dimensions. test_log_example_cpu,zdnn_log_ext test_log_cpu,zdnn_log_ext + # ==OP== LogSoftmax + # ==LEVEL== z16,arch15 # ==MIN== 6 # test_logsoftmax_axis_0_cpu # test_logsoftmax_axis_1_cpu @@ -228,6 +238,7 @@ set(NNPA_TEST_LIST # test_logsoftmax_large_number_cpu # accuracy error in test_logsoftmax_large_number_cpu # ==OP== LSTM + # ==LEVEL== z16,arch15 # ==MIN== 7 # ==LIM== - `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.
- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `input_forget` must be default value(0).
- `layout` is not supported. test_lstm_defaults_cpu,zdnn_lstm @@ -235,6 +246,7 @@ set(NNPA_TEST_LIST # test_lstm_with_peepholes_cpu # ==OP== MatMul + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. test_matmul_2d_cpu,zdnn_matmul_op_ext @@ -242,6 +254,7 @@ set(NNPA_TEST_LIST test_matmul_4d_cpu,zdnn_matmul_op_ext,"0:0=a,1=b,2=c,3=d|1:0=a,1=b,2=d,3=c" # ==OP== Max + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. # test_max_example_cpu @@ -260,6 +273,7 @@ set(NNPA_TEST_LIST # test_max_uint64_cpu # ==OP== MaxPool + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors(N x C x H x W).
- `kernel_shape` must be static.
- `ceil_mode` must be default value(0).
- `dilations` must be default value(1). # test_maxpool_1d_default_cpu @@ -276,6 +290,7 @@ set(NNPA_TEST_LIST # test_maxpool_3d_default_cpu # ==OP== Min + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. # test_min_example_cpu @@ -294,6 +309,7 @@ set(NNPA_TEST_LIST # test_min_uint64_cpu # ==OP== Mul + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. test_mul_cpu,zdnn_mul_ext,"0:0=a,1=b,2=c|1:0=a,1=b,2=c" @@ -301,11 +317,13 @@ set(NNPA_TEST_LIST test_mul_example_cpu,zdnn_mul_ext,"0:0=a|1:0=a" # ==OP== Pow + # ==LEVEL== z16,arch15 # ==MIN== 7 # ==LIM== - Exponent should be a scalar integer and less or equal to 64. test_pow_bcast_scalar_cpu # ==OP== ReduceMean + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - `keepdims` must be 1.
- Input tensor must be 4D tensors and `axis` must be [2, 3]. # test_reduce_mean_default_axes_keepdims_example_cpu @@ -318,11 +336,13 @@ set(NNPA_TEST_LIST # test_reduce_mean_negative_axes_keepdims_random_cpu # ==OP== Relu + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input tensor must be less than or equal to 4 dimensions. test_relu_cpu,zdnn_relu_ext # ==OP== Softmax + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== - `axis` must be the last dimension, i.e. `rank - 1` or -1. # test_softmax_axis_0_cpu @@ -333,6 +353,7 @@ set(NNPA_TEST_LIST # test_softmax_large_number_cpu # accuracy error # ==OP== Softplus + # ==LEVEL== z16,arch15 # ==MIN== 1 # ==LIM== The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. # Softplus op in following test cases doesn't run on NNPA because single Softplus op is included. Softplus is tested not by backend tests but by the TestSoftplus numerical test @@ -340,6 +361,7 @@ set(NNPA_TEST_LIST # test_softplus_example_cpu,zdnn_log # ==OP== Sub + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. test_sub_cpu,zdnn_sub_ext,"0:0=a,1=b,2=c|1:0=a,1=b,2=c" @@ -347,6 +369,7 @@ set(NNPA_TEST_LIST test_sub_example_cpu,zdnn_sub_ext,"0:0=a|1:0=a" # ==OP== Sum + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== - All inputs must have the same static shape (Broadcasting not supported.)
- Single input not supported. test_sum_example_cpu,zdnn_add,"0:0=a|1:0=a|2:0=a" @@ -354,37 +377,112 @@ set(NNPA_TEST_LIST test_sum_two_inputs_cpu,zdnn_add,"0:0=a|1:0=a" # ==OP== Tanh + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input tensor must be less than or equal to 4 dimensions. # ==OP== Sigmoid + # ==LEVEL== z16,arch15 # ==MIN== 6 # ==LIM== Input tensor must be less than or equal to 4 dimensions. # Model test_densenet121_cpu,zdnn_conv2d - test_inception_v1_cpu,zdnn_conv2d - test_resnet50_cpu,zdnn_conv2d - test_shufflenet_cpu,zdnn_matmul_op_ext - # test_squeezenet_cpu,zdnn_conv # got NaN results - test_vgg19_cpu,zdnn_conv + # TODO re-enable below 2 tests + # test_inception_v1_cpu,zdnn_conv2d + # test_resnet50_cpu,zdnn_conv2d + + # test_shufflenet_cpu,zdnn_matmul_op_ext # got NaN results in check-onnx-backend-dynamic-jni-nnpa. + # Got NaN results because the last Conv running on NNPA produces dlfloat16 out-of-range values that are represented as NaN. + # test_squeezenet_cpu,zdnn_conv + # TODO re-enable below test + # test_vgg19_cpu,zdnn_conv ) -set(ENV_TEST_CASE_BY_USER "") -foreach(test_name IN LISTS NNPA_TEST_LIST) - set(ENV_TEST_CASE_BY_USER "${ENV_TEST_CASE_BY_USER} ${test_name}") + +set(NNPA_TEST_LIST_ARCH_15 + + # ==OP== Gelu + # ==LEVEL== arch15 + # ==MIN== 20 + test_gelu_default_1_cpu,zdnn_gelu_ext + test_gelu_default_2_cpu,zdnn_gelu_ext + test_gelu_tanh_1_cpu,zdnn_gelu_ext + test_gelu_tanh_2_cpu,zdnn_gelu_ext + + # Gemm Transpose + test_gemm_transposeA_cpu,zdnn_matmul_transpose_op_ext + test_gemm_transposeB_cpu,zdnn_matmul_transpose_op_ext + + # ==OP== LeakyRelu + # ==LEVEL== arch15 + # ==MIN== 6 + # ==LIM== Input tensor must be less than or equal to 4 dimensions. + test_leakyrelu_cpu,zdnn_leaky_relu_ext + test_leakyrelu_default_cpu,zdnn_leaky_relu_ext + test_leakyrelu_example_cpu,zdnn_leaky_relu_ext + + # ==OP== MatMulInteger + # ==LEVEL== arch15 + # ==MIN== 10 + test_matmulinteger_cpu,zdnn_quantized_matmul_op + + # ==OP== QLinearMatMul + # ==LEVEL== arch15 + # ==MIN== 10 + test_qlinearmatmul_2D_uint8_float32_cpu,zdnn_quantized_matmul_op + test_qlinearmatmul_3D_uint8_float32_cpu,zdnn_quantized_matmul_op + # Error: at (1, 0) mismatch 0 (actual) vs 1 (reference) + # test_qlinearmatmul_2D_int8_float32_cpu,zdnn_quantized_matmul_op + # test_qlinearmatmul_3D_int8_float32_cpu,zdnn_quantized_matmul_op + + # ==OP== ReduceMax + # ==LEVEL== arch15 + # ==MIN== 1 + # ==LIM== - We do not support `do_not_keepdims` backend tests. Only support reduction over the innermost dimension. + # Currrently, there is no backend test in ONNX that does reduction on the innermost dimension. + + # ==OP== ReduceMin + # ==LEVEL== arch15 + # ==MIN== 1 + # ==LIM== - We do not support `do_not_keepdims` backend tests. Only support reduction over the innermost dimension. + # Currrently, there is no backend test in ONNX that does reduction on the innermost dimension. + + # ==OP== Sqrt + # ==LEVEL== arch15 + # ==MIN== 6 + test_sqrt_cpu,zdnn_sqrt_ext,zdnn_invsqrt_ext + test_sqrt_example_cpu,zdnn_sqrt_ext,zdnn_invsqrt_ext +) + +set(ENV_TEST_CASE_BY_USER_z16 "") +foreach(test_name IN LISTS NNPA_TEST_LIST_z16) + set(ENV_TEST_CASE_BY_USER_z16 "${ENV_TEST_CASE_BY_USER_z16} ${test_name}") +endforeach() + +set(ENV_TEST_CASE_BY_USER_ARCH_15 "") +foreach(test_name IN LISTS NNPA_TEST_LIST_ARCH_15) + set(ENV_TEST_CASE_BY_USER_ARCH_15 "${ENV_TEST_CASE_BY_USER_ARCH_15} ${test_name}") endforeach() -set(NNPA_TESTS_ENVS TEST_MCPU=z16 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER} TEST_ATOL=0.01 TEST_RTOL=0.05) +set(NNPA_TESTS_ENVS_z16 TEST_MARCH=z16 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER_z16} TEST_ATOL=0.01 TEST_RTOL=0.05) +set(NNPA_TESTS_ENVS_ARCH_15 TEST_MARCH=arch15 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER_ARCH_15} TEST_ATOL=0.01 TEST_RTOL=0.05) -set(ENV_TEST_CASE_BY_USER_DYNAMIC "") -foreach(test_name IN LISTS NNPA_TEST_LIST) +set(ENV_TEST_CASE_BY_USER_DYNAMIC_z16 "") +foreach(test_name IN LISTS NNPA_TEST_LIST_z16) if(NOT ${test_name} MATCHES ",NO_DYNAMIC_SHAPE_TEST$") - set(ENV_TEST_CASE_BY_USER_DYNAMIC "${ENV_TEST_CASE_BY_USER_DYNAMIC} ${test_name}") + set(ENV_TEST_CASE_BY_USER_DYNAMIC_z16 "${ENV_TEST_CASE_BY_USER_DYNAMIC_z16} ${test_name}") endif() endforeach() -set(NNPA_TESTS_ENVS_DYNAMIC TEST_MCPU=z16 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER_DYNAMIC} TEST_ATOL=0.01 TEST_RTOL=0.05) +set(ENV_TEST_CASE_BY_USER_DYNAMIC_ARCH_15 "") +foreach(test_name IN LISTS NNPA_TEST_LIST_ARCH_15) + if(NOT ${test_name} MATCHES ",NO_DYNAMIC_SHAPE_TEST$") + set(ENV_TEST_CASE_BY_USER_DYNAMIC_ARCH_15 "${ENV_TEST_CASE_BY_USER_DYNAMIC_ARCH_15} ${test_name}") + endif() +endforeach() +set(NNPA_TESTS_ENVS_DYNAMIC_z16 TEST_MARCH=z16 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER_DYNAMIC_z16} TEST_ATOL=0.01 TEST_RTOL=0.05) +set(NNPA_TESTS_ENVS_DYNAMIC_ARCH_15 TEST_MARCH=arch15 TEST_MACCEL=NNPA TEST_CASE_BY_USER=${ENV_TEST_CASE_BY_USER_DYNAMIC_ARCH_15} TEST_ATOL=0.01 TEST_RTOL=0.05) # ${ONNX_HOME} is the directory where onnx downloads real model files. # Model files are saved under ${ONNX_HOME}/models/model_name/model.onnx. @@ -397,7 +495,22 @@ add_custom_target(check-onnx-backend-nnpa # Needed for convolution models to avoid NaN outputs. # Remove this if saturation is enabled by default. TEST_COMPILE_ARGS="--nnpa-saturation=true" - ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + DEPENDS + ${FILE_GENERATE_DIR}/test.py + ${FILE_GENERATE_DIR}/test_config.py + ) + +# Ensure check-onnx-backend-ARCH_15-nnpa is backwards compatible +add_custom_target(check-onnx-backend-arch15-nnpa + COMMAND + TEST_INSTRUCTION_CHECK=true + ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-arch15-nnpa + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + && ${NNPA_TESTS_ENVS_ARCH_15} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -411,7 +524,23 @@ add_custom_target(check-onnx-backend-dynamic-nnpa # Needed for convolution models to avoid NaN outputs. # Remove this if saturation is enabled by default. TEST_COMPILE_ARGS="--nnpa-saturation=true" - ${NNPA_TESTS_ENVS_DYNAMIC} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + ${NNPA_TESTS_ENVS_DYNAMIC_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + DEPENDS + ${FILE_GENERATE_DIR}/test.py + ${FILE_GENERATE_DIR}/test_config.py + ) + +# Ensure check-onnx-backend-dynamic-arch15-nnpa is backwards compatible +add_custom_target(check-onnx-backend-dynamic-arch15-nnpa + COMMAND + ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-dynamic-arch15-nnpa + TEST_INSTRUCTION_CHECK=true + TEST_DYNAMIC=true + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" + ${NNPA_TESTS_ENVS_DYNAMIC_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + && ${NNPA_TESTS_ENVS_DYNAMIC_ARCH_15} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -427,7 +556,7 @@ add_custom_target(check-onnx-backend-constant-nnpa # Needed for convolution models to avoid NaN outputs. # Remove this if saturation is enabled by default. TEST_COMPILE_ARGS="--nnpa-saturation=true" - ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -439,7 +568,7 @@ add_custom_target(check-onnx-backend-compilerlib-nnpa # Needed for convolution models to avoid NaN outputs. # Remove this if saturation is enabled by default. TEST_COMPILE_ARGS="--nnpa-saturation=true" - ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config_compilerlib.py @@ -454,14 +583,23 @@ add_custom_target(clean-onnx-backend-nnpa add_dependencies(check-onnx-backend-nnpa onnx-mlir) add_dependencies(check-onnx-backend-nnpa PyRuntimeC) +add_dependencies(check-onnx-backend-arch15-nnpa onnx-mlir) +add_dependencies(check-onnx-backend-arch15-nnpa PyRuntimeC) add_dependencies(check-onnx-backend-dynamic-nnpa onnx-mlir) add_dependencies(check-onnx-backend-dynamic-nnpa PyRuntimeC) +add_dependencies(check-onnx-backend-dynamic-arch15-nnpa onnx-mlir) +add_dependencies(check-onnx-backend-dynamic-arch15-nnpa PyRuntimeC) add_dependencies(check-onnx-backend-constant-nnpa onnx-mlir) add_dependencies(check-onnx-backend-constant-nnpa PyRuntimeC) add_dependencies(check-onnx-backend-compilerlib-nnpa CompilerLibTest) add_dependencies(check-onnx-backend-compilerlib-nnpa PyRuntimeC) add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-nnpa) +# If on arch 15 machines then (TODO: enable once avail on test machines): +# add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-arch15-nnpa) +# else while on an arch 14 machine: +add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-nnpa) +# end if. add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-dynamic-nnpa) add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-constant-nnpa) @@ -472,7 +610,10 @@ if (ONNX_MLIR_ENABLE_JNI) COMMAND ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-jni-nnpa TEST_EMIT=jni JSONITER_JAR=${JSONITER_JAR} - ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -482,7 +623,10 @@ if (ONNX_MLIR_ENABLE_JNI) COMMAND ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-dynamic-jni-nnpa TEST_DYNAMIC=true TEST_EMIT=jni JSONITER_JAR=${JSONITER_JAR} - ${NNPA_TESTS_ENVS_DYNAMIC} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" + ${NNPA_TESTS_ENVS_DYNAMIC_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -492,7 +636,10 @@ if (ONNX_MLIR_ENABLE_JNI) COMMAND ONNX_HOME=${FILE_GENERATE_DIR}/check-onnx-backend-constant-jni-nnpa TEST_CONSTANT=true TEST_EMIT=jni JSONITER_JAR=${JSONITER_JAR} - ${NNPA_TESTS_ENVS} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py + # Needed for convolution models to avoid NaN outputs. + # Remove this if saturation is enabled by default. + TEST_COMPILE_ARGS="--nnpa-saturation=true" + ${NNPA_TESTS_ENVS_z16} ${BACKEND_TEST_COMMAND} ${BACKEND_TEST_ARGS} ${FILE_GENERATE_DIR}/test.py DEPENDS ${FILE_GENERATE_DIR}/test.py ${FILE_GENERATE_DIR}/test_config.py @@ -511,10 +658,12 @@ if (ONNX_MLIR_ENABLE_JNI) add_dependencies(check-onnx-backend-constant-jni-nnpa javaruntime) add_dependencies(check-onnx-backend-constant-jni-nnpa jniruntime) - add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-jni-nnpa) - add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-dynamic-jni-nnpa) - add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-constant-jni-nnpa) + # ONNX models failed with NaN results, so temporarily disable these. + #add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-jni-nnpa) + #add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-dynamic-jni-nnpa) + #add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-constant-jni-nnpa) else() message(STATUS " JNI backend-nnpa tests : OFF") endif() + diff --git a/test/accelerators/NNPA/numerical/CMakeLists.txt b/test/accelerators/NNPA/numerical/CMakeLists.txt index 20f8f28b38..3f81055175 100644 --- a/test/accelerators/NNPA/numerical/CMakeLists.txt +++ b/test/accelerators/NNPA/numerical/CMakeLists.txt @@ -41,7 +41,7 @@ function(add_numerical_test test_name) # Optimization level set by ONNX_MLIR_TEST_OPTLEVEL, defaults to 3 add_test(NAME ${test_name} - COMMAND ${test_name} -O${ONNX_MLIR_TEST_OPTLEVEL} --mcpu=z16 --maccel=NNPA + COMMAND ${test_name} -O${ONNX_MLIR_TEST_OPTLEVEL} --march=z16 --maccel=NNPA WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) set_tests_properties(${test_name} PROPERTIES LABELS numerical-nnpa) @@ -93,7 +93,7 @@ add_numerical_test(TestGemmNNPA # LSTM set(TestLSTMNNPA_instruction zdnn_lstm) # Automatically set following config when using --maccel=NNPA -# set(TestLSTMNNPA_config "-peephole=0") +set(TestLSTMNNPA_config "-peephole=0") add_numerical_test(TestLSTMNNPA ${ONNX_NUMERICALTEST_SRC_DIR}/TestLSTM.cpp LINK_LIBS PRIVATE ${TEST_LINK_LIBS} @@ -102,12 +102,13 @@ add_numerical_test(TestLSTMNNPA # GRU set(TestGRUNNPA_instruction zdnn_gru) # Automatically set following config when using --maccel=NNPA -# set(TestGRUNNPA_config "-linearBeforeReset=1") +set(TestGRUNNPA_config "-linearBeforeReset=1") add_numerical_test(TestGRUNNPA ${ONNX_NUMERICALTEST_SRC_DIR}/TestGRU.cpp LINK_LIBS PRIVATE ${TEST_LINK_LIBS} ) + # LeakyRelu set(TestLeakyReluNNPA_instruction zdnn_mul) # Automatically set following config when using --maccel=NNPA diff --git a/test/backend/CMakeLists.txt b/test/backend/CMakeLists.txt index 8d0bb04991..cfc65f5716 100644 --- a/test/backend/CMakeLists.txt +++ b/test/backend/CMakeLists.txt @@ -226,6 +226,7 @@ add_dependencies(check-onnx-backend-model onnx-mlir) add_dependencies(check-onnx-backend-model PyRuntimeC) add_dependencies(check-onnx-backend-signature onnx-mlir) add_dependencies(check-onnx-backend-signature PyRuntimeC) +add_dependencies(check-onnx-backend-case PyRuntimeC) add_dependencies(check-onnx-backend-input-verification onnx-mlir) add_dependencies(check-onnx-backend-input-verification PyRuntimeC) add_dependencies(check-onnx-backend-compilerlib CompilerLibTest) diff --git a/test/backend/all_test_names.txt b/test/backend/all_test_names.txt index 3b24d9863d..0daf5403ea 100644 --- a/test/backend/all_test_names.txt +++ b/test/backend/all_test_names.txt @@ -1,5 +1,5 @@ # This file is automatically generated by "make check-onnx-backend-case" -# From onnx 1.15.0 +# From onnx 1.17.0 # All test cases for cpu target test_bvlc_alexnet_cpu test_densenet121_cpu @@ -36,6 +36,8 @@ test_ai_onnx_ml_label_encoder_string_int_cpu test_ai_onnx_ml_label_encoder_string_int_no_default_cpu test_ai_onnx_ml_label_encoder_tensor_mapping_cpu test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu +test_ai_onnx_ml_tree_ensemble_set_membership_cpu +test_ai_onnx_ml_tree_ensemble_single_tree_cpu test_and2d_cpu test_and3d_cpu test_and4d_cpu @@ -153,6 +155,8 @@ test_cast_FLOAT16_to_FLOAT8E4M3FN_cpu test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_cpu test_cast_FLOAT16_to_FLOAT8E5M2_cpu test_cast_FLOAT16_to_FLOAT_cpu +test_cast_FLOAT16_to_INT4_cpu +test_cast_FLOAT16_to_UINT4_cpu test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_cpu test_cast_FLOAT8E4M3FNUZ_to_FLOAT_cpu test_cast_FLOAT8E4M3FN_to_FLOAT16_cpu @@ -168,8 +172,16 @@ test_cast_FLOAT_to_FLOAT8E4M3FNUZ_cpu test_cast_FLOAT_to_FLOAT8E4M3FN_cpu test_cast_FLOAT_to_FLOAT8E5M2FNUZ_cpu test_cast_FLOAT_to_FLOAT8E5M2_cpu +test_cast_FLOAT_to_INT4_cpu test_cast_FLOAT_to_STRING_cpu +test_cast_FLOAT_to_UINT4_cpu +test_cast_INT4_to_FLOAT16_cpu +test_cast_INT4_to_FLOAT_cpu +test_cast_INT4_to_INT8_cpu test_cast_STRING_to_FLOAT_cpu +test_cast_UINT4_to_FLOAT16_cpu +test_cast_UINT4_to_FLOAT_cpu +test_cast_UINT4_to_UINT8_cpu test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_cpu test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_cpu test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_cpu @@ -291,6 +303,8 @@ test_convtranspose_3d_cpu test_convtranspose_autopad_same_cpu test_convtranspose_cpu test_convtranspose_dilations_cpu +test_convtranspose_group_2_cpu +test_convtranspose_group_2_image_3_cpu test_convtranspose_kernel_shape_cpu test_convtranspose_output_shape_cpu test_convtranspose_pad_cpu @@ -311,11 +325,16 @@ test_deform_conv_with_multiple_offset_groups_cpu test_depthtospace_crd_mode_example_cpu test_depthtospace_example_cpu test_dequantizelinear_axis_cpu +test_dequantizelinear_blocked_cpu test_dequantizelinear_cpu test_dequantizelinear_e4m3fn_cpu test_dequantizelinear_e4m3fn_float16_cpu test_dequantizelinear_e4m3fn_zero_point_cpu test_dequantizelinear_e5m2_cpu +test_dequantizelinear_int16_cpu +test_dequantizelinear_int4_cpu +test_dequantizelinear_uint16_cpu +test_dequantizelinear_uint4_cpu test_det_2d_cpu test_det_nd_cpu test_dft_axis_cpu @@ -615,6 +634,7 @@ test_max_uint64_cpu test_max_uint8_cpu test_maxpool_1d_default_cpu test_maxpool_2d_ceil_cpu +test_maxpool_2d_ceil_output_size_reduce_by_one_cpu test_maxpool_2d_default_cpu test_maxpool_2d_dilations_cpu test_maxpool_2d_pads_cpu @@ -769,12 +789,24 @@ test_prelu_broadcast_expanded_cpu test_prelu_example_cpu test_prelu_example_expanded_cpu test_qlinearconv_cpu -test_qlinearmatmul_2D_cpu -test_qlinearmatmul_3D_cpu +test_qlinearmatmul_2D_int8_float16_cpu +test_qlinearmatmul_2D_int8_float32_cpu +test_qlinearmatmul_2D_uint8_float16_cpu +test_qlinearmatmul_2D_uint8_float32_cpu +test_qlinearmatmul_3D_int8_float16_cpu +test_qlinearmatmul_3D_int8_float32_cpu +test_qlinearmatmul_3D_uint8_float16_cpu +test_qlinearmatmul_3D_uint8_float32_cpu test_quantizelinear_axis_cpu +test_quantizelinear_blocked_asymmetric_cpu +test_quantizelinear_blocked_symmetric_cpu test_quantizelinear_cpu test_quantizelinear_e4m3fn_cpu test_quantizelinear_e5m2_cpu +test_quantizelinear_int16_cpu +test_quantizelinear_int4_cpu +test_quantizelinear_uint16_cpu +test_quantizelinear_uint4_cpu test_range_float_type_positive_delta_cpu test_range_float_type_positive_delta_expanded_cpu test_range_int32_type_negative_delta_cpu @@ -850,6 +882,7 @@ test_reduce_max_default_axes_keepdim_example_cpu test_reduce_max_default_axes_keepdims_random_cpu test_reduce_max_do_not_keepdims_example_cpu test_reduce_max_do_not_keepdims_random_cpu +test_reduce_max_empty_set_cpu test_reduce_max_keepdims_example_cpu test_reduce_max_keepdims_random_cpu test_reduce_max_negative_axes_keepdims_example_cpu @@ -885,6 +918,7 @@ test_reduce_sum_default_axes_keepdims_example_cpu test_reduce_sum_default_axes_keepdims_random_cpu test_reduce_sum_do_not_keepdims_example_cpu test_reduce_sum_do_not_keepdims_random_cpu +test_reduce_sum_empty_axes_input_noop_cpu test_reduce_sum_empty_axes_input_noop_example_cpu test_reduce_sum_empty_set_cpu test_reduce_sum_empty_set_non_reduced_axis_zero_cpu @@ -945,6 +979,7 @@ test_resize_downsample_sizes_nearest_not_smaller_cpu test_resize_tf_crop_and_resize_axes_2_3_cpu test_resize_tf_crop_and_resize_axes_3_2_cpu test_resize_tf_crop_and_resize_cpu +test_resize_tf_crop_and_resize_extrapolation_value_cpu test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_cpu test_resize_upsample_scales_cubic_align_corners_cpu test_resize_upsample_scales_cubic_asymmetric_cpu @@ -962,6 +997,7 @@ test_resize_upsample_sizes_nearest_ceil_half_pixel_cpu test_resize_upsample_sizes_nearest_cpu test_resize_upsample_sizes_nearest_floor_align_corners_cpu test_resize_upsample_sizes_nearest_not_larger_cpu +test_resize_upsample_sizes_nearest_not_smaller_cpu test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cpu test_reversesequence_batch_cpu test_reversesequence_time_cpu diff --git a/test/backend/common.py b/test/backend/common.py index a9b077eae7..c6f5e8d827 100644 --- a/test/backend/common.py +++ b/test/backend/common.py @@ -5,7 +5,7 @@ # Copyright 2021-2022 The IBM Research Authors. # ################################################################################ -# commom function `compile_model` called by both +# Common function `compile_model` called by both # SignatureExecutionSession and EndiannessAwareExecutionSession ################################################################################ from __future__ import absolute_import @@ -114,7 +114,8 @@ def compile_model(model, emit): command_list = [TEST_DRIVER] if args.Optlevel: command_list.append("-O" + args.Optlevel) - if args.mcpu: + if args.mcpu: # deprecated + print("warning, --mcpu option is deprecated, please use --march instead") command_list.append("--mcpu=" + args.mcpu) if args.march: command_list.append("--march=" + args.march) diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py index 898e5fe939..f3059af931 100644 --- a/test/backend/inference_backend.py +++ b/test/backend/inference_backend.py @@ -6,6 +6,7 @@ # ################################################################################ from __future__ import absolute_import +from __future__ import annotations from __future__ import division from __future__ import print_function from __future__ import unicode_literals @@ -73,6 +74,7 @@ def get_test_models(): ############################################################ # Elementary ops, ordered in the order they are found in # onnx-mlir/third_party/onnx/onnx/backend/test/case/node. + # To rebuild after changes: make onnx_mlir_supported_ops # ==ARCH== cpu # ==OP== Abs # ==MIN== 6 @@ -394,7 +396,7 @@ def get_test_models(): }, # ==OP== Cast # ==MIN== 6 - # ==LIM== Cast only between float and double types. Only ppc64le and MacOS platforms support float16. + # ==LIM== Cast only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. "test_cast_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -433,7 +435,7 @@ def get_test_models(): "test_cast_STRING_to_FLOAT_cpu": {}, # appears unsupported at this time # ==OP== CastLike # ==MIN== 19 - # ==LIM== CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. + # ==LIM== CastLike only between float and double types. Only ppc64le and MacOS platforms support float16. Does not support int4 and uint4. "test_castlike_FLOAT_to_DOUBLE_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -614,10 +616,12 @@ def get_test_models(): }, # ==OP== Constant # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. # By def, no dynamic shapes. "test_constant_cpu": {STATIC_SHAPE: {}}, # ==OP== ConstantOfShape # ==MIN== 9 + # ==LIM== Does not support int4 and uint4. # By def, no dynamic shapes. "test_constantofshape_float_ones_cpu": {STATIC_SHAPE: {}}, "test_constantofshape_int_zeros_cpu": {STATIC_SHAPE: {}}, @@ -789,7 +793,7 @@ def get_test_models(): }, # ==OP== DequantizeLinear # ==MIN== 10 - # ==LIM== Only support for per-tensor or layer dequantization. No support for per-axis dequantization. + # ==LIM== Only support for per-tensor or layer dequantization. No support for per-axis dequantization. Does not support int4 and uint4. # "test_dequantizelinear_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_dequantizelinear_cpu": { STATIC_SHAPE: {}, @@ -980,6 +984,7 @@ def get_test_models(): }, # ==OP== Flatten # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. "test_flatten_axis0_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1104,41 +1109,41 @@ def get_test_models(): DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, - # "test_gelu_default_1_expanded_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_gelu_default_1_expanded_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_gelu_default_2_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, - # "test_gelu_default_2_expanded_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_gelu_default_2_expanded_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_gelu_tanh_1_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, - # "test_gelu_tanh_1_expanded_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_gelu_tanh_1_expanded_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, "test_gelu_tanh_2_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, CONSTANT_INPUT: {-1}, }, - # "test_gelu_tanh_2_expanded_cpu": { - # STATIC_SHAPE: {}, - # DYNAMIC_SHAPE: {-1: {-1}}, - # CONSTANT_INPUT: {-1}, - # }, + "test_gelu_tanh_2_expanded_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, # ==OP== Gemm # ==MIN== 6 "test_gemm_all_attributes_cpu": { @@ -1357,7 +1362,7 @@ def get_test_models(): }, # ==OP== Identity # ==MIN== 16 - # ==LIM== Sequence identity not supported. + # ==LIM== Sequence identity not supported. Does not support int4 and uint4. "test_identity_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1367,7 +1372,7 @@ def get_test_models(): # "test_identity_opt_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== If # ==MIN== 16 - # ==LIM== Sequence and Optional outputs are not supported. + # ==LIM== Sequence and Optional outputs are not supported. Does not support int4 and uint4. "test_if_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -1784,7 +1789,7 @@ def get_test_models(): }, # ==OP== Loop # ==MIN== 1 - # ==LIM== Input must have static shape. + # ==LIM== Input must have static shape. Does not support int4 and uint4. "test_loop11_cpu": { STATIC_SHAPE: {}, # Need to enable ConvertSeqToMemrefPass for dynamic test. @@ -2263,7 +2268,7 @@ def get_test_models(): }, # ==OP== Pad # ==MIN== 2 - # ==LIM== axes input not supported + # ==LIM== axes input not supported. Does not support int4 and uint4. "test_constant_pad_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {0: {-1}}, @@ -2316,7 +2321,7 @@ def get_test_models(): }, # ==OP== QuantizeLinear # ==MIN== 10 - # ==LIM== Do not support per-axis and i8 quantization. + # ==LIM== Does not support per-axis and i8 quantization. Does not support int4 and uint4. # "test_quantizelinear_axis_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_quantizelinear_cpu": { STATIC_SHAPE: {}, @@ -2447,9 +2452,17 @@ def get_test_models(): }, # ==OP== ReduceMax # ==MIN== 1 - # ==LIM== do_not_keep_dim not supported. - # "test_reduce_max_default_axes_keepdim_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # "test_reduce_max_default_axes_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # ==LIM== do_not_keep_dims not supported. + "test_reduce_max_default_axes_keepdim_example_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, + "test_reduce_max_default_axes_keepdims_random_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, # "test_reduce_max_do_not_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_max_do_not_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_reduce_max_keepdims_example_cpu": { @@ -2474,7 +2487,7 @@ def get_test_models(): }, # ==OP== ReduceMean # ==MIN== 1 - # ==LIM== do_not_keep_dim not supported. + # ==LIM== do_not_keep_dims not supported. # "test_reduce_mean_default_axes_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_mean_default_axes_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_mean_do_not_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -2501,9 +2514,17 @@ def get_test_models(): }, # ==OP== ReduceMin # ==MIN== 1 - # ==LIM== do_not_keep_dim not supported. - # "test_reduce_min_default_axes_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, - # "test_reduce_min_default_axes_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, + # ==LIM== do_not_keep_dims not supported. + "test_reduce_min_default_axes_keepdims_example_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, + "test_reduce_min_default_axes_keepdims_random_cpu": { + STATIC_SHAPE: {}, + DYNAMIC_SHAPE: {-1: {-1}}, + CONSTANT_INPUT: {-1}, + }, # "test_reduce_min_do_not_keepdims_example_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_reduce_min_do_not_keepdims_random_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, "test_reduce_min_keepdims_example_cpu": { @@ -2622,7 +2643,7 @@ def get_test_models(): }, # ==OP== Reshape # ==MIN== 5 - # ==LIM== allowzero not supported. Input `shape` must have static dimension. + # ==LIM== allowzero not supported. Input `shape` must have static dimension. Does not support int4 and uint4. "test_reshape_extended_dims_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {0: {-1}}, @@ -2801,7 +2822,7 @@ def get_test_models(): }, # ==OP== Scan # ==MIN== 8 - # ==LIM== Does not support dynamic shapes. + # ==LIM== Does not support dynamic shapes. Does not support int4 and uint4. # ==TODO== Precision issue with newer opset, maybe just unsupported. Dynamic shape? # "test_scan_sum_cpu": {STATIC_SHAPE:{}}, "test_scan9_sum_cpu": {STATIC_SHAPE: {}}, @@ -2858,7 +2879,7 @@ def get_test_models(): # "test_sequence_insert_at_back_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # ==OP== Shape # ==MIN== 15 - # ==LIM== Does not support start and end attributes. + # ==LIM== Does not support start and end attributes. Does not support int4 and uint4. "test_shape_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -2914,6 +2935,7 @@ def get_test_models(): }, # ==OP== Size # ==MIN== 13 + # ==LIM== Does not support int4 and uint4. "test_size_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -3041,7 +3063,7 @@ def get_test_models(): }, # ==OP== Squeeze # ==MIN== 1 - # ==LIM== Does not support static and dynamic shape. + # ==LIM== Does not support static and dynamic shape. Does not support int4 and uint4. # ==TODO== Temporally removed due to changes in onnx 1.8.1 # "test_squeeze_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_squeeze_negative_axes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -3140,6 +3162,7 @@ def get_test_models(): }, # ==OP== Transpose # ==MIN== 1 + # ==LIM== Does not support int4 and uint4. "test_transpose_default_cpu": { STATIC_SHAPE: {}, DYNAMIC_SHAPE: {-1: {-1}}, @@ -3285,7 +3308,7 @@ def get_test_models(): }, # ==OP== Unsqueeze # ==MIN== 1 - # ==LIM== Does not support static and dynamic shape. + # ==LIM== Does not support static and dynamic shape. Does not support int4 and uint4. # ==TODO== Temporally removed due to changes in onnx 1.8.1 # "test_unsqueeze_axis_0_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, # "test_unsqueeze_axis_1_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, @@ -3619,11 +3642,12 @@ def assert_similar_outputs( outputs: Sequence[Any], rtol: float, atol: float, + model_dir: str | None = None, ) -> None: rtol = float(os.getenv("TEST_RTOL", rtol)) atol = float(os.getenv("TEST_ATOL", atol)) super(InferenceBackendTest, cls).assert_similar_outputs( - ref_outputs, outputs, rtol, atol + ref_outputs, outputs, rtol, atol, model_dir ) def _add_onnxmlir_model_test( diff --git a/test/backend/variables.py b/test/backend/variables.py index 6a2594b6af..b13fe96c50 100644 --- a/test/backend/variables.py +++ b/test/backend/variables.py @@ -170,7 +170,7 @@ def get_args_from_env(): "--mcpu", type=str, default=os.getenv("TEST_MCPU", ""), - help="target a specific cpu, passed to the compiler", + help="target a specific cpu, passed to the compiler (deprecated, use --march)", ) parser.add_argument( "--march", diff --git a/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir b/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir index 7a07ed5732..276bde1324 100644 --- a/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir +++ b/test/mlir/accelerators/nnpa/analysis/dyn-dim-analysis.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --onnx-dim-analysis %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --onnx-dim-analysis %s -split-input-file | FileCheck %s // COM: test zdnn unary operations. Use Relu as a sample. func.func @test_stick_unary_unstick(%arg0 : tensor) -> tensor { @@ -104,7 +104,7 @@ func.func @test_stick_matmul_unstick(%arg0 : tensor) -> tensor) -> tensor> %none = "onnx.NoValue"() {value} : () -> none - %4 = "zhigh.MatMul"(%1, %3, %none) : (tensor>, tensor>, none) -> tensor> + %4 = "zhigh.MatMul"(%1, %3, %none) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor> %5 = "zhigh.Unstick"(%4) : (tensor>) -> tensor "onnx.Return"(%5) : (tensor) -> () @@ -131,7 +131,7 @@ func.func @test_stick_matmul_unstick(%arg0 : tensor) -> tensor>) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_3_]]) {axis = 0 : si64, group_id = [[GROUP_0_]] : si64} : (tensor>) -> () // CHECK: [[VAR_4_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_5_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) : (tensor>, tensor>, none) -> tensor> +// CHECK: [[VAR_5_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor> // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 1 : si64, group_id = [[GROUP_1_]] : si64} : (tensor>) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 2 : si64, group_id = [[GROUP_1_]] : si64} : (tensor>) -> () // CHECK-DAG: "onnx.DimGroup"([[VAR_5_]]) {axis = 0 : si64, group_id = [[GROUP_0_]] : si64} : (tensor>) -> () diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir index 261b9e46ae..7e43956edc 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --device-placement --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s +// RUN: onnx-mlir-opt --device-placement --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} { func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> { diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_load_config_file.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_load_config_file.mlir index ac630f16e1..4559df6875 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_load_config_file.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_load_config_file.mlir @@ -1,10 +1,10 @@ -// RUN: cfg_file=$(dirname %s)/load-cfg-all-on-cpu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=ALL-ON-CPU +// RUN: cfg_file=$(dirname %s)/load-cfg-all-on-cpu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=ALL-ON-CPU -// RUN: cfg_file=$(dirname %s)/load-cfg-all-relu-on-cpu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=ALL-RELU-ON-CPU +// RUN: cfg_file=$(dirname %s)/load-cfg-all-relu-on-cpu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=ALL-RELU-ON-CPU -// RUN: cfg_file=$(dirname %s)/load-cfg-not-match-relu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=NOT-MATCH-RELU +// RUN: cfg_file=$(dirname %s)/load-cfg-not-match-relu.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=NOT-MATCH-RELU -// RUN: cfg_file=$(dirname %s)/load-cfg-overlapping-condition.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=OVERLAPPING +// RUN: cfg_file=$(dirname %s)/load-cfg-overlapping-condition.json && onnx-mlir-opt --device-placement=load-config-file=$cfg_file --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s --check-prefix=OVERLAPPING func.func @test_load_config_file_all_on_cpu(%arg0: tensor) -> tensor { %0 = "onnx.Relu"(%arg0) {onnx_node_name = "Relu_0"} : (tensor) -> tensor diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_perf_model.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_perf_model.mlir index f36bb3d2b8..59cdeaeba9 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_perf_model.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_perf_model.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --device-placement=use-faster=true --mcpu=z16 --maccel=NNPA --split-input-file %s | FileCheck %s +// RUN: onnx-mlir-opt --device-placement=use-faster=true --march=z16 --maccel=NNPA --split-input-file %s | FileCheck %s // ----- // Shape is such that this op is nearly guaranteed to be faster on CPU. diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_save_config_file.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_save_config_file.mlir index bd29c6549d..10c8966ea5 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_save_config_file.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/device_placement_pass_save_config_file.mlir @@ -1,4 +1,4 @@ -// RUN: cfg_file=$(dirname %s)/save-cfg.json && onnx-mlir-opt --device-placement=save-config-file=$cfg_file --mcpu=z16 --maccel=NNPA --split-input-file %s && cat $cfg_file | FileCheck %s && rm $cfg_file +// RUN: cfg_file=$(dirname %s)/save-cfg.json && onnx-mlir-opt --device-placement=save-config-file=$cfg_file --march=z16 --maccel=NNPA --split-input-file %s && cat $cfg_file | FileCheck %s && rm $cfg_file func.func @test_save_config_file(%arg0: tensor) -> tensor { %0 = "onnx.Relu"(%arg0) {onnx_node_name = "Relu_0"} : (tensor) -> tensor diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir index 05d14ca167..073204d9fd 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-onnxir.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir --EmitONNXIR --mcpu=z16 --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s +// RUN: onnx-mlir --EmitONNXIR --march=z16 --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s module attributes {llvm.data_layout = "E-m:e-i1:8:16-i8:8:16-i64:64-f128:64-v128:64-a:8:16-n32:64", llvm.target_triple = "s390x-ibm-linux", "onnx-mlir.symbol-postfix" = "model"} { func.func @mnist(%arg0: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> { diff --git a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir index 9b1bd2935d..0667e0e3b0 100644 --- a/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir +++ b/test/mlir/accelerators/nnpa/conversion/device-placement/emit-zhighir.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir --EmitZHighIR --mcpu=z16 --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s +//&& RUN: onnx-mlir --EmitZHighIR --mcpu=z16 --maccel=NNPA --disable-constant-prop=true --printIR %s | FileCheck %s // Note that, we intentionally add `device=cpu` into onnx.Gemm to force it run on CPU. module { diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-level.mlir index 90f9f85341..7ec4ed1ce6 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-level.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR --EmitZHighIR -profile-ir=Onnx %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR --EmitZHighIR -profile-ir=Onnx %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir index 059c3bcfb8..793e3bcd35 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-onnx-zhigh-level.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR --EmitZHighIR --profile-ir=ZHigh %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR --EmitZHighIR --profile-ir=ZHigh %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-zhigh-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-zhigh-level.mlir index 89f226d57f..d3edf91c1b 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-zhigh-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-zhigh-level.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR --EmitZLowIR --instrument-stage=ZHigh --instrument-ops=zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR --EmitZLowIR --instrument-stage=ZHigh --instrument-ops=zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/instrument/add-zlow-level.mlir b/test/mlir/accelerators/nnpa/conversion/instrument/add-zlow-level.mlir index e71c98e640..a4cc105245 100644 --- a/test/mlir/accelerators/nnpa/conversion/instrument/add-zlow-level.mlir +++ b/test/mlir/accelerators/nnpa/conversion/instrument/add-zlow-level.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR --EmitZLowIR --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR --EmitZLowIR --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-arch15.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-arch15.mlir new file mode 100644 index 0000000000..ffa01a707f --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-arch15.mlir @@ -0,0 +1,148 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s + +// ----- + +// Check stickification with saturation. +func.func @test_stick_with_saturation() -> () { + %0 = memref.alloc() : memref<10x10xf32> + %1 = memref.alloc() : memref<1x1x32x64xf16> + "zlow.stick"(%0, %1) {saturation = -1 : si64} : (memref<10x10xf32>, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_stick_with_saturation + // CHECK: llvm.call @zdnn_transform_ztensor_with_saturation({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 +} + +// ----- + +// Check stickification without saturation. +func.func @test_stick_without_saturation() -> () { + %0 = memref.alloc() : memref<10x10xf32> + %1 = memref.alloc() : memref<1x1x32x64xf16> + "zlow.stick"(%0, %1) {saturation = 0 : si64} : (memref<10x10xf32>, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_stick_without_saturation + // CHECK: llvm.call @zdnn_transform_ztensor({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.gelu calls the correct zDNN API or not. +func.func @test_call_zdnn_gelu() -> () { + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %shape = memref.alloc() : memref<2xi64> + "zlow.gelu"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_gelu + // CHECK: {{.*}} = llvm.call @zdnn_gelu_ext({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.leakyrelu calls the correct zDNN API or not. +func.func @test_call_zdnn_leaky_relu() -> () { + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %shape = memref.alloc() : memref<2xi64> + "zlow.leakyrelu"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_leaky_relu + // CHECK: {{.*}} = llvm.call @zdnn_leaky_relu_ext({{.*}}, {{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr, f32, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.invsqrt calls the correct zDNN API or not. +func.func @test_call_zdnn_invsqrt() -> () { + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %shape = memref.alloc() : memref<2xi64> + "zlow.invsqrt"(%0, %shape, %1) {layout = "2D"} : (memref<1x1x32x64xf16>, memref<2xi64>, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_invsqrt + // CHECK: {{.*}} = llvm.call @zdnn_invsqrt_ext({{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, f32, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.reducemax calls the correct zDNN API or not. +func.func @test_call_zdnn_reducemax() -> () { + %0 = memref.alloc() : memref<1x1x32x64xf16> + %1 = memref.alloc() : memref<1x1x32x64xf16> + %work_area = memref.alloc() {alignment = 4096 : i64} : memref<8192xi8> + %shape = memref.alloc() : memref + "zlow.reducemax"(%0, %work_area, %shape, %1) {layout = "2D", op_type = "REDUCE_OP_MAXIMUM" : i64} : (memref<1x1x32x64xf16>, memref<8192xi8>, memref, memref<1x1x32x64xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_reducemax + // CHECK: {{.*}} = llvm.call @zdnn_reduce_ext({{.*}}, {{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.reducemin calls the correct zDNN API or not. +func.func @test_call_zdnn_reducemin() -> () { + %0 = memref.alloc() : memref<3x2x32x64xf16> + %1 = memref.alloc() : memref<3x2x32x64xf16> + %work_area = memref.alloc() {alignment = 4096 : i64} : memref<8192xi8> + %shape = memref.alloc() : memref + "zlow.reducemin"(%0, %work_area, %shape, %1) {layout = "2D", op_type = "REDUCE_OP_MINIMUM" : i64} : (memref<3x2x32x64xf16>, memref<8192xi8>, memref, memref<3x2x32x64xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_reducemin + // CHECK: {{.*}} = llvm.call @zdnn_reduce_ext({{.*}}, {{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.sqrt calls the correct zDNN API or not. +func.func @test_call_zdnn_sqrt() -> () { + %0 = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %1 = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + %shape = memref.alloc() : memref<4xi64> + "zlow.sqrt"(%0, %shape, %1) {layout = "2D"} : (memref<2048xf16>, memref<4xi64>, memref<2048xf16>) -> () + return + + // CHECK-LABEL: test_call_zdnn_sqrt + // CHECK: {{.*}} = llvm.call @zdnn_sqrt_ext({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.matmul calls the correct zDNN API or not. +func.func @test_matmul_bcast1(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { + %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = -1 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + return %res : memref<2048xf16> + // CHECK-LABEL: test_matmul_bcast1 + // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.quantized_matmul calls the correct zDNN API or not. +func.func @test_call_zdnn_quantized_matmul_op(%arg0: memref<1x1x1x1x32x64xf16>, %arg1: memref, %arg2: memref, %arg3: memref<1x1x1x1x32x64xi8>, %arg4: memref, %arg5: memref, %arg6: memref<1x1x1x1x32x64xi8>, %arg7: memref, %arg8: memref, %arg9: memref<1x1x1x1x32x64xf16>, %arg10: memref<4xi64>, %arg11: memref, %arg12: memref) -> memref<1x1x1x1x32x64xf16> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<1x1x1x1x32x64xf16> + "zlow.quantizedMatmul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %alloc, %arg11, %arg12) {bias_q_type = "INT8", dequantize_output = 0 : si64, is_bcast = -1 : si64, is_stacked = 0 : si64, out_q_type = "DLFLOAT16", x_q_type = "DLFLOAT16", y_q_type = "WEIGHTS"} : (memref<1x1x1x1x32x64xf16>, memref, memref, memref<1x1x1x1x32x64xi8>, memref, memref, memref<1x1x1x1x32x64xi8>, memref, memref, memref<1x1x1x1x32x64xf16>, memref<4xi64>, memref<1x1x1x1x32x64xf16>, memref, memref) -> () + return %alloc : memref<1x1x1x1x32x64xf16> + + // CHECK-LABEL: test_call_zdnn_quantized_matmul_op + // CHECK: {{.*}} = llvm.call @zdnn_quantized_matmul_op({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr) -> i32 +} + +// ----- + +// Check whether the lowering of zlow.quantized_matmul calls the correct zDNN API or not. +func.func @test_call_zdnn_quantized_matmul_dequantized_op(%arg0: memref<1x1x1x1x32x64xf16>, %arg1: memref, %arg2: memref, %arg3: memref<1x1x1x1x32x64xi8>, %arg4: memref, %arg5: memref, %arg6: memref<1x1x1x1x32x64xi8>, %arg7: memref, %arg8: memref, %arg9: memref<1x1x1x1x32x64xf16>, %arg10: memref<4xi64>, %arg11: memref, %arg12: memref) -> memref<1x1x1x1x32x64xf16> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<1x1x1x1x32x64xf16> + "zlow.quantizedMatmul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %alloc, %arg11, %arg12) {bias_q_type = "INT8", dequantize_output = -1 : si64, is_bcast = -1 : si64, is_stacked = 0 : si64, out_q_type = "DLFLOAT16", x_q_type = "DLFLOAT16", y_q_type = "WEIGHTS"} : (memref<1x1x1x1x32x64xf16>, memref, memref, memref<1x1x1x1x32x64xi8>, memref, memref, memref<1x1x1x1x32x64xi8>, memref, memref, memref<1x1x1x1x32x64xf16>, memref<4xi64>, memref<1x1x1x1x32x64xf16>, memref, memref) -> () + return %alloc : memref<1x1x1x1x32x64xf16> + + // CHECK-LABEL: test_call_zdnn_quantized_matmul_dequantized_op + // CHECK: {{.*}} = llvm.call @zdnn_quantized_matmul_op({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, i64, i64, i64, i64, i64, !llvm.ptr, !llvm.ptr) -> i32 +} diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir index d4af23b9eb..2b56c8db2b 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir index f0ea3355aa..782bde6e13 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-typed-pointer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s // ----- @@ -39,19 +39,20 @@ func.func @test_stick() -> () { // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: {{.*}} = llvm.call @zdnn_generate_transformed_desc([[PRE_TRANSFORMED_DESC_I8PTR]], [[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 - // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> : (i64) -> !llvm.ptr + // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> : (i64) -> !llvm.ptr // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: [[BUFFER_SIZE:%.+]] = llvm.call @zdnn_getsize_ztensor([[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr) -> i64 - // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> + // CHECK: llvm.store [[PRE_TRANSFORMED_DESC]], [[ZTENSOR_PRE_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[TRANSFORMED_DESC]], [[ZTENSOR_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[BUFFER_SIZE]], [[ZTENSOR_BUFFER_SIZE]] : i64, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[ALIGNED_BUFFER_I8PTR]], [[ZTENSOR_BUFFER]] : !llvm.ptr, !llvm.ptr // CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 @@ -93,7 +94,7 @@ func.func @test_unstick() -> () { // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: {{.*}} = llvm.call @zdnn_generate_transformed_desc([[PRE_TRANSFORMED_DESC_I8PTR]], [[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 - // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> : (i64) -> !llvm.ptr + // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> : (i64) -> !llvm.ptr // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: [[BUFFER_SIZE:%.+]] = llvm.call @zdnn_getsize_ztensor([[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr) -> i64 // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr @@ -334,7 +335,7 @@ func.func @test_call_zdnn_log() -> () { // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -345,7 +346,7 @@ func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -356,7 +357,7 @@ func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16 // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -367,7 +368,7 @@ func.func @test_matmul_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,% // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir index 2307680415..7907969c6b 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s // ----- @@ -38,24 +38,24 @@ func.func @test_stick() -> () { // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: {{.*}} = llvm.call @zdnn_generate_transformed_desc([[PRE_TRANSFORMED_DESC_I8PTR]], [[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 - // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> : (i64) -> !llvm.ptr + // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> : (i64) -> !llvm.ptr // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: [[BUFFER_SIZE:%.+]] = llvm.call @zdnn_getsize_ztensor([[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr) -> i64 - // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[PRE_TRANSFORMED_DESC]], [[ZTENSOR_PRE_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[TRANSFORMED_DESC]], [[ZTENSOR_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[BUFFER_SIZE]], [[ZTENSOR_BUFFER_SIZE]] : i64, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[ALIGNED_BUFFER_I8PTR]], [[ZTENSOR_BUFFER]] : !llvm.ptr, !llvm.ptr // CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 - // CHECK: [[IS_TRANSFORMED:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[IS_TRANSFORMED:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[FALSE]], [[IS_TRANSFORMED]] : i1, !llvm.ptr // CHECK: [[UNSTICKIFIED:%.+]] = llvm.extractvalue [[UNSTICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -92,24 +92,24 @@ func.func @test_unstick() -> () { // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: {{.*}} = llvm.call @zdnn_generate_transformed_desc([[PRE_TRANSFORMED_DESC_I8PTR]], [[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 - // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> : (i64) -> !llvm.ptr + // CHECK: [[ZTENSOR:%.+]] = llvm.alloca {{.*}} x !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> : (i64) -> !llvm.ptr // CHECK: [[TRANSFORMED_DESC_I8PTR:%.+]] = llvm.bitcast [[TRANSFORMED_DESC]] : !llvm.ptr to !llvm.ptr // CHECK: [[BUFFER_SIZE:%.+]] = llvm.call @zdnn_getsize_ztensor([[TRANSFORMED_DESC_I8PTR]]) : (!llvm.ptr) -> i64 - // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_PRE_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[PRE_TRANSFORMED_DESC]], [[ZTENSOR_PRE_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_TRANSFORMED_DESC:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[TRANSFORMED_DESC]], [[ZTENSOR_TRANSFORMED_DESC]] : !llvm.ptr, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER_SIZE:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[BUFFER_SIZE]], [[ZTENSOR_BUFFER_SIZE]] : i64, !llvm.ptr - // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[ZTENSOR_BUFFER:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[ALIGNED_BUFFER_I8PTR]], [[ZTENSOR_BUFFER]] : !llvm.ptr, !llvm.ptr // CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 - // CHECK: [[IS_TRANSFORMED:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<31 x i8>)> + // CHECK: [[IS_TRANSFORMED:%.+]] = llvm.getelementptr [[ZTENSOR]]{{\[}}0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, ptr, i64, ptr, i1, array<3 x i8>, f32, f32, array<20 x i8>)> // CHECK: llvm.store [[TRUE]], [[IS_TRANSFORMED]] : i1, !llvm.ptr // CHECK: [[UNSTICKIFIED:%.+]] = llvm.extractvalue [[UNSTICKIFIED_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -333,7 +333,7 @@ func.func @test_call_zdnn_log() -> () { // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -344,7 +344,7 @@ func.func @test_matmul_no_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_no_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -355,7 +355,7 @@ func.func @test_matmul_no_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16 // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = -1 : si64, is_stacked = -1 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_stacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 @@ -366,7 +366,7 @@ func.func @test_matmul_bcast_stacked(%x: memref<2048xf16>,%y: memref<2048xf16>,% // Check whether the lowering of zlow.matmul calls the correct zDNN API or not. func.func @test_matmul_bcast_unstacked(%x: memref<2048xf16>,%y: memref<2048xf16>,%bias: memref<2048xf16>, %shape: memref<3xi64>) -> memref<2048xf16> { %res = memref.alloc() {alignment = 4096 : i64} : memref<2048xf16> - "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () + "zlow.matmul"(%x, %y, %bias, %shape, %res) {is_bcast1 = 0 : si64, is_bcast23 = -1 : si64, is_stacked = 0 : si64} : (memref<2048xf16>, memref<2048xf16>, memref<2048xf16>, memref<3xi64>, memref<2048xf16>) -> () return %res : memref<2048xf16> // CHECK-LABEL: test_matmul_bcast_unstacked // CHECK: %{{.*}} = llvm.call @zdnn_matmul_bcast_op_ext(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> i32 diff --git a/test/mlir/accelerators/nnpa/conversion/normalize-memref.mlir b/test/mlir/accelerators/nnpa/conversion/normalize-memref.mlir index 6e59e2e39d..d884d6c906 100644 --- a/test/mlir/accelerators/nnpa/conversion/normalize-memref.mlir +++ b/test/mlir/accelerators/nnpa/conversion/normalize-memref.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --normalize-memrefs %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --normalize-memrefs %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-krnl/onnx-on-ztensor.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-krnl/onnx-on-ztensor.mlir index 41728bc5d0..75f81d7843 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-krnl/onnx-on-ztensor.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-krnl/onnx-on-ztensor.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s // Test doing unary element-wise computation directly on zTensor. // Taking ONNXSqrtOp as the example. diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-arch15.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-arch15.mlir new file mode 100644 index 0000000000..595771e326 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-arch15.mlir @@ -0,0 +1,66 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s + +func.func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_add +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>, [[PARAM_1_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Add"([[VAR_0_]], [[VAR_1_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf32> +// CHECK: return [[VAR_3_]] : tensor<10x10xf32> +// CHECK: } +} + +// ----- + +// COM: Binary ops use 3DS by default for rank 3. +func.func @test_add_3ds(%arg0 : tensor<10x10x10xf32>, %arg1 : tensor<10x10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10x10xf32>, tensor<10x10x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_add_3ds +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10x10xf32>, [[PARAM_1_:%.+]]: tensor<10x10x10xf32>) -> tensor<10x10x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<10x10x10xf32>) -> tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<10x10x10xf32>) -> tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Add"([[VAR_0_]], [[VAR_1_]]) : (tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<10x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<10x10x10xf32> +// CHECK: return [[VAR_3_]] : tensor<10x10x10xf32> +// CHECK: } +} + +// ----- + +// COM: Do not lower broadcasting onnx.Add to zHigh. +func.func @test_add_not_lowered_diff_shape(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_add_not_lowered_diff_shape +} + +// ----- + +/// Do not lower onnx.Add to zHigh if inputs have unknown dimensions +/// because we cannot statically check whether it is really broadcasting or not. +func.func @test_add_not_lowered_unknown_dims(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x?xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x?xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_add_not_lowered_unknown_dims +} + +// ----- + +/// COM: Test for zdnn limitation. +/// COM: Not lowered when dimensin size exceeds DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE in `third_party/zdnn-lib/zdnn_limit.h` +/// COM: DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE depends on zAIU HW. Please check the value if these tests fails. + +func.func @test_exceed_limit_add(%arg0 : tensor<2097152x10xf32>, %arg1 : tensor<2097152x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<2097152x10xf32>, tensor<2097152x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_exceed_limit_add +// CHECK: "onnx.Add" +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir index d50fd6c291..370f844d8c 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu-opt.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s | FileCheck %s func.func @test_add_force_cpu_opt(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) {device = "cpu", onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir index 173210e1db..c10dd38b9c 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add-exec-cpu.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR --EmitZHighIR -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR --EmitZHighIR -tag="test" %s | FileCheck %s func.func @test_add_force_cpu(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) {device = "cpu", onnx_node_name = "test/add0"} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add.mlir index d637c76b4f..8874702eb0 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/add.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/common-rules.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/common-rules.mlir index fb7a6e13af..a3cf9c34c6 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/common-rules.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/common-rules.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s // COM: Do not lower element-wise ops with scalar tensor since it is not benefical. func.func @test_not_lowered_scalar_tensor(%arg0 : tensor, %arg1 : tensor, %arg2: tensor<2xf32>) -> tensor<*xf32> { diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir index c7857a7588..700b615268 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/conv.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_onnx_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tensor<2x3x2x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> { %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {kernel_shape = [2, 2]} : (tensor<5x3x32x32xf32>, tensor<2x3x2x2xf32>, tensor<2xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir index 7df7a2cb2a..d11c7f631c 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --nnpa-enable-scalar-bcast-binary %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --nnpa-enable-scalar-bcast-binary %s -split-input-file | FileCheck %s // COM: Division by a scalar in case of dynamic dimensions. func.func @test_div_unknown_scalar1(%arg0 : tensor) -> tensor<*xf32> { diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir index 9cad7a6915..879ec80a61 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_div(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/exp.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/exp.mlir index cd1b115435..2d7f38f57d 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/exp.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/exp.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_exp(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Exp"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gelu.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gelu.mlir new file mode 100644 index 0000000000..647085c4fb --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gelu.mlir @@ -0,0 +1,30 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s | FileCheck %s + +func.func @test_gelu_erf_arch15(%arg0 : tensor<1x2xf32>) -> tensor<1x2xf32>{ + %0 ="onnx.Gelu"(%arg0) {approximate = "none"} : (tensor<1x2xf32>) -> tensor<1x2xf32> + "func.return"(%0) : (tensor<1x2xf32>) -> () + + +// CHECK-LABEL: func @test_gelu_erf_arch15 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<1x2xf32>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Gelu"(%0) {approximate = "none"} : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf32> +// CHECK: return [[VAR_2_]] : tensor<1x2xf32> +// CHECK: } +} + +// ----- + +func.func @test_gelu_tanh_arch15(%arg0 : tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 ="onnx.Gelu"(%arg0) {approximate = "tanh"} : (tensor<1x2xf32>) -> tensor<1x2xf32> + "func.return"(%0) : (tensor<1x2xf32>) -> () + +// CHECK-LABEL: func @test_gelu_tanh_arch15 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<1x2xf32>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Gelu"(%0) {approximate = "tanh"} : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf32> +// CHECK: return [[VAR_2_]] : tensor<1x2xf32> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir index 948995a469..317d51a0e1 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s func.func @test_gemm_bias_none(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { %bias = "onnx.NoValue"() {value} : () -> none @@ -10,7 +10,7 @@ func.func @test_gemm_bias_none(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32 // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<10x10xf32> // CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } @@ -27,7 +27,7 @@ func.func @test_gemm_bias_1d(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<10x10xf32> // CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } @@ -44,7 +44,7 @@ func.func @test_gemm_bias_2d(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<10x10xf32> // CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[VAR_4_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_6_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> @@ -62,13 +62,12 @@ func.func @test_gemm_transA(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, // CHECK-LABEL: func @test_gemm_transA // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x10xf32>, [[PARAM_1_:%.+]]: tensor<5x10xf32>, [[PARAM_2_:%.+]]: tensor<10xf32>) -> tensor<10x10xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [1, 0]} : (tensor<5x10xf32>) -> tensor<10x5xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Unstick"([[VAR_4_]]) : (tensor<*xf16>) -> tensor<10x10xf32> -// CHECK: return [[VAR_5_]] : tensor<10x10xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 1 : si64, transposeB = 0 : si64} : (tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<10x10xf32> +// CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } } @@ -80,14 +79,13 @@ func.func @test_gemm_transB(%arg0 : tensor<10x5xf32>, %arg1 : tensor<10x5xf32>, // CHECK-LABEL: func @test_gemm_transB // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x5xf32>, [[PARAM_1_:%.+]]: tensor<10x5xf32>, [[PARAM_2_:%.+]]: tensor<10xf32>) -> tensor<10x10xf32> { -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [1, 0]} : (tensor<10x5xf32>) -> tensor<5x10xf32> // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_2_]], [[VAR_3_]]) : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Unstick"([[VAR_4_]]) : (tensor<*xf16>) -> tensor<10x10xf32> -// CHECK: return [[VAR_5_]] : tensor<10x10xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 1 : si64} : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<10x10xf32> +// CHECK: return [[VAR_4_]] : tensor<10x10xf32> // CHECK: } } @@ -99,15 +97,13 @@ func.func @test_gemm_transAB(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, // CHECK-LABEL: func @test_gemm_transAB // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x5xf32>, [[PARAM_1_:%.+]]: tensor<5x10xf32>, [[PARAM_2_:%.+]]: tensor<5xf32>) -> tensor<5x5xf32> { -// CHECK: [[VAR_2_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [1, 0]} : (tensor<5x10xf32>) -> tensor<10x5xf32> -// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [1, 0]} : (tensor<10x5xf32>) -> tensor<5x10xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<5xf32>) -> tensor<5xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_5_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) : (tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> -// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x5xf32> -// CHECK: return [[VAR_6_]] : tensor<5x5xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<5xf32>) -> tensor<5xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 1 : si64, transposeB = 1 : si64} : (tensor<10x5xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<5xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<5x5xf32> +// CHECK: return [[VAR_4_]] : tensor<5x5xf32> // CHECK: } } @@ -124,7 +120,7 @@ func.func @test_gemm_unknown_dims(%arg0: tensor, %arg1: tensor<5x10xf32 // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor<5x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor // CHECK: return [[VAR_4_]] : tensor // CHECK: } @@ -159,8 +155,8 @@ func.func @test_gemm_not_lowered(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf /// COM: Not lowered when dimensin size exceeds DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE in `third_party/zdnn-lib/zdnn_limit.h` /// COM: DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE depends on zAIU HW. Please check the value if these tests fails. -func.func @test_exceed_limit_gemm(%arg0 : tensor<32769x5xf32>, %arg1 : tensor<5x32769xf32>, %arg2: tensor<32769xf32>) -> tensor<*xf32> { - %0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<32769x5xf32>, tensor<5x32769xf32>, tensor<32769xf32>) -> tensor<*xf32> +func.func @test_exceed_limit_gemm(%arg0 : tensor<2097152x5xf32>, %arg1 : tensor<5x2097152xf32>, %arg2: tensor<2097152xf32>) -> tensor<*xf32> { + %0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<2097152x5xf32>, tensor<5x2097152xf32>, tensor<2097152xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_exceed_limit_gemm diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir index ba5aeca2df..546b3e23de 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gru.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s func.func @test_onnx_to_zhigh_gru0(%X: tensor<7x2000x204xf32>, %W: tensor<1x600x204xf32>, %R: tensor<1x600x200xf32>, %B: tensor<1x1200xf32>) -> (tensor<7x1x2000x200xf32>, tensor<1x2000x200xf32>) { %cst = "onnx.NoValue"() {value} : () -> none @@ -247,6 +247,34 @@ func.func @test_onnx_to_zhigh_gru0_bidir_dyn(%X: tensor, %W: tensor<2 // ----- +func.func @gru_with_len(%arg0: tensor<2x2x1xf32>, %arg1: tensor<1x3x1xf32>, %arg2 : tensor<1x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %lens = onnx.Constant dense<[2, 1]> : tensor<2xi32> + %cst = "onnx.NoValue"() {value} : () -> none + %res:2 = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %lens, %cst) {layout = 0 : si64, linear_before_reset = 1 : si64} + : ( tensor<2x2x1xf32>, tensor<1x3x1xf32>, tensor<1x3x1xf32>, none, tensor<2xi32>, none) -> (tensor<*xf32>, tensor<*xf32>) + onnx.Return %res#0, %res#1 : tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func @gru_with_len +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x2x1xf32>, [[PARAM_1_:%.+]]: tensor<1x3x1xf32>, [[PARAM_2_:%.+]]: tensor<1x3x1xf32>) -> (tensor<2x1x2x1xf32>, tensor<1x2x1xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 1]> : tensor<2xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<2x2x1xf32>) -> tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32> +// CHECK: [[VAR_4_:%.+]]:3 = "onnx.SplitV11"([[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) +// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.StickForGRU"([[VAR_4_]]#0, [[VAR_4_]]#1, [[VAR_4_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16> +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_2_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32> +// CHECK: [[VAR_7_:%.+]]:3 = "onnx.SplitV11"([[VAR_6_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) +// CHECK: [[VAR_8_:%.+]] = "zhigh.StickForGRU"([[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16> +// CHECK: [[VAR_9_:%.+]] = "zhigh.GRU"([[VAR_2_]], [[VAR_1_]], [[VAR_5_]], [[VAR_1_]], [[VAR_8_]], [[VAR_1_]]) {direction = "forward", hidden_size = 1 : si64, return_all_steps = -1 : si64} : (tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none, tensor<*xf16>, none, tensor<*xf16>, none) -> tensor<*xf16> +// CHECK: [[VAR_10_:%.+]] = "zhigh.Unstick"([[VAR_9_]]) : (tensor<*xf16>) -> tensor<2x1x2x1xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = "zhigh.FixGRUY"([[VAR_10_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>, none) -> tensor<2x1x2x1xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.FixGRUYh"([[VAR_10_]], [[VAR_0_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>) -> tensor<1x2x1xf32> +// CHECK: onnx.Return [[VAR_11_]], [[VAR_12_]] : tensor<2x1x2x1xf32>, tensor<1x2x1xf32> +// CHECK: } +} + +// ----- + // COM : Maximum hidden_size in GRU is 10880. Not lowered when using 10881. func.func @test_onnx_to_zhigh_gru_exceed_num_hidden(%X: tensor<7x2000x204xf32>, %W: tensor<1x16384x204xf32>, %R: tensor<1x16384x10881xf32>, %B: tensor<1x16386xf32>) -> (tensor<7x1x2000x10881xf32>, tensor<1x2000x10881xf32>) { diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/invsqrt.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/invsqrt.mlir new file mode 100644 index 0000000000..d37990ceb9 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/invsqrt.mlir @@ -0,0 +1,45 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s | FileCheck %s + +func.func @test_invsqrt_reciprocal(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %a = "onnx.Sqrt"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> + %y = "onnx.Reciprocal"(%a) : (tensor<*xf32>) -> tensor<*xf32> + "func.return"(%y) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_invsqrt_reciprocal +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.InvSqrt"([[VAR_0_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf32> +// CHECK: return [[VAR_2_]] : tensor<10x10xf32> +// CHECK: } +} + +func.func @test_invsqrt_div(%arg0 : tensor<1x2xf32>) -> tensor<1x2xf32> { + %x = onnx.Constant dense<[[1.0, 1.0]]> : tensor<1x2xf32> + %a = "onnx.Sqrt"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + %y = "onnx.Div"(%x, %a) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + "func.return"(%y) : (tensor<1x2xf32>) -> () + +// CHECK-LABEL: func @test_invsqrt_div +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<1x2xf32>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.InvSqrt"([[VAR_0_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf32> +// CHECK: return [[VAR_2_]] : tensor<1x2xf32> +// CHECK: } +} + +func.func @test_invsqrt_div2(%arg0 : tensor<1x2xf32>) -> tensor<*xf32> { + %x = onnx.Constant dense<[[1.0, 1.0]]> : tensor<1x2xf32> + %a = "onnx.Sqrt"(%arg0) : (tensor<1x2xf32>) -> tensor<*xf32> + %y = "onnx.Div"(%x, %a) : (tensor<1x2xf32>, tensor<*xf32>) -> tensor<*xf32> + "func.return"(%y) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_invsqrt_div +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<1x2xf32>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.InvSqrt"([[VAR_0_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<1x2xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<1x2xf32> +// CHECK: return [[VAR_2_]] : tensor<1x2xf32> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/leakyrelu.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/leakyrelu.mlir new file mode 100644 index 0000000000..b5eb4b09d6 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/leakyrelu.mlir @@ -0,0 +1,42 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s + +func.func @test_leakyrelu(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.LeakyRelu"(%arg0) { alpha = 0.02:f32 } : (tensor<10x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_leakyrelu +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.LeakyRelu"([[VAR_0_]]) {alpha = 2.000000e-02 : f32} : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf32> +// CHECK: return [[VAR_2_]] : tensor<10x10xf32> +// CHECK: } +} + +// ----- + +func.func @test_leakyrelu2(%arg0 : tensor<2x10xf32>) -> tensor<*xf32> { + %0 = "onnx.LeakyRelu"(%arg0) { alpha = 0.01:f32 } : (tensor<2x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_leakyrelu2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x10xf32>) -> tensor<2x10xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<2x10xf32>) -> tensor<2x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.LeakyRelu"([[VAR_0_]]) {alpha = 0.00999999977 : f32} : (tensor<2x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<2x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x10xf32> +// CHECK: return [[VAR_2_]] : tensor<2x10xf32> +// CHECK: } +} + +// ----- + +func.func @test_leakyrelu_default(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.LeakyRelu"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_leakyrelu_default +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.LeakyRelu"([[VAR_0_]]) {alpha = 0.00999999977 : f32} : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf32> +// CHECK: return [[VAR_2_]] : tensor<10x10xf32> +// CHECK: } +} + diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/log.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/log.mlir index 2c51af040e..0a6580ece9 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/log.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/log.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_log(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Log"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/lstm.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/lstm.mlir index cee97d059f..39bc16d49a 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/lstm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/lstm.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s func.func @test_onnx_to_zhigh_ccfd0(%X: tensor<7x2000x204xf32>, %W: tensor<1x800x204xf32>, %R: tensor<1x800x200xf32>, %B: tensor<1x1600xf32>) -> (tensor<7x1x2000x200xf32>, tensor<1x2000x200xf32>, tensor<1x2000x200xf32>) { %cst = "onnx.NoValue"() {value} : () -> none diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul-arch15.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul-arch15.mlir new file mode 100644 index 0000000000..e6bf56beb1 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul-arch15.mlir @@ -0,0 +1,115 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s + +// COM: In these tests, matmul and transpose will be combined together to be lowered to +// COM: zhigh.MatMul. + +func.func @test_onnx_transposea_matmul_to_zhigh(%arg0 : tensor<8x4xf32>, %arg1 : tensor<8x4xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [1, 0]}: (tensor<8x4xf32>) -> tensor<4x8xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<4x8xf32>,tensor<8x4xf32>) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_onnx_transposea_matmul_to_zhigh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x4xf32>, [[PARAM_1_:%.+]]: tensor<8x4xf32>) -> tensor<4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<8x4xf32>) -> tensor<8x4xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x4xf32>) -> tensor<8x4xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 1 : si64, transposeB = 0 : si64} : (tensor<8x4xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x4xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<4x4xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<4x4xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x4xf32> +// CHECK: return [[VAR_3_]] : tensor<4x4xf32> +// CHECK: } +} + +// ----- + +func.func @test_onnx_transposeb_matmul_to_zhigh(%arg0 : tensor<4x8xf32>, %arg1 : tensor<4x8xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg1) {perm = [1, 0]}: (tensor<4x8xf32>) -> tensor<8x4xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor<4x8xf32>,tensor<8x4xf32>) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_onnx_transposeb_matmul_to_zhigh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x8xf32>, [[PARAM_1_:%.+]]: tensor<4x8xf32>) -> tensor<4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {transposeA = 0 : si64, transposeB = 1 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<4x4xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x4xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x4xf32> +// CHECK: return [[VAR_4_]] : tensor<4x4xf32> +// CHECK: } +} + +// ----- + +func.func @test_onnx_transposeab_matmul_to_zhigh(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x4xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {permA = [1, 0]}: (tensor<4x8xf32>) -> tensor<8x4xf32> + %1 = "onnx.Transpose"(%arg1) {permB = [1, 0]}: (tensor<16x4xf32>) -> tensor<4x16xf32> + %2 = "onnx.MatMul"(%0, %1) : (tensor<8x4xf32>,tensor<4x16xf32>) -> tensor<*xf32> + "func.return"(%2) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_onnx_transposeab_matmul_to_zhigh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x8xf32>, [[PARAM_1_:%.+]]: tensor<16x4xf32>) -> tensor<8x16xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<16x4xf32>) -> tensor<16x4xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {transposeA = 1 : si64, transposeB = 1 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16x4xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<8x16xf32> +// CHECK: return [[VAR_4_]] : tensor<8x16xf32> +// CHECK: } +} + +// ----- + +func.func @test_onnx_to_transposea_matmul_to_zhigh_3d(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100x16x8xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 1]}: (tensor<100x4x8xf32>) -> tensor<100x8x4xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<100x8x4xf32>, tensor<100x16x8xf32>) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_onnx_to_transposea_matmul_to_zhigh_3d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x4x8xf32>, [[PARAM_1_:%.+]]: tensor<100x16x8xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1]} : (tensor<100x4x8xf32>) -> tensor<100x8x4xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.MatMul"([[VAR_0_]], [[PARAM_1_]]) : (tensor<100x8x4xf32>, tensor<100x16x8xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_1_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_onnx_to_transposeb_matmul_to_zhigh_3d(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100x16x8xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg1) {perm = [0, 2, 1]}: (tensor<100x16x8xf32>) -> tensor<100x8x16xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_onnx_to_transposeb_matmul_to_zhigh_3d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x4x8xf32>, [[PARAM_1_:%.+]]: tensor<100x16x8xf32>) -> tensor<100x4x16xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<100x4x8xf32>) -> tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<100x16x8xf32>) -> tensor<100x16x8xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_1_]], [[VAR_2_]], [[VAR_0_]]) {transposeA = 0 : si64, transposeB = 1 : si64} : (tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<100x16x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<100x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<100x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<100x4x16xf32> +// CHECK: return [[VAR_4_]] : tensor<100x4x16xf32> +// CHECK: } +} + +// ----- + +func.func @test_onnx_to_transposeab_matmul_to_zhigh_3d(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100x8x16xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {permA = [0, 2, 1]}: (tensor<100x4x8xf32>) -> tensor<100x8x4xf32> + %1 = "onnx.Transpose"(%arg1) {permB = [0, 2, 1]}: (tensor<100x8x16xf32>) -> tensor<100x16x8xf32> + %2 = "onnx.MatMul"(%0, %1) : (tensor<100x8x4xf32>,tensor<100x16x8xf32>) -> tensor<*xf32> + "func.return"(%2) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_onnx_to_transposeab_matmul_to_zhigh_3d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x4x8xf32>, [[PARAM_1_:%.+]]: tensor<100x8x16xf32>) -> tensor<*xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [2, 1, 0], permA = [0, 2, 1]} : (tensor<100x4x8xf32>) -> tensor<100x8x4xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 1, 0], permB = [0, 2, 1]} : (tensor<100x8x16xf32>) -> tensor<100x16x8xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[VAR_0_]], [[VAR_1_]]) : (tensor<100x8x4xf32>, tensor<100x16x8xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_2_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir index 857baf98f6..064d27f518 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s func.func @test_onnx_to_matmul2d(%arg0 : tensor<4x8xf32>, %arg1 : tensor<8x16xf32>) -> tensor<*xf32> { %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x8xf32>, tensor<8x16xf32>) -> tensor<*xf32> @@ -9,7 +9,7 @@ func.func @test_onnx_to_matmul2d(%arg0 : tensor<4x8xf32>, %arg1 : tensor<8x16xf3 // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<4x16xf32> // CHECK: return [[VAR_3_]] : tensor<4x16xf32> // CHECK: } @@ -26,7 +26,7 @@ func.func @test_onnx_to_matmul3d(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100 // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<100x4x8xf32>) -> tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<100x8x16xf32>) -> tensor<100x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<100x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<100x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<100x4x16xf32> // CHECK: return [[VAR_3_]] : tensor<100x4x16xf32> // CHECK: } @@ -43,7 +43,7 @@ func.func @test_onnx_to_matmul3dbcast(%arg0 : tensor<100x4x8xf32>, %arg1 : tenso // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<100x4x8xf32>) -> tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<100x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<100x4x16xf32> // CHECK: return [[VAR_3_]] : tensor<100x4x16xf32> // CHECK: } @@ -79,7 +79,7 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } @@ -105,7 +105,7 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized( // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32> // CHECK: return [[VAR_4_]] : tensor<4x16xf32> // CHECK: } @@ -161,7 +161,7 @@ func.func @test_onnx_to_matmul2d_dyn(%arg0 : tensor, %arg1 : tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor>, tensor>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor // CHECK: return [[VAR_3_]] : tensor // CHECK: } @@ -178,7 +178,7 @@ func.func @test_onnx_to_matmul3d_dyn(%arg0 : tensor, %arg1 : tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor>, tensor>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor // CHECK: return [[VAR_3_]] : tensor // CHECK: } @@ -195,7 +195,7 @@ func.func @test_onnx_to_matmul3dbcast_dyn(%arg0 : tensor, %arg1 : ten // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) : (tensor>, tensor>, none) -> tensor<*xf16> +// CHECK: [[VAR_2_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor<*xf16> // CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor // CHECK: return [[VAR_3_]] : tensor // CHECK: } diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmulinteger.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmulinteger.mlir new file mode 100644 index 0000000000..b121ac4628 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/matmulinteger.mlir @@ -0,0 +1,187 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s --check-prefix=CHECK-FUSION + +func.func @matmulinteger(%arg0: tensor, %arg1: tensor<768x768xi8>, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL: func.func @matmulinteger +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<768x768xi8>, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[PARAM_2_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Cast"([[VAR_4_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Cast"([[PARAM_3_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[VAR_3_]], [[VAR_2_]], [[VAR_5_]]) {layout = "3DS", quantized_type = "INT8", sym_mode = 0 : i64} : (tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[PARAM_1_]], [[VAR_2_]], [[VAR_6_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_RecScale_]]_1, [[VAR_Offset_]]_2, [[VAR_0_]], [[VAR_0_]], [[VAR_0_]], [[VAR_2_]], [[VAR_1_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = 0 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, none, none, none, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_Out_3_]]) : (tensor>) -> tensor +// CHECK: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_7_]]) {saturate = 1 : si64, to = i32} : (tensor) -> tensor +// CHECK: return [[VAR_8_]] : tensor +// CHECK: } +} + +// ----- + +// Do not do pre_compute when B is not a constant. +func.func @matmulinteger_no_precompute_bias(%arg0: tensor, %arg1: tensor<768x768xi8>, %arg2: tensor) -> tensor { + %0 = onnx.Constant dense<0> : tensor + %1 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %0) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @matmulinteger_no_precompute_bias +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<768x768xi8>, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Cast"([[PARAM_2_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Cast"([[VAR_5_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[VAR_4_]], [[VAR_3_]], [[VAR_6_]]) {layout = "3DS", quantized_type = "INT8", sym_mode = 0 : i64} : (tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[PARAM_1_]], [[VAR_3_]], [[VAR_7_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_RecScale_]]_1, [[VAR_Offset_]]_2, [[VAR_1_]], [[VAR_1_]], [[VAR_1_]], [[VAR_3_]], [[VAR_2_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = 0 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, none, none, none, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_8_:%.+]] = "zhigh.Unstick"([[VAR_Out_3_]]) : (tensor>) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Cast"([[VAR_8_]]) {saturate = 1 : si64, to = i32} : (tensor) -> tensor +// CHECK: return [[VAR_9_]] : tensor +// CHECK: } +} + +// ----- + +func.func @matmulinteger_precompute_bias(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = onnx.Constant dense<0> : tensor + %B = onnx.Constant dense<0> : tensor<768x768xi8> + %1 = "onnx.MatMulInteger"(%arg0, %B, %arg1, %0) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @matmulinteger_precompute_bias +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0> : tensor<768x768xi8> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Cast"([[VAR_6_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_1_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[VAR_5_]], [[VAR_4_]], [[VAR_7_]]) {layout = "3DS", quantized_type = "INT8", sym_mode = 0 : i64} : (tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_2_]], [[VAR_4_]], [[VAR_8_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_9_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = f32} : (tensor<768x768xi8>) -> tensor<768x768xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.ReduceSum"([[VAR_9_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<768x768xf32>, tensor) -> tensor<768xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Div"([[VAR_4_]], [[VAR_4_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_12_:%.+]] = "onnx.Div"([[VAR_11_]], [[VAR_4_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_13_:%.+]] = "onnx.Mul"([[VAR_12_]], [[VAR_7_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_14_:%.+]] = "onnx.Sub"([[VAR_3_]], [[VAR_13_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_10_]]) : (tensor, tensor<768xf32>) -> tensor<768xf32> +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_15_]], [[VAR_4_]], [[VAR_3_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<768xf32>, tensor, tensor) -> (tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_RecScale_]]_1, [[VAR_Offset_]]_2, [[VAR_Out_]]_3, [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_4_]], [[VAR_3_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_16_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK: [[VAR_17_:%.+]] = "onnx.Cast"([[VAR_16_]]) {saturate = 1 : si64, to = i32} : (tensor) -> tensor +// CHECK: return [[VAR_17_]] : tensor +// CHECK: } +} + +// ----- + +func.func @matmulinteger_rewrite_from_mul_pattern_in_bert(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<5> : tensor<768x768xi8> + %1 = onnx.Constant dense<0.00656270096> : tensor + %2 = onnx.Constant dense<0> : tensor + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor) -> (tensor, tensor, tensor) + %3 = "onnx.MatMulInteger"(%y, %0, %y_zero_point, %2) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor + %4 = "onnx.Cast"(%3) {saturate = 1 : si64, to = f32} : (tensor) -> tensor + %5 = "onnx.Mul"(%4, %y_scale) : (tensor, tensor) -> tensor + %6 = "onnx.Mul"(%5, %1) : (tensor, tensor) -> tensor + return %6 : tensor + +// CHECK-LABEL: func.func @matmulinteger_rewrite_from_mul_pattern_in_bert +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<5> : tensor<768x768xi8> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<0.00656270096> : tensor +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_6_]], [[VAR_6_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Reciprocal"([[VAR_4_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_5_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_3_]], [[VAR_7_]], [[VAR_8_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_9_:%.+]] = "onnx.Cast"([[VAR_3_]]) {saturate = 1 : si64, to = f32} : (tensor<768x768xi8>) -> tensor<768x768xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.ReduceSum"([[VAR_9_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<768x768xf32>, tensor) -> tensor<768xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Div"([[VAR_2_]], [[VAR_RecScale_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_12_:%.+]] = "onnx.Div"([[VAR_11_]], [[VAR_7_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_13_:%.+]] = "onnx.Mul"([[VAR_12_]], [[VAR_Offset_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_14_:%.+]] = "onnx.Sub"([[VAR_1_]], [[VAR_1_]]3) : (tensor, tensor) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_10_]]) : (tensor, tensor<768xf32>) -> tensor<768xf32> +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_15_]], [[VAR_2_]], [[VAR_1_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<768xf32>, tensor, tensor) -> (tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_7_]], [[VAR_8_]], [[VAR_Out_]]_3, [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_2_]], [[VAR_1_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_16_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_16_]] : tensor +// CHECK: } +} + +// ----- + +func.func @matmulinteger_fuse_add_pattern_in_bert(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<-2> : tensor + %1 = onnx.Constant dense<0.000000e+00> : tensor + %2 = onnx.Constant dense<1.000000e+00> : tensor + %3 = onnx.Constant dense<5.000000e+00> : tensor<768xf32> + %4 = onnx.Constant dense<5> : tensor<768x768xi8> + %5 = onnx.Constant dense<0.00656270096> : tensor + %6 = onnx.Constant dense<0> : tensor + %7 = "onnx.NoValue"() {value} : () -> none + %Out, %RecScale, %Offset = "zhigh.QuantizedStick"(%arg0, %7, %7) {layout = "3DS", quantized_type = "DLFLOAT16"} : (tensor, none, none) -> (tensor>, tensor, tensor) + %8 = "onnx.Reciprocal"(%5) : (tensor) -> tensor + %9 = "onnx.Cast"(%6) {saturate = 1 : si64, to = f32} : (tensor) -> tensor + %Out_0, %RecScale_1, %Offset_2 = "zhigh.QuantizedStick"(%4, %8, %9) {layout = "2D", quantized_type = "WEIGHTS"} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) + %10 = "onnx.Cast"(%4) {saturate = 1 : si64, to = f32} : (tensor<768x768xi8>) -> tensor<768x768xf32> + %11 = "onnx.ReduceSum"(%10, %0) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<768x768xf32>, tensor) -> tensor<768xf32> + %12 = "onnx.Div"(%2, %RecScale) : (tensor, tensor) -> tensor + %13 = "onnx.Div"(%12, %8) : (tensor, tensor) -> tensor + %14 = "onnx.Mul"(%13, %Offset) : (tensor, tensor) -> tensor + %15 = "onnx.Sub"(%1, %14) : (tensor, tensor) -> tensor + %16 = "onnx.Mul"(%15, %11) : (tensor, tensor<768xf32>) -> tensor<768xf32> + %17 = "onnx.Add"(%3, %16) : (tensor<768xf32>, tensor<768xf32>) -> tensor<768xf32> + %Out_3, %RecScale_4, %Offset_5 = "zhigh.QuantizedStick"(%17, %2, %1) {layout = "1D", quantized_type = "DLFLOAT16"} : (tensor<768xf32>, tensor, tensor) -> (tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) + %Out_6, %OutRecScale, %OutOffset = "zhigh.QuantizedMatMul"(%Out, %RecScale, %Offset, %Out_0, %8, %9, %Out_3, %RecScale_4, %Offset_5, %2, %1) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) + %18 = "zhigh.Unstick"(%Out_6) : (tensor>) -> tensor + return %18 : tensor + +// CHECK-FUSION-LABEL: func.func @matmulinteger_fuse_add_pattern_in_bert +// CHECK-FUSION-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-FUSION-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2> : tensor +// CHECK-FUSION-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-FUSION-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK-FUSION-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<5.000000e+00> : tensor<768xf32> +// CHECK-FUSION-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<5> : tensor<768x768xi8> +// CHECK-FUSION-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0.00656270096> : tensor +// CHECK-FUSION-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-FUSION-DAG: [[VAR_7_:%.+]] = "onnx.NoValue"() {value} : () -> none + // CHECK-FUSION: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_7_]], [[VAR_7_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK-FUSION-DAG: [[VAR_8_:%.+]] = "onnx.Reciprocal"([[VAR_5_]]) : (tensor) -> tensor +// CHECK-FUSION-DAG: [[VAR_9_:%.+]] = "onnx.Cast"([[VAR_6_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor + // CHECK-FUSION: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_4_]], [[VAR_8_]], [[VAR_9_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<768x768xi8>, tensor, tensor) -> (tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK-FUSION: [[VAR_10_:%.+]] = "onnx.Cast"([[VAR_4_]]) {saturate = 1 : si64, to = f32} : (tensor<768x768xi8>) -> tensor<768x768xf32> +// CHECK-FUSION-DAG: [[VAR_11_:%.+]] = "onnx.ReduceSum"([[VAR_10_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<768x768xf32>, tensor) -> tensor<768xf32> +// CHECK-FUSION-DAG: [[VAR_12_:%.+]] = "onnx.Div"([[VAR_2_]], [[VAR_RecScale_]]) : (tensor, tensor) -> tensor +// CHECK-FUSION: [[VAR_13_:%.+]] = "onnx.Div"([[VAR_12_]], [[VAR_8_]]) : (tensor, tensor) -> tensor +// CHECK-FUSION: [[VAR_14_:%.+]] = "onnx.Mul"([[VAR_13_]], [[VAR_Offset_]]) : (tensor, tensor) -> tensor +// CHECK-FUSION: [[VAR_15_:%.+]] = "onnx.Sub"([[VAR_1_]], [[VAR_1_]]4) : (tensor, tensor) -> tensor +// CHECK-FUSION: [[VAR_16_:%.+]] = "onnx.Mul"([[VAR_15_]], [[VAR_11_]]) : (tensor, tensor<768xf32>) -> tensor<768xf32> +// CHECK-FUSION: [[VAR_17_:%.+]] = "onnx.Add"([[VAR_3_]], [[VAR_16_]]) : (tensor<768xf32>, tensor<768xf32>) -> tensor<768xf32> + // CHECK-FUSION: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_17_]], [[VAR_2_]], [[VAR_1_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<768xf32>, tensor, tensor) -> (tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK-FUSION: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_8_]], [[VAR_9_]], [[VAR_Out_]]_3, [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_2_]], [[VAR_1_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<768x768xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<768xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK-FUSION: [[VAR_18_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK-FUSION: return [[VAR_18_]] : tensor +// CHECK-FUSION: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/max.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/max.mlir index d4b0da8748..d916c93ab5 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/max.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/max.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_max(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Max"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/min.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/min.mlir index 4e59b9e415..8148fad37c 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/min.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/min.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Min"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/mul.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/mul.mlir index 69acfb9e44..0de095bd6d 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/mul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/mul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_mul(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/pool.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/pool.mlir index 5e8e3ad622..f974ae04f0 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/pool.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/pool.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @maxpool_should_lower_to_zhigh_padtype_valid(%arg0: tensor<1x3x32x32xf32>) -> tensor<*xf32> { %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", dilations = [1, 1], kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/qlinearmatmul.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/qlinearmatmul.mlir new file mode 100644 index 0000000000..2a7a11637b --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/qlinearmatmul.mlir @@ -0,0 +1,66 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s + +func.func @qlinearmatmul_i8_f32(%arg0: tensor<2x4xi8> {onnx.name = "a"}, %arg1: tensor {onnx.name = "a_scale"}, %arg2: tensor {onnx.name = "a_zero_point"}, %arg3: tensor<4x3xi8> {onnx.name = "b"}, %arg4: tensor {onnx.name = "b_scale"}, %arg5: tensor {onnx.name = "b_zero_point"}, %arg6: tensor {onnx.name = "y_scale"}, %arg7: tensor {onnx.name = "y_zero_point"}) -> (tensor<2x3xi8> {onnx.name = "y"}) { + %0 = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x4xi8>, tensor, tensor, tensor<4x3xi8>, tensor, tensor, tensor, tensor) -> tensor<2x3xi8> + onnx.Return %0 : tensor<2x3xi8> + +// CHECK-LABEL: func.func @qlinearmatmul_i8_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4xi8> {onnx.name = "a"}, [[PARAM_1_:%.+]]: tensor {onnx.name = "a_scale"}, [[PARAM_2_:%.+]]: tensor {onnx.name = "a_zero_point"}, [[PARAM_3_:%.+]]: tensor<4x3xi8> {onnx.name = "b"}, [[PARAM_4_:%.+]]: tensor {onnx.name = "b_scale"}, [[PARAM_5_:%.+]]: tensor {onnx.name = "b_zero_point"}, [[PARAM_6_:%.+]]: tensor {onnx.name = "y_scale"}, [[PARAM_7_:%.+]]: tensor {onnx.name = "y_zero_point"}) -> (tensor<2x3xi8> {onnx.name = "y"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Reciprocal"([[PARAM_1_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_2_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reciprocal"([[PARAM_4_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[PARAM_5_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reciprocal"([[PARAM_6_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Cast"([[PARAM_7_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {layout = "2D", quantized_type = "INT8", sym_mode = 0 : i64} : (tensor<2x4xi8>, tensor, tensor) -> (tensor<2x4xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "INT8"}>>, tensor, tensor) +// CHECK: [[Out_0_:%.+]], [[RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[PARAM_3_]], [[VAR_3_]], [[VAR_4_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<4x3xi8>, tensor, tensor) -> (tensor<4x3xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[Out_3_:%.+]], [[OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[Out_]], [[RecScale_]], [[VAR_Offset_]], [[Out_0_]], [[RecScale_1_]], [[VAR_Offset_2_]], [[VAR_0_]], [[VAR_0_]], [[VAR_0_]], [[VAR_5_]], [[VAR_6_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = 0 : si64} : (tensor<2x4xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "INT8"}>>, tensor, tensor, tensor<4x3xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, none, none, none, tensor, tensor) -> (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[Out_3_]]) : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D", quantizedType = "DLFLOAT16"}>>) -> tensor<2x3xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_7_]]) {saturate = 1 : si64, to = i8} : (tensor<2x3xf32>) -> tensor<2x3xi8> +// CHECK: onnx.Return [[VAR_8_]] : tensor<2x3xi8> +// CHECK: } +} + +// ----- + +func.func @qlinearmatmul_ui8_f32(%arg0: tensor<2x4xui8> {onnx.name = "a"}, %arg1: tensor {onnx.name = "a_scale"}, %arg2: tensor {onnx.name = "a_zero_point"}, %arg3: tensor<4x3xui8> {onnx.name = "b"}, %arg4: tensor {onnx.name = "b_scale"}, %arg5: tensor {onnx.name = "b_zero_point"}, %arg6: tensor {onnx.name = "y_scale"}, %arg7: tensor {onnx.name = "y_zero_point"}) -> (tensor<2x3xui8> {onnx.name = "y"}) { + %0 = "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2x4xui8>, tensor, tensor, tensor<4x3xui8>, tensor, tensor, tensor, tensor) -> tensor<2x3xui8> + onnx.Return %0 : tensor<2x3xui8> + +// CHECK-LABEL: func.func @qlinearmatmul_ui8_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4xui8> {onnx.name = "a"}, [[PARAM_1_:%.+]]: tensor {onnx.name = "a_scale"}, [[PARAM_2_:%.+]]: tensor {onnx.name = "a_zero_point"}, [[PARAM_3_:%.+]]: tensor<4x3xui8> {onnx.name = "b"}, [[PARAM_4_:%.+]]: tensor {onnx.name = "b_scale"}, [[PARAM_5_:%.+]]: tensor {onnx.name = "b_zero_point"}, [[PARAM_6_:%.+]]: tensor {onnx.name = "y_scale"}, [[PARAM_7_:%.+]]: tensor {onnx.name = "y_zero_point"}) -> (tensor<2x3xui8> {onnx.name = "y"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<128> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i16} : (tensor<2x4xui8>) -> tensor<2x4xi16> +// CHECK: [[VAR_3_:%.+]] = "onnx.Sub"([[VAR_2_]], [[VAR_0_]]) : (tensor<2x4xi16>, tensor) -> tensor<2x4xi16> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[VAR_3_]]) {saturate = 1 : si64, to = i8} : (tensor<2x4xi16>) -> tensor<2x4xi8> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Cast"([[PARAM_3_]]) {saturate = 1 : si64, to = i16} : (tensor<4x3xui8>) -> tensor<4x3xi16> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[VAR_5_]], [[VAR_0_]]) : (tensor<4x3xi16>, tensor) -> tensor<4x3xi16> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Cast"([[VAR_6_]]) {saturate = 1 : si64, to = i8} : (tensor<4x3xi16>) -> tensor<4x3xi8> +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Reciprocal"([[PARAM_1_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Cast"([[PARAM_2_]]) {saturate = 1 : si64, to = i16} : (tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Sub"([[VAR_9_]], [[VAR_0_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_11_:%.+]] = "onnx.Cast"([[VAR_10_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Cast"([[VAR_11_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Reciprocal"([[PARAM_4_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Cast"([[PARAM_5_]]) {saturate = 1 : si64, to = i16} : (tensor) -> tensor +// CHECK: [[VAR_15_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_0_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_16_:%.+]] = "onnx.Cast"([[VAR_15_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Cast"([[VAR_16_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Reciprocal"([[PARAM_6_]]) : (tensor) -> tensor +// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.Cast"([[PARAM_7_]]) {saturate = 1 : si64, to = i16} : (tensor) -> tensor +// CHECK: [[VAR_20_:%.+]] = "onnx.Sub"([[VAR_19_]], [[VAR_0_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_21_:%.+]] = "onnx.Cast"([[VAR_20_]]) {saturate = 1 : si64, to = i8} : (tensor) -> tensor +// CHECK: [[VAR_22_:%.+]] = "onnx.Cast"([[VAR_21_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[VAR_4_]], [[VAR_8_]], [[VAR_12_]]) {layout = "2D", quantized_type = "INT8", sym_mode = 0 : i64} : (tensor<2x4xi8>, tensor, tensor) -> (tensor<2x4xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "INT8"}>>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_7_]], [[VAR_13_]], [[VAR_17_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<4x3xi8>, tensor, tensor) -> (tensor<4x3xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_RecScale_]]_1, [[VAR_Offset_]]_2, [[VAR_1_]], [[VAR_1_]], [[VAR_1_]], [[VAR_1_]]8, [[VAR_22_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = 0 : si64} : (tensor<2x4xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "INT8"}>>, tensor, tensor, tensor<4x3xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, none, none, none, tensor, tensor) -> (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_23_:%.+]] = "zhigh.Unstick"([[VAR_Out_3_]]) : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D", quantizedType = "DLFLOAT16"}>>) -> tensor<2x3xf32> +// CHECK: [[VAR_24_:%.+]] = "onnx.Cast"([[VAR_23_]]) {saturate = 1 : si64, to = i16} : (tensor<2x3xf32>) -> tensor<2x3xi16> +// CHECK: [[VAR_25_:%.+]] = "onnx.Add"([[VAR_24_]], [[VAR_0_]]) : (tensor<2x3xi16>, tensor) -> tensor<2x3xi16> +// CHECK: [[VAR_26_:%.+]] = "onnx.Cast"([[VAR_25_]]) {saturate = 1 : si64, to = ui16} : (tensor<2x3xi16>) -> tensor<2x3xui16> +// CHECK: [[VAR_27_:%.+]] = "onnx.Cast"([[VAR_26_]]) {saturate = 1 : si64, to = ui8} : (tensor<2x3xui16>) -> tensor<2x3xui8> +// CHECK: onnx.Return [[VAR_27_]] : tensor<2x3xui8> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir new file mode 100644 index 0000000000..83565c6e42 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir @@ -0,0 +1,179 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh="quantization=DynSymI8" --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh="quantization=SymSymI8" --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s --check-prefix=SYMSYMI8 + +func.func @test_correctness_of_symmetric_quant_for_weight(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<[[-0.00718058366], [5.253110e-01], [-0.0434652828], [-0.305256933], [0.193365857], [0.0105065238], [-0.143788248], [-0.0161222648], [0.0230324212], [-0.34107244], [-0.273072243], [-0.104352467], [0.0164068397], [-1.32305741], [-0.0345043093], [-0.232206389], [-0.150001124], [0.119475454], [0.730642438], [-0.407772154], [-0.0164191965], [-1.625590e-01], [-0.112515017], [0.158920377], [-0.0997497215], [0.0788274407], [1.1542908], [0.492949218], [-0.125796661], [0.0107790371], [0.141159713], [-0.0774109289], [-0.438130081], [-0.0888700857], [0.207725927], [-0.0913108587], [0.258232892], [0.0672571063], [-0.100412264], [1.68460846], [-0.289168775], [-0.686722457], [0.903651654], [0.110602334], [-0.0505490415], [1.31204939], [0.136107579], [0.26376456], [-0.508291602], [-0.0118971812], [-0.0373991691], [0.448705465], [0.00448446581], [-0.165114298], [0.156860754], [0.141124308], [-0.272756487], [-0.0834815949], [0.020905681], [-0.0877983123], [-1.0087887], [-0.353012145], [-0.0439243801], [-0.00592191564], [-0.0637216269], [0.175808683], [-0.193864927], [-0.0574007072], [0.390869558], [0.138100505], [0.429396927], [1.10117233], [-0.362377733], [0.116578773], [0.0540139228], [-5.85162896E-4], [-0.335441321], [-0.0902953073], [0.017575942], [-0.0359748788], [1.50025952], [-0.668821096], [0.0109066488], [9.907780e-01], [0.10227681], [-0.0582750589], [0.0172416102], [0.0429656394], [0.0465254933], [0.350135148], [-0.260139734], [0.199394852], [-0.136131078], [0.241424322], [0.855418264], [-0.160689577], [-0.825074911], [-0.124827594], [0.0153419804], [0.389386117], [0.153694436], [-0.897866904], [-0.292769879], [0.181667477], [-0.188009143], [-0.0245181341], [-2.17088842], [-0.0526076891], [-0.108600065], [0.187120304], [0.171495944], [0.310159177], [2.204240e+00], [0.0506350659], [-0.159419239], [-0.145082235], [-0.0991335287], [-0.0680764392], [-0.311415762], [-0.187137261], [-0.416945577], [0.0703471377], [0.498331547], [-0.41216433], [-0.427900195], [0.102105901], [0.130767033], [-0.440281332], [0.778514624], [-0.253678083], [0.395671815], [0.380029172], [-0.418493837], [-0.288157403], [0.0689846799], [1.269960e+00], [-0.0585722439], [-0.138125435], [-0.191710189], [0.0163070802], [0.159242466], [0.116627224], [0.289637923], [-0.299413532], [-0.0216965247], [0.271396786], [0.250576884], [-0.131420374], [0.137698188], [-0.0102280416], [0.234722644], [-0.0366179943], [-0.105632246], [-0.145528033], [-0.278210133], [-0.247100428], [0.217718393], [0.171669215], [0.0151556451], [0.961385667], [-0.0484847203], [0.434219301], [-0.00167646946], [-0.0308207348], [-0.102328695], [-0.127907664], [-0.185960412], [0.210866481], [0.140434876], [-0.233541235], [-0.123745643], [-0.0113738365], [1.30043447], [0.179708347], [-0.331716627], [0.0133318678], [-0.107284561], [-0.114116102], [-0.478514463], [0.0616452768], [-0.781869769], [-0.121830635], [-0.0684970543], [-6.584100e-02], [-0.131784603], [-0.619898796], [0.160366163], [-0.50115186], [0.0228514839], [0.581515431], [4.220270e-01], [1.944400e-01], [-1.07740963], [3.732520e-01], [0.725471556], [-0.117193311], [-0.105938725], [0.320118755], [-0.484032601], [-0.0467250831]]> : tensor<200x1xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor, tensor<200x1xf32>) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @test_correctness_of_symmetric_quant_for_weight +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2.750000e+02> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}[0], [30], [-3], [-18], [11], [1], [-8], [-1], [1], [-20], [-16], [-6], [1], [-76], [-2], [-13], [-9], [7], [42], [-23], [-1], [-9], [-6], [9], [-6], [5], [67], [28], [-7], [1], [8], [-4], [-25], [-5], [12], [-5], [15], [4], [-6], [97], [-17], [-40], [52], [6], [-3], [76], [8], [15], [-29], [-1], [-2], [26], [0], [-10], [9], [8], [-16], [-5], [1], [-5], [-58], [-20], [-3], [0], [-4], [10], [-11], [-3], [23], [8], [25], [63], [-21], [7], [3], [0], [-19], [-5], [1], [-2], [86], [-39], [1], [57], [6], [-3], [1], [2], [3], [20], [-15], [11], [-8], [14], [49], [-9], [-48], [-7], [1], [22], [9], [-52], [-17], [10], [-11], [-1], [-125], [-3], [-6], [11], [10], [18], [127], [3], [-9], [-8], [-6], [-4], [-18], [-11], [-24], [4], [29], [-24], [-25], [6], [8], [-25], [45], [-15], [23], [22], [-24], [-17], [4], [73], [-3], [-8], [-11], [1], [9], [7], [17], [-17], [-1], [16], [14], [-8], [8], [-1], [14], [-2], [-6], [-8], [-16], [-14], [13], [10], [1], [55], [-3], [25], [0], [-2], [-6], [-7], [-11], [12], [8], [-13], [-7], [-1], [75], [10], [-19], [1], [-6], [-7], [-28], [4], [-45], [-7], [-4], [-4], [-8], [-36], [9], [-29], [1], [34], [24], [11], [-62], [22], [42], [-7], [-6], [18], [-28], [-3]{{.}}> : tensor<200x1xi8> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<57.61623> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_3_]], [[VAR_3_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_1_]], [[VAR_2_]], [[VAR_4_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<200x1xi8>, tensor, tensor) -> (tensor<200x1xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: } +} + +// ----- + +func.func @test_correctness_of_symmetric_quant_for_activation_and_weight(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<[[-0.00718058366], [5.253110e-01], [-0.0434652828], [-0.305256933], [0.193365857], [0.0105065238], [-0.143788248], [-0.0161222648], [0.0230324212], [-0.34107244], [-0.273072243], [-0.104352467], [0.0164068397], [-1.32305741], [-0.0345043093], [-0.232206389], [-0.150001124], [0.119475454], [0.730642438], [-0.407772154], [-0.0164191965], [-1.625590e-01], [-0.112515017], [0.158920377], [-0.0997497215], [0.0788274407], [1.1542908], [0.492949218], [-0.125796661], [0.0107790371], [0.141159713], [-0.0774109289], [-0.438130081], [-0.0888700857], [0.207725927], [-0.0913108587], [0.258232892], [0.0672571063], [-0.100412264], [1.68460846], [-0.289168775], [-0.686722457], [0.903651654], [0.110602334], [-0.0505490415], [1.31204939], [0.136107579], [0.26376456], [-0.508291602], [-0.0118971812], [-0.0373991691], [0.448705465], [0.00448446581], [-0.165114298], [0.156860754], [0.141124308], [-0.272756487], [-0.0834815949], [0.020905681], [-0.0877983123], [-1.0087887], [-0.353012145], [-0.0439243801], [-0.00592191564], [-0.0637216269], [0.175808683], [-0.193864927], [-0.0574007072], [0.390869558], [0.138100505], [0.429396927], [1.10117233], [-0.362377733], [0.116578773], [0.0540139228], [-5.85162896E-4], [-0.335441321], [-0.0902953073], [0.017575942], [-0.0359748788], [1.50025952], [-0.668821096], [0.0109066488], [9.907780e-01], [0.10227681], [-0.0582750589], [0.0172416102], [0.0429656394], [0.0465254933], [0.350135148], [-0.260139734], [0.199394852], [-0.136131078], [0.241424322], [0.855418264], [-0.160689577], [-0.825074911], [-0.124827594], [0.0153419804], [0.389386117], [0.153694436], [-0.897866904], [-0.292769879], [0.181667477], [-0.188009143], [-0.0245181341], [-2.17088842], [-0.0526076891], [-0.108600065], [0.187120304], [0.171495944], [0.310159177], [2.204240e+00], [0.0506350659], [-0.159419239], [-0.145082235], [-0.0991335287], [-0.0680764392], [-0.311415762], [-0.187137261], [-0.416945577], [0.0703471377], [0.498331547], [-0.41216433], [-0.427900195], [0.102105901], [0.130767033], [-0.440281332], [0.778514624], [-0.253678083], [0.395671815], [0.380029172], [-0.418493837], [-0.288157403], [0.0689846799], [1.269960e+00], [-0.0585722439], [-0.138125435], [-0.191710189], [0.0163070802], [0.159242466], [0.116627224], [0.289637923], [-0.299413532], [-0.0216965247], [0.271396786], [0.250576884], [-0.131420374], [0.137698188], [-0.0102280416], [0.234722644], [-0.0366179943], [-0.105632246], [-0.145528033], [-0.278210133], [-0.247100428], [0.217718393], [0.171669215], [0.0151556451], [0.961385667], [-0.0484847203], [0.434219301], [-0.00167646946], [-0.0308207348], [-0.102328695], [-0.127907664], [-0.185960412], [0.210866481], [0.140434876], [-0.233541235], [-0.123745643], [-0.0113738365], [1.30043447], [0.179708347], [-0.331716627], [0.0133318678], [-0.107284561], [-0.114116102], [-0.478514463], [0.0616452768], [-0.781869769], [-0.121830635], [-0.0684970543], [-6.584100e-02], [-0.131784603], [-0.619898796], [0.160366163], [-0.50115186], [0.0228514839], [0.581515431], [4.220270e-01], [1.944400e-01], [-1.07740963], [3.732520e-01], [0.725471556], [-0.117193311], [-0.105938725], [0.320118755], [-0.484032601], [-0.0467250831]]> : tensor<200x1xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor, tensor<200x1xf32>) -> tensor + return %1 : tensor + +// SYMSYMI8-LABEL: func.func @test_correctness_of_symmetric_quant_for_activation_and_weight +// SYMSYMI8-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// SYMSYMI8-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}[0], [30], [-3], [-18], [11], [1], [-8], [-1], [1], [-20], [-16], [-6], [1], [-76], [-2], [-13], [-9], [7], [42], [-23], [-1], [-9], [-6], [9], [-6], [5], [67], [28], [-7], [1], [8], [-4], [-25], [-5], [12], [-5], [15], [4], [-6], [97], [-17], [-40], [52], [6], [-3], [76], [8], [15], [-29], [-1], [-2], [26], [0], [-10], [9], [8], [-16], [-5], [1], [-5], [-58], [-20], [-3], [0], [-4], [10], [-11], [-3], [23], [8], [25], [63], [-21], [7], [3], [0], [-19], [-5], [1], [-2], [86], [-39], [1], [57], [6], [-3], [1], [2], [3], [20], [-15], [11], [-8], [14], [49], [-9], [-48], [-7], [1], [22], [9], [-52], [-17], [10], [-11], [-1], [-125], [-3], [-6], [11], [10], [18], [127], [3], [-9], [-8], [-6], [-4], [-18], [-11], [-24], [4], [29], [-24], [-25], [6], [8], [-25], [45], [-15], [23], [22], [-24], [-17], [4], [73], [-3], [-8], [-11], [1], [9], [7], [17], [-17], [-1], [16], [14], [-8], [8], [-1], [14], [-2], [-6], [-8], [-16], [-14], [13], [10], [1], [55], [-3], [25], [0], [-2], [-6], [-7], [-11], [12], [8], [-13], [-7], [-1], [75], [10], [-19], [1], [-6], [-7], [-28], [4], [-45], [-7], [-4], [-4], [-8], [-36], [9], [-29], [1], [34], [24], [11], [-62], [22], [42], [-7], [-6], [18], [-28], [-3]{{.}}> : tensor<200x1xi8> +// SYMSYMI8-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<57.61623> : tensor +// SYMSYMI8-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none +// SYMSYMI8-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// SYMSYMI8-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// SYMSYMI8: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_2_]], [[VAR_2_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 1 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// SYMSYMI8: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_0_]], [[VAR_1_]], [[VAR_3_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<200x1xi8>, tensor, tensor) -> (tensor<200x1xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// SYMSYMI8: [[VAR_Out_3_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_1_]], [[VAR_3_]], [[VAR_2_]], [[VAR_4_]], [[VAR_3_]], [[VAR_4_]], [[VAR_3_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<200x1xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, none, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// SYMSYMI8: [[VAR_5_:%.+]] = "zhigh.Unstick"([[VAR_Out_3_]]) : (tensor>) -> tensor +// SYMSYMI8: return [[VAR_5_]] : tensor +// SYMSYMI8: } +} + +// ----- + +func.func @test_matmul(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<-0.00718058366> : tensor<200x1xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor, tensor<200x1xf32>) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @test_matmul +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2.540000e+04> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-127> : tensor<200x1xi8> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<17686.584> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_3_]], [[VAR_3_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_1_]], [[VAR_2_]], [[VAR_4_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<200x1xi8>, tensor, tensor) -> (tensor<200x1xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_6_:%.+]] = "onnx.Div"([[VAR_5_]], [[VAR_RecScale_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_6_]], [[VAR_2_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_7_]], [[VAR_Offset_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Sub"([[VAR_4_]], [[VAR_8_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Mul"([[VAR_9_]], [[VAR_0_]]) : (tensor, tensor<1xf32>) -> tensor<1xf32> +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_10_]], [[VAR_5_]], [[VAR_4_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<1xf32>, tensor, tensor) -> (tensor<1xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_2_]], [[VAR_4_]], [[VAR_Out_]]_3, [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_5_]], [[VAR_4_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<200x1xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<1xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_11_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_11_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_matmul_add(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<-0.00718058366> : tensor<200x10xf32> + %1 = onnx.Constant dense<-0.00718058366> : tensor<10xf32> + %2 = "onnx.MatMul"(%arg0, %0) : (tensor, tensor<200x10xf32>) -> tensor + %3 = "onnx.Add"(%2, %1): (tensor, tensor<10xf32>) -> tensor + return %3 : tensor + +// CHECK-LABEL: func.func @test_matmul_add +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2.540000e+04> : tensor<10xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-127> : tensor<200x10xi8> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<17686.584> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<-0.00718058366> : tensor<10xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_4_]], [[VAR_4_]]) {layout = "3DS", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_1_]], [[VAR_2_]], [[VAR_5_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<200x10xi8>, tensor, tensor) -> (tensor<200x10xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_6_]], [[VAR_RecScale_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = "onnx.Div"([[VAR_7_]], [[VAR_2_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Mul"([[VAR_8_]], [[VAR_Offset_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Sub"([[VAR_5_]], [[VAR_9_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_11_:%.+]] = "onnx.Mul"([[VAR_10_]], [[VAR_0_]]) : (tensor, tensor<10xf32>) -> tensor<10xf32> +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_11_]], [[VAR_6_]], [[VAR_5_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<10xf32>, tensor, tensor) -> (tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_12_:%.+]] = "zhigh.Stick"([[VAR_3_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_13_:%.+]] = "zhigh.Add"([[VAR_Out_3_]], [[VAR_12_]]) : (tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>> +// CHECK: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_2_]], [[VAR_5_]], [[VAR_13_]], [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_6_]], [[VAR_5_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<200x10xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_14_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_14_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_gemm(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<-0.00718058366> : tensor<200x10xf32> + %1 = onnx.Constant dense<-0.00718058366> : tensor<10xf32> + %2 = "onnx.Gemm"(%arg0, %0, %1) {transA = 0 : si64, transB = 0 : si64, alpha = 1.0 : f32, beta = 1.0 : f32} : (tensor, tensor<200x10xf32>, tensor<10xf32>) -> tensor + return %2 : tensor + +// CHECK-LABEL: func.func @test_gemm +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-2.540000e+04> : tensor<10xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-127> : tensor<200x10xi8> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<17686.584> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<-0.00718058366> : tensor<10xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor +// CHECK: [[VAR_Out_:%.+]], [[VAR_RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_4_]], [[VAR_4_]]) {layout = "2D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor, none, none) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_Out_0_:%.+]], [[VAR_RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[VAR_1_]], [[VAR_2_]], [[VAR_5_]]) {layout = "2D", quantized_type = "WEIGHTS", sym_mode = 0 : i64} : (tensor<200x10xi8>, tensor, tensor) -> (tensor<200x10xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_6_]], [[VAR_RecScale_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = "onnx.Div"([[VAR_7_]], [[VAR_2_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Mul"([[VAR_8_]], [[VAR_Offset_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Sub"([[VAR_5_]], [[VAR_9_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_11_:%.+]] = "onnx.Mul"([[VAR_10_]], [[VAR_0_]]) : (tensor, tensor<10xf32>) -> tensor<10xf32> +// CHECK: [[VAR_Out_3_:%.+]], [[VAR_RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[VAR_11_]], [[VAR_6_]], [[VAR_5_]]) {layout = "1D", quantized_type = "DLFLOAT16", sym_mode = 0 : i64} : (tensor<10xf32>, tensor, tensor) -> (tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[VAR_12_:%.+]] = "zhigh.Stick"([[VAR_3_]]) {layout = "1D"} : (tensor<10xf32>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>> +// CHECK: [[VAR_13_:%.+]] = "zhigh.Add"([[VAR_Out_3_]], [[VAR_12_]]) : (tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor<10xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>> +// CHECK: [[VAR_Out_6_:%.+]], [[VAR_OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[VAR_Out_]], [[VAR_RecScale_]], [[VAR_Offset_]], [[VAR_Out_]]_0, [[VAR_2_]], [[VAR_5_]], [[VAR_13_]], [[VAR_RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_6_]], [[VAR_5_]]) {DequantizeOutput = 0 : si64, DisableClipping = -1 : si64, PreComputedBias = -1 : si64} : (tensor>, tensor, tensor, tensor<200x10xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<10xf16, #zhigh.layout<{dataLayout = "1D", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor, tensor) -> (tensor>, tensor, tensor) +// CHECK: [[VAR_14_:%.+]] = "zhigh.Unstick"([[VAR_Out_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_14_]] : tensor +// CHECK: } +} + +// ----- + +// Do not quantize because B is not a constant. +func.func @test_matmul_not_quantized(%arg0: tensor, %arg1: tensor<200x1xf32>) -> tensor { + %1 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<200x1xf32>) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @test_matmul_not_quantized +// CHECK: "zhigh.MatMul" +// CHECK-NOT: "zhigh.QuantizedMatMul" +} + +// ----- + +// Do not quantize because C is not a constant. +func.func @test_matmul_add_not_quantized(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { + %0 = onnx.Constant dense<-0.00718058366> : tensor<200x10xf32> + %1 = "onnx.MatMul"(%arg0, %0) : (tensor, tensor<200x10xf32>) -> tensor + %2 = "onnx.Add"(%1, %arg1): (tensor, tensor<10xf32>) -> tensor + return %2 : tensor + +// CHECK-LABEL: func.func @test_matmul_add_not_quantized +// CHECK: "zhigh.MatMul" +// CHECK-NOT: "zhigh.QuantizedMatMul" +} + +// ----- + +// Do not quantize because A is transposed. +func.func @test_gemm_not_quantized(%arg0: tensor<200x?xf32>) -> tensor { + %0 = onnx.Constant dense<-0.00718058366> : tensor<200x10xf32> + %1 = onnx.Constant dense<-0.00718058366> : tensor<10xf32> + %2 = "onnx.Gemm"(%arg0, %0, %1) {transA = 1 : si64, transB = 0 : si64, alpha = 1.0 : f32, beta = 1.0 : f32} : (tensor<200x?xf32>, tensor<200x10xf32>, tensor<10xf32>) -> tensor + return %2 : tensor + +// CHECK-LABEL: func.func @test_gemm_not_quantized +// CHECK: "zhigh.MatMul" +// CHECK-NOT: "zhigh.QuantizedMatMul" +} + diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemax.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemax.mlir new file mode 100644 index 0000000000..77c6052512 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemax.mlir @@ -0,0 +1,75 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_reduce_max_axes_defined_noop_0(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<[2]> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMax"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_reduce_max_axes_defined_noop_0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<3x2x2xf32>) -> tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.ReduceMax"([[VAR_1_]]) : (tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<3x2x1xf32> +// CHECK: return [[VAR_3_]] : tensor<3x2x1xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_max_axes_minus_one(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<-1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMax"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_max_axes_minus_one +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<3x2x2xf32>) -> tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.ReduceMax"([[VAR_1_]]) : (tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<3x2x1xf32> +// CHECK: return [[VAR_3_]] : tensor<3x2x1xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_max_not_lowered_unknown_axis(%arg0 : tensor<3x2x2xf32>, %arg1: tensor<1xi64>) -> tensor<*xf32> { + %0 ="onnx.ReduceMax"(%arg0, %arg1) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_max_not_lowered_unknown_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.ReduceMax"([[PARAM_0_]], [[PARAM_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_reduce_max_axes_not_lowered_not_innermost_axis(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<0> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMax"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_max_axes_not_lowered_not_innermost_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<1x2x2xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.ReduceMax"([[PARAM_0_]], [[VAR_0_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>) -> tensor<1x2x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x2x2xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_max_axes_not_lowered_not_multiple_axes(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<[2, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 ="onnx.ReduceMax"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<2xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_max_axes_not_lowered_not_multiple_axes +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<1x2x1xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 0]> : tensor<2xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.ReduceMax"([[PARAM_0_]], [[VAR_0_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<2xi64>) -> tensor<1x2x1xf32> +// CHECK: return [[VAR_1_]] : tensor<1x2x1xf32> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemean.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemean.mlir index 686c684be3..d4f8e87a99 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemean.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemean.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @should_lower_to_zhigh(%arg0 : tensor<1x3x5x7xf32>) -> tensor<*xf32> { %0 = "onnx.ReduceMeanV13"(%arg0) { axes = [2, 3] }: (tensor<1x3x5x7xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemin.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemin.mlir new file mode 100644 index 0000000000..74c825a85c --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/reducemin.mlir @@ -0,0 +1,75 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_reduce_min_axes_defined_noop_0(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<[2]> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMin"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_reduce_min_axes_defined_noop_0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<3x2x2xf32>) -> tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.ReduceMin"([[VAR_1_]]) : (tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<3x2x1xf32> +// CHECK: return [[VAR_3_]] : tensor<3x2x1xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_min_axes_minus_one(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<-1> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMin"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_min_axes_minus_one +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<3x2x1xf32> { +// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<3x2x2xf32>) -> tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.ReduceMin"([[VAR_1_]]) : (tensor<3x2x2xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> +// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<3x2x1xf32> +// CHECK: return [[VAR_3_]] : tensor<3x2x1xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_min_not_lowered_unknown_axis(%arg0 : tensor<3x2x2xf32>, %arg1: tensor<1xi64>) -> tensor<*xf32> { + %0 ="onnx.ReduceMin"(%arg0, %arg1) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_min_not_lowered_unknown_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.ReduceMin"([[PARAM_0_]], [[PARAM_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +func.func @test_reduce_min_axes_not_lowered_not_innermost_axis(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<0> : tensor<1xi64> } : () -> tensor<1xi64> + %0 ="onnx.ReduceMin"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_min_axes_not_lowered_not_innermost_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<1x2x2xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.ReduceMin"([[PARAM_0_]], [[VAR_0_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>) -> tensor<1x2x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x2x2xf32> +// CHECK: } +} + +// ----- + +func.func @test_reduce_min_axes_not_lowered_not_multiple_axes(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { + %cst = "onnx.Constant"() {value = dense<[2, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %0 ="onnx.ReduceMin"(%arg0, %cst) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<2xi64>)-> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_reduce_min_axes_not_lowered_not_multiple_axes +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2x2xf32>) -> tensor<1x2x1xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 0]> : tensor<2xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.ReduceMin"([[PARAM_0_]], [[VAR_0_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x2x2xf32>, tensor<2xi64>) -> tensor<1x2x1xf32> +// CHECK: return [[VAR_1_]] : tensor<1x2x1xf32> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/relu.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/relu.mlir index a330794899..7b2d4c9138 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/relu.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/relu.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sigmoid.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sigmoid.mlir index 3b7c3f9f8a..91e1be781f 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sigmoid.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sigmoid.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_sigmoid(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Sigmoid"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir index 70fbeda772..1fac5a1f4a 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sqrt.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sqrt.mlir new file mode 100644 index 0000000000..48c3fc0b81 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sqrt.mlir @@ -0,0 +1,27 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s + +func.func @test_sqrt(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_sqrt +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_1_:%.+]] = "zhigh.Sqrt"([[VAR_0_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_2_:%.+]] = "zhigh.Unstick"([[VAR_1_]]) : (tensor<10x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<10x10xf32> +// CHECK: return [[VAR_2_]] : tensor<10x10xf32> +// CHECK: } +} + +// ----- + +/// COM: Test for zdnn limitation. +/// COM: Not lowered when dimensin size exceeds DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE in `third_party/zdnn-lib/zdnn_limit.h` +/// COM: DLCPP_MAXIMUM_DIMENSION_INDEX_SIZE depends on zAIU HW. Please check the value if these tests fails. + +func.func @test_exceed_limit_sqrt(%arg0 : tensor<2097152x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor<2097152x10xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func @test_exceed_limit_sqrt +// CHECK: "onnx.Sqrt" +} diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sub.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sub.mlir index e9e40305e4..c1e5593e36 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sub.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sub.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Sub"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sum.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sum.mlir index aaff8185bf..6f31e650b4 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sum.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/sum.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s // COM: Check the singleton case of lowering ONNXSumOp to ZHighAddOp, // COM: where ONNXSumOp has two inputs and is lowered to a single ZHighAddOp. diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/tanh.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/tanh.mlir index a1a755da1a..a53daee5a1 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/tanh.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/tanh.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh %s -split-input-file | FileCheck %s func.func @test_tanh(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Tanh"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir new file mode 100644 index 0000000000..3e7d0b4c0d --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh-arch15.mlir @@ -0,0 +1,245 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh --canonicalize %s -split-input-file | FileCheck %s + +// ----- + +// Do not Split MatMul because a dimension does not exceeds NNPAGetMaxForDim for e2 of 1048576. + +func.func @test_matmul_no_splitting_arch15_A(%arg0: tensor, %arg1: tensor<768x1024xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x1024xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_no_splitting_arch15_A +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x1024xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[A_]], [[B_]]) : (tensor, tensor<768x1024xf32>) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// Split MatMul because a dimension exceeds NNPAGetMaxForDim for e2 on arch15 of 1048576: use 2097152 + +func.func @test_matmul_splitting_arch15_A(%arg0: tensor, %arg1: tensor<768x1024xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x1024xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_splitting_arch15_A +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x1024xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1048576> : tensor<2xi64> +// CHECK: [[VAR_1_:%.+]]:2 = "onnx.Split"([[A_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor, tensor<2xi64>) -> (tensor, tensor) +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.MatMul"([[VAR_1_]]#0, [[B_]]) : (tensor, tensor<768x1024xf32>) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_1_]]#1, [[B_]]) : (tensor, tensor<768x1024xf32>) -> tensor +// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 1 : si64} : (tensor, tensor) -> tensor +// CHECK: return [[VAR_4_]] : tensor +// CHECK: } +} + +// ----- + +// Do not split MatMul because a dimension does not exceeds NNPAGetMaxForDim e1 on arch15 of 2097152. + +func.func @test_matmul_no_splitting_arch15_B(%arg0: tensor, %arg1: tensor<768x2097152xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x2097152xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_no_splitting_arch15_B +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x2097152xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[A_]], [[B_]]) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// Split MatMul because a dimension exceeds NNPAGetMaxForDim e1 on arch15 of 2097152: use 4194304. + +func.func @test_matmul_splitting_arch15_B(%arg0: tensor, %arg1: tensor<768x4194304xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x4194304xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_splitting_arch15_B +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x4194304xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<2097152> : tensor<2xi64> +// CHECK: [[VAR_1_:%.+]]:2 = "onnx.Split"([[B_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<768x4194304xf32>, tensor<2xi64>) -> (tensor<768x2097152xf32>, tensor<768x2097152xf32>) +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.MatMul"([[A_]], [[VAR_1_]]#0) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.MatMul"([[A_]], [[VAR_1_]]#1) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 2 : si64} : (tensor, tensor) -> tensor +// CHECK: return [[VAR_4_]] : tensor +// CHECK: } +} + +// ----- + +// No split MatMul because a dimension does not exceeds NNPAGetMaxForDim for e2/e1 on arch15 of 1048576 / 2097152 + +func.func @test_matmul_no_splitting_arch15_A_B(%arg0: tensor, %arg1: tensor<768x2097152xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x2097152xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_no_splitting_arch15_A_B +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x2097152xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[A_]], [[B_]]) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// Split MatMul because a dimension exceeds NNPAGetMaxForDim for e2/e1 on arch15 of 1048576 / 2097152: use 2097152 and 4194304 + +func.func @test_matmul_splitting_arch15_A_B(%arg0: tensor, %arg1: tensor<768x4194304xf32>) -> (tensor) { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<768x4194304xf32>) -> tensor + return %0 : tensor + +// mlir2FileCheck.py -a '["A","B"]' +// CHECK-LABEL: func.func @test_matmul_splitting_arch15_A_B +// CHECK-SAME: ([[A_:%.+]]: tensor, [[B_:%.+]]: tensor<768x4194304xf32>) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1048576> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]]:2 = "onnx.Split"([[A_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor, tensor<2xi64>) -> (tensor, tensor) +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<2097152> : tensor<2xi64> +// CHECK: [[VAR_3_:%.+]]:2 = "onnx.Split"([[B_]], [[VAR_2_]]) {axis = 1 : si64} : (tensor<768x4194304xf32>, tensor<2xi64>) -> (tensor<768x2097152xf32>, tensor<768x2097152xf32>) +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.MatMul"([[VAR_1_]]#0, [[VAR_3_]]#0) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.MatMul"([[VAR_1_]]#0, [[VAR_3_]]#1) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_4_]], [[VAR_5_]]) {axis = 2 : si64} : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.MatMul"([[VAR_1_]]#1, [[VAR_3_]]#0) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.MatMul"([[VAR_1_]]#1, [[VAR_3_]]#1) : (tensor, tensor<768x2097152xf32>) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Concat"([[VAR_7_]], [[VAR_8_]]) {axis = 2 : si64} : (tensor, tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_6_]], [[VAR_9_]]) {axis = 1 : si64} : (tensor, tensor) -> tensor +// CHECK: return [[VAR_10_]] : tensor +// CHECK: } +} + +// ----- + +// Rewrite N-D QLinearMatMul into 3-D one. + +func.func @test_nd_qlinearmatmul_nd_nd(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg1: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + %2 = "onnx.QLinearMatMul"(%0, %arg2, %arg3, %1, %arg2, %arg3, %arg2, %arg3) : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %3 = "onnx.DequantizeLinear"(%2, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + return %3 : tensor + +// CHECK-LABEL: func.func @test_nd_qlinearmatmul +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_1_:%.+]]: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { + // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> + // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<4> : tensor<1xi64> + // CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> + // CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> + // CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> + // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> + // CHECK: [[VAR_7_:%.+]] = "onnx.Slice"([[VAR_6_]], [[VAR_3_]], [[VAR_2_]], [[VAR_4_]], [[VAR_1_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> + // CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + // CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor, tensor<3xi64>) -> tensor + // CHECK-DAG: [[VAR_10_:%.+]] = "onnx.QuantizeLinear"([[VAR_9_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor + // CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Shape"([[PARAM_1_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> + // CHECK: [[VAR_12_:%.+]] = "onnx.Slice"([[VAR_11_]], [[VAR_3_]], [[VAR_2_]], [[VAR_4_]], [[VAR_1_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> + // CHECK: [[VAR_13_:%.+]] = "onnx.Concat"([[VAR_5_]], [[VAR_12_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + // CHECK: [[VAR_14_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_13_]]) {allowzero = 0 : si64} : (tensor, tensor<3xi64>) -> tensor + // CHECK: [[VAR_15_:%.+]] = "onnx.QuantizeLinear"([[VAR_14_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor + // CHECK: [[VAR_16_:%.+]] = "onnx.QLinearMatMul"([[VAR_10_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_15_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK-DAG: [[VAR_17_:%.+]] = "onnx.DequantizeLinear"([[VAR_16_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor + // CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> + // CHECK-DAG: [[VAR_19_:%.+]] = "onnx.Shape"([[PARAM_1_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> + // CHECK-NOT: separator of consecutive DAGs + // CHECK-DAG: [[VAR_20_:%.+]] = "onnx.Slice"([[VAR_18_]], [[VAR_4_]], [[VAR_0_]], [[VAR_4_]], [[VAR_1_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK-DAG: [[VAR_21_:%.+]] = "onnx.Slice"([[VAR_19_]], [[VAR_0_]], [[VAR_2_]], [[VAR_4_]], [[VAR_1_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + // CHECK: [[VAR_22_:%.+]] = "onnx.Concat"([[VAR_20_]], [[VAR_21_]]) {axis = 0 : si64} : (tensor<3xi64>, tensor<1xi64>) -> tensor<4xi64> + // CHECK: [[VAR_23_:%.+]] = "onnx.Reshape"([[VAR_17_]], [[VAR_22_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor + // CHECK: return [[VAR_23_]] : tensor + // CHECK: } +} + +func.func @test_nd_qlinearmatmul_nd_2d(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg1: tensor<64x384xf32>, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor<64x384xf32>, tensor, tensor) -> tensor<64x384xi8> + %2 = "onnx.QLinearMatMul"(%0, %arg2, %arg3, %1, %arg2, %arg3, %arg2, %arg3) : (tensor, tensor, tensor, tensor<64x384xi8>, tensor, tensor, tensor, tensor) -> tensor + %3 = "onnx.DequantizeLinear"(%2, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + return %3 : tensor + +// CHECK-LABEL: func.func @test_nd_qlinearmatmul_nd_2d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_1_:%.+]]: tensor<64x384xf32>, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[64, 384]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<4> : tensor<1xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.QuantizeLinear"([[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor<64x384xf32>, tensor, tensor) -> tensor<64x384xi8> +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Slice"([[VAR_8_]], [[VAR_4_]], [[VAR_3_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_6_]], [[VAR_9_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_10_]]) {allowzero = 0 : si64} : (tensor, tensor<3xi64>) -> tensor +// CHECK: [[VAR_12_:%.+]] = "onnx.QuantizeLinear"([[VAR_11_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_13_:%.+]] = "onnx.QLinearMatMul"([[VAR_12_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_7_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) : (tensor, tensor, tensor, tensor<64x384xi8>, tensor, tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.DequantizeLinear"([[VAR_13_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Slice"([[VAR_15_]], [[VAR_5_]], [[VAR_1_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Slice"([[VAR_0_]], [[VAR_2_]], [[VAR_4_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]]) {axis = 0 : si64} : (tensor<3xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_19_:%.+]] = "onnx.Reshape"([[VAR_14_]], [[VAR_18_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK: return [[VAR_19_]] : tensor +// CHECK: } +} + +func.func @test_nd_qlinearmatmul_2d_nd(%arg0: tensor<384x64xf32>, %arg1: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor<384x64xf32>, tensor, tensor) -> tensor<384x64xi8> + %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + %2 = "onnx.QLinearMatMul"(%0, %arg2, %arg3, %1, %arg2, %arg3, %arg2, %arg3) : (tensor<384x64xi8>, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + %3 = "onnx.DequantizeLinear"(%2, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + return %3 : tensor + +// CHECK-LABEL: func.func @test_nd_qlinearmatmul_2d_nd +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<384x64xf32>, [[PARAM_1_:%.+]]: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[384, 64]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<4> : tensor<1xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor<384x64xf32>, tensor, tensor) -> tensor<384x64xi8> +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Shape"([[PARAM_1_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Slice"([[VAR_8_]], [[VAR_4_]], [[VAR_3_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: [[VAR_10_:%.+]] = "onnx.Concat"([[VAR_6_]], [[VAR_9_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_10_]]) {allowzero = 0 : si64} : (tensor, tensor<3xi64>) -> tensor +// CHECK: [[VAR_12_:%.+]] = "onnx.QuantizeLinear"([[VAR_11_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_13_:%.+]] = "onnx.QLinearMatMul"([[VAR_7_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_12_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) : (tensor<384x64xi8>, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.DequantizeLinear"([[VAR_13_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.Shape"([[PARAM_1_]]) {start = 0 : si64} : (tensor) -> tensor<4xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Slice"([[VAR_15_]], [[VAR_5_]], [[VAR_4_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Slice"([[VAR_0_]], [[VAR_5_]], [[VAR_2_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Slice"([[VAR_15_]], [[VAR_1_]], [[VAR_3_]], [[VAR_5_]], [[VAR_2_]]) : (tensor<4xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: [[VAR_19_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_18_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<4xi64> +// CHECK: [[VAR_20_:%.+]] = "onnx.Reshape"([[VAR_14_]], [[VAR_19_]]) {allowzero = 0 : si64} : (tensor, tensor<4xi64>) -> tensor +// CHECK: return [[VAR_20_]] : tensor +// CHECK: } +} + +// Do not rewrite because of potential broadcasting. +func.func @test_nd_qlinearmatmul_nd_nd_not_rewriting(%arg0: tensor {onnx.dim_params = "0:bs,1:sl"}, %arg1: tensor<1x?x64x384xf32> {onnx.dim_params = "1:sl"}, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "onnx.QuantizeLinear"(%arg0, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + %1 = "onnx.QuantizeLinear"(%arg1, %arg2, %arg3) : (tensor<1x?x64x384xf32>, tensor, tensor) -> tensor<1x?x64x384xi8> + %2 = "onnx.QLinearMatMul"(%0, %arg2, %arg3, %1, %arg2, %arg3, %arg2, %arg3) : (tensor, tensor, tensor, tensor<1x?x64x384xi8>, tensor, tensor, tensor, tensor) -> tensor + %3 = "onnx.DequantizeLinear"(%2, %arg2, %arg3) : (tensor, tensor, tensor) -> tensor + return %3 : tensor + +// CHECK-LABEL: func.func @test_nd_qlinearmatmul_nd_nd_not_rewriting +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:bs,1:sl"}, [[PARAM_1_:%.+]]: tensor<1x?x64x384xf32> {onnx.dim_params = "1:sl"}, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.QuantizeLinear"([[PARAM_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.QuantizeLinear"([[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64, saturate = 1 : si64} : (tensor<1x?x64x384xf32>, tensor, tensor) -> tensor<1x?x64x384xi8> +// CHECK: [[VAR_2_:%.+]] = "onnx.QLinearMatMul"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_2_]], [[PARAM_3_]]) : (tensor, tensor, tensor, tensor<1x?x64x384xi8>, tensor, tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_3_:%.+]] = "onnx.DequantizeLinear"([[VAR_2_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: return [[VAR_3_]] : tensor +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir index e0679a6870..f788a0091a 100644 --- a/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir +++ b/test/mlir/accelerators/nnpa/conversion/rewrite-onnx-for-zhigh.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --rewrite-onnx-for-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --rewrite-onnx-for-zhigh --shape-inference --canonicalize --constprop-onnx --shape-inference %s --split-input-file | FileCheck --check-prefix=CONSTPROP %s func.func @test_batchnorm_epsilon(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>) -> tensor<2x3x4x5xf32> { %0 = "onnx.BatchNormalizationInferenceMode"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/add.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/add.mlir index c1f3324bc9..1b58dab4ed 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/add.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/add.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_add() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/max.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/max.mlir index ab1cbbed3c..32b483a241 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/max.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/max.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_max() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/min.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/min.mlir index 8a12a30542..3995ef0f75 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/min.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/min.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_min() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/mul.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/mul.mlir index fecda71427..84aaf9a14d 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/mul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/mul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_mul() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/relu.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/relu.mlir index 7d50da0f86..ae3cff3538 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/relu.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/relu.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_log() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/sub.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/sub.mlir index f4ce23eaa5..e84f89eeff 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/sub.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-onnx/sub.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-zhigh-to-onnx %s -split-input-file | FileCheck %s func.func @test_sub() -> tensor<10x10xf32> { %cst0 = onnx.Constant dense<1.0> : tensor<10x10xf32> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/add.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/add.mlir index eaad58308b..dbfd014dc4 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/add.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/add.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Add"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir index b8cef7cf2b..270940f2b6 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/compiler-stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir index 2d2983ba07..a4002b2227 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/conv.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @conv_valid_padding(%arg0: tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, %arg1: tensor<2x2x3x1xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, %arg2: tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { %0 = "zhigh.Conv2D"(%arg0, %arg1, %arg2) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1], act_func = "ACT_NONE"} : (tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x1xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/div.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/div.mlir index 8ddd718edd..4cb93ea443 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/div.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/div.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Div"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/dlf16_to_f32.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/dlf16_to_f32.mlir index 07ae1fbd06..972adfaf08 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/dlf16_to_f32.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/dlf16_to_f32.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/exp.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/exp.mlir index cd4c6cd6a9..77876561fb 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/exp.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/exp.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Exp"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/f32_to_dlf16.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/f32_to_dlf16.mlir index 696cb8401a..a038b6887e 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/f32_to_dlf16.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/f32_to_dlf16.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gelu.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gelu.mlir new file mode 100644 index 0000000000..44acd88099 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gelu.mlir @@ -0,0 +1,50 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.Gelu"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x5xf16, #map> { +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x5xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.gelu"([[PARAM_0_]], [[RES_1_]], [[RES_]]) {layout = "3D"} : (memref<3x4x5xf16, #map>, memref<3xi64>, memref<3x4x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x5xf16, #map> +// CHECK: } +} + +// ----- + +func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.Gelu"(%arg0) : (tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow_unknown_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x?x5xf16, #map>) -> memref<3x?x5xf16, #map> { +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.gelu"([[PARAM_0_]], [[RES_1_]], [[RES_]]) {layout = "3D"} : (memref<3x?x5xf16, #map>, memref<3xi64>, memref<3x?x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x?x5xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir index b828da2b80..930b1a0179 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/gru.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @gru_return_single_step(%input : tensor<3x5x7xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %h0 : tensor<1x5x9xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %input_weights : tensor<1x7x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %input_bias : tensor<1x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %hidden_weights : tensor<1x9x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %hidden_bias : tensor<1x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>) -> tensor<*xf16> { diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/invsqrt.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/invsqrt.mlir new file mode 100644 index 0000000000..ae34ab5495 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/invsqrt.mlir @@ -0,0 +1,51 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.InvSqrt"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x5xf16, #map> { +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x5xf16, #map> +// CHECK-DAG: [[RES_0_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_0_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_0_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_0_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.invsqrt"([[PARAM_0_]], [[RES_0_]], [[RES_]]) {layout = "3D"} : (memref<3x4x5xf16, #map>, memref<3xi64>, memref<3x4x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x5xf16, #map> +// CHECK: } +} + +// ----- + +func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.InvSqrt"(%arg0) : (tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow_unknown_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x?x5xf16, #map>) -> memref<3x?x5xf16, #map> { +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_0_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_0_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: krnl.store [[VAR_3_]], [[RES_0_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_0_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.invsqrt"([[PARAM_0_]], [[RES_0_]], [[RES_]]) {layout = "3D"} : (memref<3x?x5xf16, #map>, memref<3xi64>, memref<3x?x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x?x5xf16, #map> +// CHECK: } + +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/leakyrelu.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/leakyrelu.mlir new file mode 100644 index 0000000000..b5b2477027 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/leakyrelu.mlir @@ -0,0 +1,50 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.LeakyRelu"(%arg0) {alpha = 0.02 : f32} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x5xf16, #map> { +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x5xf16, #map> +// CHECK-DAG: [[RES_0_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_0_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_0_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_0_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.leakyrelu"([[PARAM_0_]], [[RES_0_]], [[RES_]]) {alpha = 2.000000e-02 : f32, layout = "3D"} : (memref<3x4x5xf16, #map>, memref<3xi64>, memref<3x4x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x5xf16, #map> +// CHECK: } +} + +// ----- + +func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.LeakyRelu"(%arg0) : (tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow_unknown_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x?x5xf16, #map>) -> memref<3x?x5xf16, #map> { +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_0_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_0_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: krnl.store [[VAR_3_]], [[RES_0_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_0_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.leakyrelu"([[PARAM_0_]], [[RES_0_]], [[RES_]]) {alpha = 0.00999999977 : f32, layout = "3D"} : (memref<3x?x5xf16, #map>, memref<3xi64>, memref<3x?x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x?x5xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/log.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/log.mlir index 7263c491da..b9ce99f445 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/log.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/log.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Log"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir index e63d5cee97..1011c632d8 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/lstm.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s func.func @lstm_return_single_step(%input : tensor<3x5x7xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %h0 : tensor<1x5x9xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %c0 : tensor<1x5x9xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %input_weights : tensor<1x7x36xf16, #zhigh.layout<{dataLayout = "FICO"}>>, %input_bias : tensor<1x36xf16, #zhigh.layout<{dataLayout = "FICO"}>>, %hidden_weights : tensor<1x9x36xf16, #zhigh.layout<{dataLayout = "FICO"}>>, %hidden_bias : tensor<1x36xf16, #zhigh.layout<{dataLayout = "FICO"}>>) -> (tensor<*xf16>, tensor<*xf16>) { diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/matmul.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/matmul.mlir index f9ee8d6786..e67f3d5423 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/matmul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/matmul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @matmul(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> @@ -19,13 +19,89 @@ func.func @matmul(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %a // CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> // CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> // CHECK: krnl.store [[VAR_c16_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> -// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<4x8xf16, #map>, memref<8x16xf16, #map>, memref<16xf16, #map1>, memref<3xi64>, memref<4x16xf16, #map>) -> () +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<4x8xf16, #map>, memref<8x16xf16, #map>, memref<16xf16, #map1>, memref<3xi64>, memref<4x16xf16, #map>) -> () // CHECK: return [[RES_]] : memref<4x16xf16, #map> // CHECK: } } // ----- +func.func @matmul_transposeA(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { + %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) {transposeA = 1 : si64, transposeB = 0 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> +// CHECK-DAG: #map1 = affine_map<(d0) -> (0, d0 floordiv 64, 0, 0, 31, d0 mod 64)> +// CHECK-LABEL: func @matmul_transposeA +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4x8xf16, #map>, [[PARAM_1_:%.+]]: memref<8x16xf16, #map>, [[PARAM_2_:%.+]]: memref<16xf16, #map1>) -> memref<8x16xf16, #map> { +// CHECK-DAG: [[VAR_c16_i64_:%.+]] = arith.constant 16 : i64 +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c8_i64_:%.+]] = arith.constant 8 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<8x16xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c16_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 1 : si64, transposeB = 0 : si64} : (memref<4x8xf16, #map>, memref<8x16xf16, #map>, memref<16xf16, #map1>, memref<3xi64>, memref<8x16xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<8x16xf16, #map> +// CHECK: } +} + +// ----- + +func.func @matmul_transposeB(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { + %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) {transposeA = 0 : si64, transposeB = 1 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> +// CHECK-DAG: #map1 = affine_map<(d0) -> (0, d0 floordiv 64, 0, 0, 31, d0 mod 64)> +// CHECK-LABEL: func @matmul_transposeB +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4x8xf16, #map>, [[PARAM_1_:%.+]]: memref<8x16xf16, #map>, [[PARAM_2_:%.+]]: memref<16xf16, #map1>) -> memref<4x8xf16, #map> { +// CHECK-DAG: [[VAR_c8_i64_:%.+]] = arith.constant 8 : i64 +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4x8xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 0 : si64, transposeB = 1 : si64} : (memref<4x8xf16, #map>, memref<8x16xf16, #map>, memref<16xf16, #map1>, memref<3xi64>, memref<4x8xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<4x8xf16, #map> +// CHECK: } +} + +// ----- + +func.func @matmul_transposeAB(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { + %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) {transposeA = 1 : si64, transposeB = 1 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> +// CHECK-DAG: #map1 = affine_map<(d0) -> (0, d0 floordiv 64, 0, 0, 31, d0 mod 64)> +// CHECK-LABEL: func @matmul_transposeAB +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4x8xf16, #map>, [[PARAM_1_:%.+]]: memref<8x16xf16, #map>, [[PARAM_2_:%.+]]: memref<16xf16, #map1>) -> memref<8x8xf16, #map> { +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c8_i64_:%.+]] = arith.constant 8 : i64 +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<8x8xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 1 : si64, transposeB = 1 : si64} : (memref<4x8xf16, #map>, memref<8x16xf16, #map>, memref<16xf16, #map1>, memref<3xi64>, memref<8x8xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<8x8xf16, #map> +// CHECK: } +} + +// ----- + func.func @matmul_stack(%arg0: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %arg1: tensor<2x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %arg2: tensor<2x16xf16, #zhigh.layout<{dataLayout = "2DS"}>>) -> tensor<*xf16> { %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x16xf16, #zhigh.layout<{dataLayout = "2DS"}>>) -> tensor<*xf16> return %0 : tensor<*xf16> @@ -48,21 +124,21 @@ func.func @matmul_stack(%arg0: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3D // CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<4xi64> // CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<4xi64> // CHECK: krnl.store [[VAR_c16_i64_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<4xi64> -// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<2x4x8xf16, #map>, memref<2x8x16xf16, #map>, memref<2x16xf16, #map1>, memref<4xi64>, memref<2x4x16xf16, #map>) -> () +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<2x4x8xf16, #map>, memref<2x8x16xf16, #map>, memref<2x16xf16, #map1>, memref<4xi64>, memref<2x4x16xf16, #map>) -> () // CHECK: return [[RES_]] : memref<2x4x16xf16, #map> // CHECK: } } // ----- -func.func @matmul_broadcast(%arg0: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { +func.func @matmul_broadcast23(%arg0: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %arg1: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> return %0 : tensor<*xf16> // CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> // CHECK-DAG: #map1 = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> // CHECK-DAG: #map2 = affine_map<(d0) -> (0, d0 floordiv 64, 0, 0, 31, d0 mod 64)> -// CHECK-LABEL: func @matmul_broadcast +// CHECK-LABEL: func @matmul_broadcast23 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x4x8xf16, #map>, [[PARAM_1_:%.+]]: memref<8x16xf16, #map1>, [[PARAM_2_:%.+]]: memref<16xf16, #map2>) -> memref<2x4x16xf16, #map> { // CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[VAR_c16_i64_:%.+]] = arith.constant 16 : i64 @@ -78,13 +154,42 @@ func.func @matmul_broadcast(%arg0: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = // CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<4xi64> // CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<4xi64> // CHECK: krnl.store [[VAR_c16_i64_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<4xi64> -// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast = -1 : si64, is_stacked = 0 : si64} : (memref<2x4x8xf16, #map>, memref<8x16xf16, #map1>, memref<16xf16, #map2>, memref<4xi64>, memref<2x4x16xf16, #map>) -> () +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = -1 : si64, is_stacked = 0 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<2x4x8xf16, #map>, memref<8x16xf16, #map1>, memref<16xf16, #map2>, memref<4xi64>, memref<2x4x16xf16, #map>) -> () // CHECK: return [[RES_]] : memref<2x4x16xf16, #map> // CHECK: } } // ----- +func.func @matmul_broadcast1(%arg0: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %arg2: tensor<2x8xf16, #zhigh.layout<{dataLayout = "2DS"}>>) -> tensor<*xf16> { + %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) : (tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x8xf16, #zhigh.layout<{dataLayout = "2DS"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> +// CHECK-DAG: #map1 = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: #map2 = affine_map<(d0, d1) -> (d0, d1 floordiv 64, 0, 0, 31, d1 mod 64)> +// CHECK-LABEL: func @matmul_broadcast1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<8x16xf16, #map>, [[PARAM_1_:%.+]]: memref<2x4x8xf16, #map1>, [[PARAM_2_:%.+]]: memref<2x8xf16, #map2>) -> memref<2x8x8xf16, #map1> { +// CHECK-DAG: [[VAR_c2_i64_:%.+]] = arith.constant 2 : i64 +// CHECK-DAG: [[VAR_c16_i64_:%.+]] = arith.constant 16 : i64 +// CHECK-DAG: [[VAR_c8_i64_:%.+]] = arith.constant 8 : i64 +// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x8x8xf16, #map1> +// CHECK-DAG: [[RES_0_:%.+]] = memref.alloc() {{.*}}: memref<4xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_0_]]{{.}}[[VAR_c0_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[VAR_c16_i64_]], [[RES_0_]]{{.}}[[VAR_c1_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[VAR_c2_i64_]], [[RES_0_]]{{.}}[[VAR_c2_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[VAR_c8_i64_]], [[RES_0_]]{{.}}[[VAR_c3_]]{{.}} : memref<4xi64> +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_0_]], [[RES_]]) {is_bcast1 = -1 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<8x16xf16, #map>, memref<2x4x8xf16, #map1>, memref<2x8xf16, #map2>, memref<4xi64>, memref<2x8x8xf16, #map1>) -> () +// CHECK: return [[RES_]] : memref<2x8x8xf16, #map1> +// CHECK: } +} + +// ----- + func.func @matmul_unknown_dims(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1: tensor<8x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg2: tensor>) -> tensor<*xf16> { %0 ="zhigh.MatMul"(%arg0, %arg1, %arg2) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor>) -> tensor<*xf16> return %0 : tensor<*xf16> @@ -107,7 +212,7 @@ func.func @matmul_unknown_dims(%arg0: tensor<4x8xf16, #zhigh.layout<{dataLayout // CHECK: krnl.store [[VAR_c8_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> // CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 // CHECK: krnl.store [[VAR_4_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> -// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast = 0 : si64, is_stacked = 0 : si64} : (memref<4x8xf16, #map>, memref<8x?xf16, #map>, memref, memref<3xi64>, memref<4x?xf16, #map>) -> () +// CHECK: "zlow.matmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[RES_1_]], [[RES_]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = 0 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<4x8xf16, #map>, memref<8x?xf16, #map>, memref, memref<3xi64>, memref<4x?xf16, #map>) -> () // CHECK: return [[RES_]] : memref<4x?xf16, #map> // CHECK: } } diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/max.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/max.mlir index 42d8c30842..4f009f9326 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/max.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/max.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Max"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/meanreduce.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/meanreduce.mlir index 4242f00d97..7dbc85d42c 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/meanreduce.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/meanreduce.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<1x5x7x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> { %0 = "zhigh.MeanReduce2d"(%arg0) : (tensor<1x5x7x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/min.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/min.mlir index a0baea41cd..9ff5e9cf42 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/min.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/min.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Min"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/mul.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/mul.mlir index 22662aa3ac..2a33115d08 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/mul.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/mul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Mul"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/pool.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/pool.mlir index 9cd499a287..652a0c8e56 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/pool.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/pool.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @maxpool_valid_padding(%arg0: tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> { %0 = "zhigh.MaxPool2D"(%arg0) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1], act_func = "ACT_NONE"} : (tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_matmul.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_matmul.mlir new file mode 100644 index 0000000000..45302c2404 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_matmul.mlir @@ -0,0 +1,38 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_zhigh_quantized_matmul(%arg0: tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, %arg1: tensor, %arg2: tensor, %arg3: tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, %arg4: tensor, %arg5: tensor, %arg6: tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, %arg7: tensor, %arg8: tensor) -> tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> { + %none = "onnx.NoValue"() {value} : () -> none + %Out, %Out_RecScale, %Out_Offset = "zhigh.QuantizedMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %none, %none) {DequantizeOutput = 0 : si64, DisableClipping = 0 : si64, PreComputedBias = 0 : si64} : (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, tensor, tensor, none, none) -> (tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor) + return %Out : tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 64, d0 mod 64, d1 mod 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (0, d0 floordiv 128, 0, 0, 31, d0 mod 128)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (0, d0 floordiv 64, 0, 0, 31, d0 mod 64)> +// CHECK-LABEL: func.func @test_zhigh_quantized_matmul +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf16, #map>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref, [[PARAM_3_:%.+]]: memref<5x7xi8, #map1>, [[PARAM_4_:%.+]]: memref, [[PARAM_5_:%.+]]: memref, [[PARAM_6_:%.+]]: memref<7xi8, #map2>, [[PARAM_7_:%.+]]: memref, [[PARAM_8_:%.+]]: memref) -> memref<1x3x7xf16, #map> { +// CHECK-DAG: [[CST_7_:%.+]] = arith.constant 7 : i64 +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_3_1_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x3x7xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[CST_1_dot_000000_]], [[RES_1_]][] : memref +// CHECK: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_2_]][] : memref +// CHECK: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4xi64> +// CHECK: krnl.store [[CST_1_]], [[RES_3_]]{{.}}[[CST_0_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[CST_3_]], [[RES_3_]]{{.}}[[CST_1_1_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[CST_5_]], [[RES_3_]]{{.}}[[CST_2_]]{{.}} : memref<4xi64> +// CHECK: krnl.store [[CST_7_]], [[RES_3_]]{{.}}[[CST_3_1_]]{{.}} : memref<4xi64> +// CHECK: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<7xf16, #map3> +// CHECK: "zlow.quantizedMatmul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]], [[PARAM_8_]], [[RES_4_]], [[RES_3_]], [[RES_]], [[RES_]]_1, [[RES_]]_2) {bias_q_type = "INT8", dequantize_output = 0 : si64, disable_clipping = 0 : si64, is_bcast = -1 : si64, is_stacked = 0 : si64, out_q_type = "DLFLOAT16", pre_computed_bias = 0 : si64, x_q_type = "DLFLOAT16", y_q_type = "WEIGHTS"} : (memref<1x3x5xf16, #map>, memref, memref, memref<5x7xi8, #map1>, memref, memref, memref<7xi8, #map2>, memref, memref, memref<7xf16, #map3>, memref<4xi64>, memref<1x3x7xf16, #map>, memref, memref) -> () +// CHECK: return [[RES_]] : memref<1x3x7xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick.mlir new file mode 100644 index 0000000000..8d9a623aa9 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick.mlir @@ -0,0 +1,229 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_zhigh_quantized_stick_dlfloat16(%arg0: tensor<1x3x5xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + return %0#0: tensor<*xf16> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_dlfloat16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf32>) -> memref<1x3x5xf16, #map> { +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[CST_minus_1_dot_280000_:%.+]] = arith.constant -1.280000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_]], [[CST_0_1_]] : memref +// CHECK: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK: [[VAR_17_:%.+]] = arith.minnumf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_17_]], [[RES_]][] : memref +// CHECK: } +// CHECK: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_1_]], [[CST_0_]] : memref +// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_17_1_:%.+]] = arith.maxnumf [[LOAD_RES_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_17_1_]], [[RES_1_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_minus_1_dot_280000_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_1_dot_270000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = "krnl.round_even"([[VAR_11_]]) : (f32) -> f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_13_]], [[RES_2_]][] : memref +// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_12_]], [[RES_3_]][] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xf16, #map> +// CHECK: "zlow.stick"([[PARAM_0_]], [[RES_4_]]) {layout = "3DS", saturation = -1 : si64} : (memref<1x3x5xf32>, memref<1x3x5xf16, #map>) -> () +// CHECK: return [[RES_4_]] : memref<1x3x5xf16, #map> +// CHECK: } +} + +// ----- + + +func.func @test_zhigh_quantized_stick_dlfloat16_symmetric(%arg0: tensor<1x3x5xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 1 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + return %0#0: tensor<*xf16> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_dlfloat16_symmetric +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf32>) -> memref<1x3x5xf16, #map> { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<1.270000e+02> : tensor} : () -> memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5){ +// CHECK: [[VAR_7_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2] : memref<1x3x5xf32> +// CHECK: [[VAR_9_:%.+]] = math.absf [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_9_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2] : memref<1x3x5xf32> +// CHECK: } +// CHECK: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_1_]], [[CST_0_]] : memref +// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 5){ +// CHECK: [[VAR_7_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_7_1_]]#0, [[VAR_7_1_]]#1, [[VAR_7_1_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[VAR_9_1_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref +// CHECK: } +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[LOAD_VAR_0_MEM_:%.+]] = krnl.load [[VAR_0_]][] : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_5_:%.+]] = arith.divf [[LOAD_VAR_0_MEM_]], [[LOAD_RES_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_5_]], [[RES_2_]][] : memref +// CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[LOAD_RES_2_MEM_]], [[RES_3_]][] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_4_]][] : memref +// CHECK: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xf16, #map> +// CHECK: "zlow.stick"([[PARAM_0_]], [[RES_5_]]) {layout = "3DS", saturation = -1 : si64} : (memref<1x3x5xf32>, memref<1x3x5xf16, #map>) -> () +// CHECK: return [[RES_5_]] : memref<1x3x5xf16, #map> +// CHECK: } +} + +// ----- + +func.func @test_zhigh_quantized_stick_int8(%arg0: tensor<1x3x5xf32>) -> tensor<*xi8> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "int8"} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + return %0#0: tensor<*xi8> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 128, 0, d1 floordiv 32, d1 mod 32, d2 mod 128)> +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_int8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf32>) -> memref<1x3x5xi8, #map> { +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[CST_minus_1_dot_280000_:%.+]] = arith.constant -1.280000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_]], [[CST_0_1_]] : memref +// CHECK: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK: [[VAR_17_:%.+]] = arith.minnumf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_17_]], [[RES_]][] : memref +// CHECK: } +// CHECK: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_1_]], [[CST_0_]] : memref +// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_17_1_:%.+]] = arith.maxnumf [[LOAD_RES_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_17_1_]], [[RES_1_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_minus_1_dot_280000_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_1_dot_270000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = "krnl.round_even"([[VAR_11_]]) : (f32) -> f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_13_]], [[RES_2_]][] : memref +// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_12_]], [[RES_3_]][] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xi8, #map> +// CHECK: "zlow.quantizedStick"([[PARAM_0_]], [[RES_2_]], [[RES_3_]], [[RES_4_]]) {layout = "3DS", q_type = "int8"} : (memref<1x3x5xf32>, memref, memref, memref<1x3x5xi8, #map>) -> () +// CHECK: return [[RES_4_]] : memref<1x3x5xi8, #map> +// CHECK: } +} + +// ----- + + +func.func @test_zhigh_quantized_stick_weights(%arg0: tensor<1x3x5xf32>) -> tensor<*xi8> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "weights"} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + return %0#0: tensor<*xi8> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 64, d1 mod 64, d2 mod 64)> +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_weights +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf32>) -> memref<1x3x5xi8, #map> { +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[CST_minus_1_dot_280000_:%.+]] = arith.constant -1.280000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_]], [[CST_0_1_]] : memref +// CHECK: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_]]#0, [[VAR_14_]]#1, [[VAR_14_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK: [[VAR_17_:%.+]] = arith.minnumf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_17_]], [[RES_]][] : memref +// CHECK: } +// CHECK: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_1_]], [[CST_0_]] : memref +// CHECK: [[LOOP_1_:%.+]]:3 = krnl.define_loops 3 +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_3_:%.+]] = 0 to 1, [[LOOP_1_]]#1 -> [[I_4_:%.+]] = 0 to 3, [[LOOP_1_]]#2 -> [[I_5_:%.+]] = 0 to 5){ +// CHECK: [[VAR_14_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_14_1_]]#0, [[VAR_14_1_]]#1, [[VAR_14_1_]]#2] : memref<1x3x5xf32> +// CHECK-DAG: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_17_1_:%.+]] = arith.maxnumf [[LOAD_RES_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_17_1_]], [[RES_1_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_minus_1_dot_280000_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_1_dot_270000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = "krnl.round_even"([[VAR_11_]]) : (f32) -> f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_13_]], [[RES_2_]][] : memref +// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_12_]], [[RES_3_]][] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xi8, #map> +// CHECK: "zlow.quantizedStick"([[PARAM_0_]], [[RES_2_]], [[RES_3_]], [[RES_4_]]) {layout = "3DS", q_type = "weights"} : (memref<1x3x5xf32>, memref, memref, memref<1x3x5xi8, #map>) -> () +// CHECK: return [[RES_4_]] : memref<1x3x5xi8, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick_O3.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick_O3.mlir new file mode 100644 index 0000000000..2a45c12b5e --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/quantized_stick_O3.mlir @@ -0,0 +1,54 @@ +// RUN: onnx-mlir-opt -O3 --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_zhigh_quantized_stick_dlfloat16_symmetric(%arg0: tensor<1x3x5xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 1 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + return %0#0: tensor<*xf16> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_dlfloat16_symmetric +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x5xf32>) -> memref<1x3x5xf16, #map> { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<12xf32> +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_15_:%.+]] = arith.constant 15 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_15_]], [[RES_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref<1x3x5xf32>, memref<1xindex>) -> memref<15xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<12xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32>, vector<12xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 12 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4){ +// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_6_]]{{.}} : memref<15xf32>, vector<12xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32>, vector<12xf32> +// CHECK: [[VAR_9_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<12xf32> +// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[VAR_9_]] : vector<12xf32> +// CHECK: vector.store [[VAR_10_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32>, vector<12xf32> +// CHECK: } +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 12 to 15){ +// CHECK: [[VAR_6_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_6_1_]]{{.}} : memref<15xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32> +// CHECK: [[VAR_9_1_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_1_]] : f32 +// CHECK: [[VAR_10_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[VAR_9_1_]] : f32 +// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<12xf32>, vector<12xf32> +// CHECK: [[VAR_3_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_2_]] : vector<12xf32> into f32 +// CHECK: krnl.store [[VAR_3_]], [[RES_2_]][] : memref +// CHECK: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_5_:%.+]] = arith.divf [[CST_1_dot_270000_]], [[LOAD_RES_2_MEM_]] : f32 +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[VAR_5_]], [[RES_3_]][] : memref +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_4_]][] : memref +// CHECK: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1x3x5xf16, #map> +// CHECK: "zlow.stick"([[PARAM_0_]], [[RES_5_]]) {layout = "3DS", saturation = -1 : si64} : (memref<1x3x5xf32>, memref<1x3x5xf16, #map>) -> () +// CHECK: return [[RES_5_]] : memref<1x3x5xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemax.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemax.mlir new file mode 100644 index 0000000000..06c390c17e --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemax.mlir @@ -0,0 +1,25 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @reduce_max_axes_defined_noop_0(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> { + %0 = "zhigh.ReduceMax"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> + return %0 : tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @reduce_max_axes_defined_noop_0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x1xf16, #map> { +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x1xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[CST_3_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[CST_4_]], [[RES_1_]]{{.}}[[CST_1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[CST_5_]], [[RES_1_]]{{.}}[[CST_2_]]{{.}} : memref<3xi64> +// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<8192xi8> +// CHECK: "zlow.reducemax"([[PARAM_0_]], [[RES_2_]], [[RES_1_]], [[RES_]]) {layout = "3DS"} : (memref<3x4x5xf16, #map>, memref<8192xi8>, memref<3xi64>, memref<3x4x1xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x1xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemin.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemin.mlir new file mode 100644 index 0000000000..faf425537c --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reducemin.mlir @@ -0,0 +1,25 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @reduce_min_axes_defined_noop_0(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> { + %0 = "zhigh.ReduceMin"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> + return %0 : tensor<3x4x1xf16, #zhigh.layout<{dataLayout = "3DS"}>> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @reduce_min_axes_defined_noop_0 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x1xf16, #map> { +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x1xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[CST_3_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[CST_4_]], [[RES_1_]]{{.}}[[CST_1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[CST_5_]], [[RES_1_]]{{.}}[[CST_2_]]{{.}} : memref<3xi64> +// CHECK: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<8192xi8> +// CHECK: "zlow.reducemin"([[PARAM_0_]], [[RES_2_]], [[RES_1_]], [[RES_]]) {layout = "3DS"} : (memref<3x4x5xf16, #map>, memref<8192xi8>, memref<3xi64>, memref<3x4x1xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x1xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/relu.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/relu.mlir index 6e16464c62..9a83ab0f2c 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/relu.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/relu.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Relu"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sigmoid.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sigmoid.mlir index 077198915d..066d27eef1 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sigmoid.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sigmoid.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Sigmoid"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/softmax.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/softmax.mlir index f99a5d5efe..354aa94ef7 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/softmax.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/softmax.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> { %0 = "zhigh.Softmax"(%arg0) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sqrt.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sqrt.mlir new file mode 100644 index 0000000000..441573348b --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sqrt.mlir @@ -0,0 +1,50 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.Sqrt"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x5xf16, #map>) -> memref<3x4x5xf16, #map> { +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c4_i64_:%.+]] = arith.constant 4 : i64 +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3x4x5xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c4_i64_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.sqrt"([[PARAM_0_]], [[RES_1_]], [[RES_]]) {layout = "3D"} : (memref<3x4x5xf16, #map>, memref<3xi64>, memref<3x4x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x4x5xf16, #map> +// CHECK: } +} + +// ----- + +func.func @should_lower_to_zlow_unknown_dims(%arg0: tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { + %0 = "zhigh.Sqrt"(%arg0) : (tensor<3x?x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> + return %0 : tensor<*xf16> + +// CHECK-DAG: #map = affine_map<(d0, d1, d2) -> (0, d2 floordiv 64, d0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func @should_lower_to_zlow_unknown_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x?x5xf16, #map>) -> memref<3x?x5xf16, #map> { +// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[VAR_c5_i64_:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_c3_i64_:%.+]] = arith.constant 3 : i64 +// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<3x?x5xf16, #map> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xi64> +// CHECK: krnl.store [[VAR_c3_i64_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<3xi64> +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<3xi64> +// CHECK: krnl.store [[VAR_c5_i64_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64> +// CHECK: "zlow.sqrt"([[PARAM_0_]], [[RES_1_]], [[RES_]]) {layout = "3D"} : (memref<3x?x5xf16, #map>, memref<3xi64>, memref<3x?x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<3x?x5xf16, #map> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir index 22a67eec40..9696f5afc0 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant-of-shape.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant-of-shape.mlir index 5c8d4a63c9..6bb7cd76ae 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant-of-shape.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant-of-shape.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @test_stickified_constant_of_shape(%arg0: tensor) -> tensor> { %0 = onnx.Constant dense<8.000000e+00> : tensor diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir index 7bf9766d88..1055244b1a 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/stickified-constant.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --convert-onnx-to-krnl %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-onnx-to-krnl %s -split-input-file | FileCheck %s module { func.func @remove_stick_2d() -> tensor<2x3xf32> { @@ -32,3 +32,27 @@ module { // CHECK-NEXT: zhigh: "0x0100000000003E00400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000004100420042800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" // CHECK-NEXT: } // CHECK-NEXT: } + +// ----- + +func.func @splat_stickified_constant() -> tensor<2x3xf32> { + %0 = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense<5> : tensor<4096xi8>} : () -> tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>> + %1 = "zhigh.Unstick"(%0) : (tensor<2x3xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> +// CHECK-LABEL: func.func @splat_stickified_constant +// CHECK-SAME: () -> memref<2x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_1", shape = [1, 1, 1, 1, 32, 64], value = dense_resource : tensor<4096xi8>} : () -> memref<2x3xf16, #map> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x3xf32> +// CHECK: "zlow.unstick"([[VAR_0_]], [[RES_]]) {layout = "2D"} : (memref<2x3xf16, #map>, memref<2x3xf32>) -> () +// CHECK: return [[RES_]] : memref<2x3xf32> +// CHECK: } +// CHECK: dialect_resources: { +// CHECK: builtin: { +// CHECK: zhigh: "0x0100000005050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505050505" +// CHECK: } +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sub.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sub.mlir index b360592392..04a1951709 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sub.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/sub.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Sub"(%arg0, %arg1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/tanh.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/tanh.mlir index a7659b291e..f65a874971 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/tanh.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/tanh.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> { %0 = "zhigh.Tanh"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir index 508a819ccc..be231f26e9 100644 --- a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/test-datalayout.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow_1d(%arg0: tensor<7xf32>) -> tensor<*xf16> { %0 = "zhigh.Stick"(%arg0) {layout = "1D"} : (tensor<7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/driver/ccfd.mlir b/test/mlir/accelerators/nnpa/driver/ccfd.mlir index 3c66f67da7..a6aa1de5b0 100644 --- a/test/mlir/accelerators/nnpa/driver/ccfd.mlir +++ b/test/mlir/accelerators/nnpa/driver/ccfd.mlir @@ -1,67 +1,63 @@ -// RUN: ccfd=$(dirname %s)/ccfd.onnx && curl -L https://github.com/IBM/ai-on-z-fraud-detection/raw/main/onnx%20models/ccf_lstm_static_tf2onnx_OS_new.onnx -o ${ccfd} && onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" ${ccfd} | FileCheck %s && rm -rf ${ccfd} +// RUN: ccfd=$(dirname %s)/ccfd.onnx && curl -L https://github.com/IBM/ai-on-z-fraud-detection/raw/main/onnx%20models/ccf_lstm_static_tf2onnx_OS_new.onnx -o ${ccfd} && onnx-mlir --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" ${ccfd} | FileCheck %s && rm -rf ${ccfd} // COM: This test is to check regression on the IBM CCFD model. // COM: We expect that there are only one zlow.stick for the input and one zlow.unstick for the output. // COM: It is the necessary condition to get the best performance. -CHECK-LABEL: func.func @main_graph -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-NEXT: zlow.stick - -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-DAG: memref.alloc -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-NEXT: zlow.lstm +// CHECK-LABEL: func.func @main_graph +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-NEXT: zlow.stick + +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-DAG: memref.alloc +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-NEXT: zlow.lstm // No stick and unstick between two LSTMs. -CHECK-NOT: zlow.stick -CHECK-NOT: zlow.unstick - -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-DAG: memref.alloc -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-NEXT: zlow.lstm - +// CHECK-NOT: zlow.stick +// CHECK-NOT: zlow.unstick +// +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-DAG: memref.alloc +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-NEXT: zlow.lstm +// // No stick and unstick in between. -CHECK-NOT: zlow.stick -CHECK-NOT: zlow.unstick - -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-DAG: krnl.global -CHECK-DAG: krnl.global -CHECK-NEXT: zlow.matmul - +// CHECK-NOT: zlow.stick +// CHECK-NOT: zlow.unstick +// +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-DAG: krnl.global +// CHECK-DAG: krnl.global +// CHECK-NEXT: zlow.matmul +// // No stick and unstick in between. -CHECK-NOT: zlow.stick -CHECK-NOT: zlow.unstick - -CHECK-DAG: krnl.global -CHECK-DAG: memref.alloc -CHECK-DAG: krnl.global -CHECK-NEXT: zlow.add - +// CHECK-NOT: zlow.stick +// CHECK-NOT: zlow.unstick +// +// CHECK-DAG: krnl.global +// CHECK-DAG: memref.alloc +// CHECK-DAG: krnl.global +// CHECK-NEXT: zlow.add +// // No stick and unstick in between. -CHECK-NOT: zlow.stick -CHECK-NOT: zlow.unstick - -CHECK-DAG: memref.alloc -CHECK-DAG: krnl.global -CHECK-NEXT: zlow.sigmoid - -CHECK: memref.alloc -CHECK-NEXT: zlow.unstick +// CHECK-NOT: zlow.stick +// CHECK-NOT: zlow.unstick +// +// CHECK-DAG: memref.alloc +// CHECK-DAG: krnl.global +// CHECK-NEXT: zlow.sigmoid +// +// CHECK: memref.alloc +// CHECK-NEXT: zlow.unstick diff --git a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir index dc26676eb4..9c24e93acb 100644 --- a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir +++ b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor-num2.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir index d5f40bbc1f..a369f1289e 100644 --- a/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir +++ b/test/mlir/accelerators/nnpa/driver/data-transformation-on-ztensor.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR -tag="test" %s | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir index 863efd1ee4..7b8c9cc748 100644 --- a/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/dense-out-attention-layer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitZHighIR --printIR %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitZHighIR --printIR %s | FileCheck %s // This pattern is found in bert models, where the output of attention layer is passed through a dense layer, then added with the attention layer's input. // To simplify the test we use the input of MatMul to mimic the input of attention layer. @@ -15,5 +15,5 @@ func.func @test_matmul_add_add(%arg0: tensor, %arg1: tensor<768x768 // CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor) -> tensor> // CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<49152xi8>} : () -> tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor> +// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<768xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor> } diff --git a/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir index cfe6e2b611..5c2b199907 100644 --- a/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --nnpa-enable-scalar-bcast-binary --printIR %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitMLIR --nnpa-enable-scalar-bcast-binary --printIR %s | FileCheck %s // Check whether the compiler can remove unstick/stick so that the output of zdnn matmul is passed directly to zdnn div. func.func @matmul_div(%arg0: tensor) -> tensor { @@ -8,12 +8,12 @@ func.func @matmul_div(%arg0: tensor) -> tensor { %r = "onnx.Div"(%m, %scalar) : (tensor, tensor) -> tensor "onnx.Return"(%r) : (tensor) -> () -// CHECK-LABEL: func.func @matmul_div +// CHECK-LABEL: func.func @matmul_div // CHECK: memref.alloc // CHECK: memref.alloc // CHECK: [[ALLOC:%.+]] = memref.alloc({{.*}}) {{.*}}: memref // CHECK-DAG: [[MATMUL_RES:%.+]] = memref.cast [[ALLOC]] : memref to memref -// CHECK: "zlow.matmul"({{.*}}, {{.*}}, {{.*}}, {{.*}}, [[MATMUL_RES]]) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref, memref, memref, memref<4xi64>, memref) -> () +// CHECK: "zlow.matmul"({{.*}}, {{.*}}, {{.*}}, {{.*}}, [[MATMUL_RES]]) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref, memref, memref, memref<4xi64>, memref) -> () // CHECK-NOT: "zlow.stick" // CHECK-NOT: "zlow.unstick" // CHECK: "zlow.div"([[MATMUL_RES]], {{.*}}, {{.*}}, {{.*}}) {layout = "3DS"} : (memref, memref, memref<3xi64>, memref) -> () diff --git a/test/mlir/accelerators/nnpa/driver/saturation.mlir b/test/mlir/accelerators/nnpa/driver/saturation.mlir index 1023fdfe13..0245786f20 100644 --- a/test/mlir/accelerators/nnpa/driver/saturation.mlir +++ b/test/mlir/accelerators/nnpa/driver/saturation.mlir @@ -1,42 +1,42 @@ -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitZHighIR --nnpa-saturation=false --printIR %s | FileCheck --check-prefix=ZHIGH_OFF %s -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitZHighIR --nnpa-saturation=true --printIR %s | FileCheck --check-prefix=ZHIGH_ON %s -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitZLowIR --nnpa-saturation=false --printIR %s | FileCheck --check-prefix=ZLOW_OFF %s -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitZLowIR --nnpa-saturation=true --printIR %s | FileCheck --check-prefix=ZLOW_ON %s -// RUN: onnx-mlir-opt -mcpu=z16 -maccel=NNPA --nnpa-saturation=false --shape-inference --convert-onnx-to-zhigh --zhigh-decompose-stick-unstick %s | FileCheck --check-prefix=DECOMPOSE_OFF %s -// RUN: onnx-mlir-opt -mcpu=z16 -maccel=NNPA --nnpa-saturation=true --shape-inference --convert-onnx-to-zhigh --zhigh-decompose-stick-unstick %s | FileCheck --check-prefix=DECOMPOSE_ON %s -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitMLIR --nnpa-saturation=false --enable-compiler-stick-unstick --printIR %s | FileCheck --check-prefix=COMPILER_STICK_OFF %s -// RUN: onnx-mlir -mcpu=z16 -maccel=NNPA --EmitMLIR --nnpa-saturation=true --enable-compiler-stick-unstick --printIR %s | FileCheck --check-prefix=COMPILER_STICK_ON %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitZHighIR --nnpa-saturation=false --printIR %s | FileCheck --check-prefix=ZHIGH_OFF %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitZHighIR --nnpa-saturation=true --printIR %s | FileCheck --check-prefix=ZHIGH_ON %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitZLowIR --nnpa-saturation=false --printIR %s | FileCheck --check-prefix=ZLOW_OFF %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitZLowIR --nnpa-saturation=true --printIR %s | FileCheck --check-prefix=ZLOW_ON %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --nnpa-saturation=false --shape-inference --convert-onnx-to-zhigh --zhigh-decompose-stick-unstick %s | FileCheck --check-prefix=DECOMPOSE_OFF %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --nnpa-saturation=true --shape-inference --convert-onnx-to-zhigh --zhigh-decompose-stick-unstick %s | FileCheck --check-prefix=DECOMPOSE_ON %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitMLIR --nnpa-saturation=false --enable-compiler-stick-unstick --printIR %s | FileCheck --check-prefix=COMPILER_STICK_OFF %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --EmitMLIR --nnpa-saturation=true --enable-compiler-stick-unstick --printIR %s | FileCheck --check-prefix=COMPILER_STICK_ON %s // COM: for each case, check saturation ON and OFF. func.func @saturation(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () -// ZHIGH_OFF-LABEL: func @saturation +// ZHIGH_OFF-LABEL: func @saturation // ZHIGH_OFF: "zhigh.Stick"({{.*}}) {layout = "2D"} : {{.*}} -// ZHIGH_ON-LABEL: func @saturation +// ZHIGH_ON-LABEL: func @saturation // ZHIGH_ON: "zhigh.Stick"({{.*}}) {layout = "2D", saturation = -1 : si64} : {{.*}} -// ZLOW_OFF-LABEL: func @saturation +// ZLOW_OFF-LABEL: func @saturation // ZLOW_OFF: "zlow.stick"({{.*}}, {{.*}}) {layout = "2D"} : {{.*}} -// ZLOW_ON-LABEL: func @saturation +// ZLOW_ON-LABEL: func @saturation // ZLOW_ON: "zlow.stick"({{.*}}, {{.*}}) {layout = "2D", saturation = -1 : si64} : {{.*}} -// DECOMPOSE_OFF-LABEL: func @saturation +// DECOMPOSE_OFF-LABEL: func @saturation // DECOMPOSE_OFF: "zhigh.F32ToDLF16"(%arg0) : {{.*}} -// DECOMPOSE_ON-LABEL: func @saturation +// DECOMPOSE_ON-LABEL: func @saturation // DECOMPOSE_ON: "zhigh.F32ToDLF16"(%arg0) {saturation = -1 : si64} : {{.*}} -// COMPILER_STICK_OFF-LABEL: func @saturation +// COMPILER_STICK_OFF-LABEL: func @saturation // COMPILER_STICK_OFF-NOT: arith.minnumf // COMPILER_STICK_OFF-NOT: arith.maxnumf // COMPILER_STICK_OFF: zlow.relu -// COMPILER_STICK_ON-LABEL: func @saturation +// COMPILER_STICK_ON-LABEL: func @saturation // COMPILER_STICK_ON: arith.minnumf // COMPILER_STICK_ON: arith.maxnumf // COMPILER_STICK_ON: zlow.relu diff --git a/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir index 8a6d0d0ede..a685699191 100644 --- a/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/softmax-matmul-in-attention-layer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --enable-compiler-stick-unstick=false --EmitMLIR --printIR %s | FileCheck %s // Check whether the compiler can remove unstick/stick so that the output of zdnn softmax is passed directly to zdnn matmul. func.func @softmax_matmul(%arg0: tensor) -> tensor { diff --git a/test/mlir/accelerators/nnpa/module_op_be/compiler-config.mlir b/test/mlir/accelerators/nnpa/module_op_be/compiler-config.mlir index 4dc5006ae0..f275622bce 100644 --- a/test/mlir/accelerators/nnpa/module_op_be/compiler-config.mlir +++ b/test/mlir/accelerators/nnpa/module_op_be/compiler-config.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA -v -tag="test" %s -o %t 2>&1 | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA -v -tag="test" %s -o %t 2>&1 | FileCheck %s // ----- @@ -12,5 +12,5 @@ module { "onnx.EntryPoint"() {func = @main_graph} : () -> () } // CHECK: {{.*}} opt {{.*}} -o {{.*}}.bc -// CHECK: {{.*}} llc {{.*}} {{.*}} {{.*}}.bc +// CHECK: {{.*}} llc {{.*}} {{.*}} {{.*}}.bc // CHECK: {{.*}} {{clang|c|g}}++{{.*}} {{.*}}.o -o {{.*}}.so -shared -fPIC -L{{.*}}/lib -lRuntimeNNPA -lzdnn -lcruntime diff --git a/test/mlir/accelerators/nnpa/module_op_be/module_op.mlir b/test/mlir/accelerators/nnpa/module_op_be/module_op.mlir index 2a7ec2359d..ca55de53ce 100644 --- a/test/mlir/accelerators/nnpa/module_op_be/module_op.mlir +++ b/test/mlir/accelerators/nnpa/module_op_be/module_op.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --printIR %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --maccel=NNPA --printIR %s | FileCheck %s // CHECK: module attributes {llvm.data_layout = "E-{{.*}}", llvm.target_triple = "{{.*}}", "onnx-mlir.accels" = ["NNPA-0x10001"], "onnx-mlir.symbol-postfix" = "{{.*}}"} module { diff --git a/test/mlir/accelerators/nnpa/module_op_be/module_op_arch15.mlir b/test/mlir/accelerators/nnpa/module_op_be/module_op_arch15.mlir new file mode 100644 index 0000000000..4cb3cea8e8 --- /dev/null +++ b/test/mlir/accelerators/nnpa/module_op_be/module_op_arch15.mlir @@ -0,0 +1,5 @@ +// RUN: onnx-mlir --march=arch15 --maccel=NNPA --printIR %s | FileCheck %s + +// CHECK: module attributes {llvm.data_layout = "E-{{.*}}", llvm.target_triple = "{{.*}}", "onnx-mlir.accels" = ["NNPA-0x10101"], "onnx-mlir.symbol-postfix" = "{{.*}}"} +module { +} diff --git a/test/mlir/accelerators/nnpa/module_op_le/module_op.mlir b/test/mlir/accelerators/nnpa/module_op_le/module_op.mlir index 8d2432488c..17d85fa387 100644 --- a/test/mlir/accelerators/nnpa/module_op_le/module_op.mlir +++ b/test/mlir/accelerators/nnpa/module_op_le/module_op.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --accel=NNPA --printIR %s | FileCheck %s +// RUN: onnx-mlir --march=z16 --accel=NNPA --printIR %s | FileCheck %s // CHECK: module attributes {llvm.data_layout = "e-{{.*}}", "onnx-mlir.symbol-postfix" = "{{.*}}"} module { diff --git a/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir b/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir index 20438651da..0df2ee00ad 100644 --- a/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir +++ b/test/mlir/accelerators/nnpa/transform/fold-std-alloc.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --fold-std-alloc %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --fold-std-alloc %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir index c3a638feb6..ee1e31f280 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zhigh-clip-to-dlfloat -split-input-file %s || FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-clip-to-dlfloat -split-input-file %s || FileCheck %s func.func @should_clip_stick(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir index 99b9ec6c60..175f228aba 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --canonicalize %s -split-input-file | FileCheck %s func.func @remove_stick_and_unstick_same_layout(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "zhigh.Stick"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf16, #zhigh.layout<{ dataLayout = "2D"}>> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir index 609cab1aec..78cae2e6f9 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/constprop.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --constprop-zhigh %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --constprop-zhigh %s -split-input-file | FileCheck %s // ----- diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/quantizedconstprop.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/quantizedconstprop.mlir new file mode 100644 index 0000000000..3125c749c7 --- /dev/null +++ b/test/mlir/accelerators/nnpa/transform/zhigh-constant-propagation-be/quantizedconstprop.mlir @@ -0,0 +1,32 @@ +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --constprop-zhigh %s -split-input-file | FileCheck %s + +// ----- + +// Note: from zdnn, the padding value might be value other than 0 + +func.func @quantized_weight_int8() -> tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>> { + %0 = onnx.Constant dense<0.000000e+00> : tensor + %3 = onnx.Constant dense<0.00656270096> : tensor + %inp = "onnx.Constant"() {value = dense<[ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65], + [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31, -32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47, -48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63, -64, -65], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65], + [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31, -32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47, -48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63, -64, -65], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65], + [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31, -32, -33, -34, -35, -36, -37, -38, -39, -40, -41, -42, -43, -44, -45, -46, -47, -48, -49, -50, -51, -52, -53, -54, -55, -56, -57, -58, -59, -60, -61, -62, -63, -64, -65], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] +]> : tensor<7x65xi8>} : () -> tensor<7x65xi8> + %res:3 = "zhigh.QuantizedStick"(%inp, %3, %0) {layout = "2D", quantized_type = "WEIGHTS"} : (tensor<7x65xi8>, tensor, tensor) -> (tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) + return %res#0 : tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>> +} +// CHECK-LABEL: func.func @quantized_weight_int8 +// CHECK-SAME: () -> tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>> { +// CHECK: [[VAR_0_:%.+]] = "zhigh.StickifiedConstant"() {alignment = 4096 : i64, value = dense_resource : tensor<8192xi8>} : () -> tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>> +// CHECK: return [[VAR_0_]] : tensor<7x65xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>> +// CHECK: } +// CHECK: dialect_resources: { +// CHECK: builtin: { +// CHECK: zhigh: "0x0100000001FF02FE03FD04FC05FB06FA07F908F809F70AF60BF50CF40DF30EF20FF110F011EF12EE13ED14EC15EB16EA17E918E819E71AE61BE51CE41DE31EE21FE120E021DF22DE23DD24DC25DB26DA27D928D829D72AD62BD52CD42DD32ED22FD130D031CF32CE33CD34CC35CB36CA37C938C839C73AC63BC53CC43DC33EC23FC140C001FF02FE03FD04FC05FB06FA07F908F809F70AF60BF50CF40DF30EF20FF110F011EF12EE13ED14EC15EB16EA17E918E819E71AE61BE51CE41DE31EE21FE120E021DF22DE23DD24DC25DB26DA27D928D829D72AD62BD52CD42DD32ED22FD130D031CF32CE33CD34CC35CB36CA37C938C839C73AC63BC53CC43DC33EC23FC140C001FF02FE03FD04FC05FB06FA07F908F809F70AF60BF50CF40DF30EF20FF110F011EF12EE13ED14EC15EB16EA17E918E819E71AE61BE51CE41DE31EE21FE120E021DF22DE23DD24DC25DB26DA27D928D829D72AD62BD52CD42DD32ED22FD130D031CF32CE33CD34CC35CB36CA37C938C839C73AC63BC53CC43DC33EC23FC140C00101020203030404050506060707080809090A0A0B0B0C0C0D0D0E0E0F0F10101111121213131414151516161717181819191A1A1B1B1C1C1D1D1E1E1F1F20202121222223232424252526262727282829292A2A2B2B2C2C2D2D2E2E2F2F30303131323233333434353536363737383839393A3A3B3B3C3C3D3D3E3E3F3F4040000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041BF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041BF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041BF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000041410000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" +// CHECK: } +// CHECK: } + diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-decompose-stick-unstick.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-decompose-stick-unstick.mlir index c93417a689..52efd31c26 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-decompose-stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-decompose-stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt -mcpu=z16 -maccel=NNPA --zhigh-decompose-stick-unstick --split-input-file %s | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-decompose-stick-unstick --split-input-file %s | FileCheck %s func.func @test_relu(%arg0: tensor<1x3x5x?xf32>) -> tensor<1x3x5x?xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "4D"} : (tensor<1x3x5x?xf32>) -> tensor<1x3x5x?xf16, #zhigh.layout<{dataLayout = "4D"}>> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-layout-propagation.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-layout-propagation.mlir index 2a644f8d47..8ec3b55c2b 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-layout-propagation.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-layout-propagation.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zhigh-layout-prop --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-layout-prop --shape-inference %s -split-input-file | FileCheck %s func.func @add_layout_propagate_nhwc_1(%arg0: tensor<1x56x56x256xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, %arg1: tensor<1x256x56x56xf32>) -> tensor<1x256x56x56xf16, #zhigh.layout<{dataLayout = "4D"}>> { %0 = "zhigh.Unstick"(%arg0) : (tensor<1x56x56x256xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<1x256x56x56xf32> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-recompose-to-stick-unstick.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-recompose-to-stick-unstick.mlir index 4ae3b52a8e..3dac318abe 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-recompose-to-stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-recompose-to-stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt -mcpu=z16 -maccel=NNPA --zhigh-recompose-to-stick-unstick --split-input-file %s | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-recompose-to-stick-unstick --split-input-file %s | FileCheck %s func.func @test_relu(%arg0: tensor<1x3x5x?xf32>) -> tensor<1x3x5x?xf32> { %0 = "zhigh.F32ToDLF16"(%arg0) : (tensor<1x3x5x?xf32>) -> tensor<1x3x5x?xf16> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/conv.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/conv.mlir index 66108f4b2c..0537e8e2b0 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/conv.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/conv.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @conv_valid_padding(%arg0: tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, %arg1: tensor<2x2x3x1xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, %arg2: tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> { %0 = "zhigh.Conv2D"(%arg0, %arg1, %arg2) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1], act_func = "ACT_NONE"} : (tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x1xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<1xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/elementwise.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/elementwise.mlir index 69c66a655b..c99379b45d 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/elementwise.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// /// Test the default behavior of unary lement-wise ops users give the shape of diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/gru.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/gru.mlir index bd7c05048b..01c60f3d42 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/gru.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/gru.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @gru_return_single_step(%input : tensor<3x5x7xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %h0 : tensor<1x5x9xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %input_weights : tensor<1x7x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %input_bias : tensor<1x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %hidden_weights : tensor<1x9x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>, %hidden_bias : tensor<1x27xf16, #zhigh.layout<{dataLayout = "ZRH"}>>) -> tensor<*xf16> { diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/lstm.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/lstm.mlir index d4efc2716f..378478155c 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/lstm.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/lstm.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @test_lstm_all_timesteps(%X: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, %W: tensor<1x8x64xf16, #zhigh.layout<{dataLayout = "FICO"}>>, %R: tensor<1x16x64xf16, #zhigh.layout<{dataLayout = "FICO"}>>) -> (tensor<*xf16>) { %cst = "onnx.NoValue"() {value} : () -> none diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/matmul.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/matmul.mlir index 9328578423..6eab625b55 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/matmul.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/matmul.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @test_matmul_2d(%arg0 : tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, %arg1 : tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<*xf16> { %cst = "onnx.NoValue"() {value} : () -> none @@ -8,7 +8,7 @@ func.func @test_matmul_2d(%arg0 : tensor<4x8xf16, #zhigh.layout<{dataLayout = "2 // CHECK-LABEL: func @test_matmul_2d // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, [[PARAM_1_:%.+]]: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK: return [[VAR_0_]] : tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>> // CHECK: } } @@ -23,7 +23,7 @@ func.func @test_matmul_3d_broadcast(%arg0 : tensor<2x4x8xf16, #zhigh.layout<{dat // CHECK-LABEL: func @test_matmul_3d_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, [[PARAM_1_:%.+]]: tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: return [[VAR_0_]] : tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: } } @@ -38,7 +38,7 @@ func.func @test_matmul_3d_stack(%arg0 : tensor<2x4x8xf16, #zhigh.layout<{dataLay // CHECK-LABEL: func @test_matmul_3d_stack // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, [[PARAM_1_:%.+]]: tensor<2x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<2x4x8xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x8x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: return [[VAR_0_]] : tensor<2x4x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: } } @@ -53,7 +53,7 @@ func.func @test_matmul_2d_unknown_dims(%arg0 : tensor>, [[PARAM_1_:%.+]]: tensor>) -> tensor> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor>, tensor>, none) -> tensor> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor>, tensor>, none) -> tensor> // CHECK: return [[VAR_0_]] : tensor> // CHECK: } } @@ -68,7 +68,7 @@ func.func @test_matmul_3d_broadcast_unknown_dims(%arg0 : tensor<2x?x?xf16, #zhig // CHECK-LABEL: func @test_matmul_3d_broadcast_unknown_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, [[PARAM_1_:%.+]]: tensor>) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor>, none) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor>, none) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: return [[VAR_0_]] : tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: } } @@ -83,7 +83,7 @@ func.func @test_matmul_3d_stack_unknown_dims(%arg0 : tensor<2x?x?xf16, #zhigh.la // CHECK-LABEL: func @test_matmul_3d_stack_unknown_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, [[PARAM_1_:%.+]]: tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> { // CHECK: [[VAR_cst_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) : (tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_0_:%.+]] = "zhigh.MatMul"([[PARAM_0_]], [[PARAM_1_]], [[VAR_cst_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<2x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: return [[VAR_0_]] : tensor<2x?x16xf16, #zhigh.layout<{dataLayout = "3DS"}>> // CHECK: } } diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/meanreduce.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/meanreduce.mlir index eee8622cf9..6a1dc1d425 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/meanreduce.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/meanreduce.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @should_lower_to_zlow(%arg0: tensor<1x5x7x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> { %0 = "zhigh.MeanReduce2d"(%arg0) : (tensor<1x5x7x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/pool.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/pool.mlir index 2eb25df189..5d92b987e1 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/pool.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/pool.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @maxpool_valid_padding(%arg0: tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> { %0 = "zhigh.MaxPool2D"(%arg0) {kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1], act_func = "ACT_NONE"} : (tensor<1x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_matmul.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_matmul.mlir new file mode 100644 index 0000000000..e1c8cb6cc4 --- /dev/null +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_matmul.mlir @@ -0,0 +1,41 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s + +func.func @test_zhigh_quantized_matmul(%arg0: tensor<1x3x5xf32>, %arg1: tensor<5x7xf32>, %arg2: tensor<7xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %x:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16"} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + %y:3 = "zhigh.QuantizedStick"(%arg1, %none, %none) {layout = "2D", quantized_type = "weights"} : (tensor<5x7xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + %b:3 = "zhigh.QuantizedStick"(%arg2, %none, %none) {layout = "1D", quantized_type = "int8"} : (tensor<7xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + %m:3 = "zhigh.QuantizedMatMul"(%x#0, %x#1, %x#2, %y#0, %y#1, %y#2, %b#0, %b#1, %b#2, %none, %none) {DequantizeOutput = 0 : si64} : (tensor<*xf16>, tensor, tensor, tensor<*xi8>, tensor, tensor, tensor<*xi8>, tensor, tensor, none, none) -> (tensor<*xf16>, tensor, tensor) + onnx.Return %m#0: tensor<*xf16> + +// CHECK-LABEL: func.func @test_zhigh_quantized_matmul +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5x7xf32>, [[PARAM_2_:%.+]]: tensor<7xf32>) -> tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_0_]], [[VAR_0_]]) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[Out_0_:%.+]], [[RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[PARAM_1_]], [[VAR_0_]], [[VAR_0_]]) {layout = "2D", quantized_type = "weights", sym_mode = 0 : i64} : (tensor<5x7xf32>, none, none) -> (tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[Out_3_:%.+]], [[RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[PARAM_2_]], [[VAR_0_]], [[VAR_0_]]) {layout = "1D", quantized_type = "int8", sym_mode = 0 : i64} : (tensor<7xf32>, none, none) -> (tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, tensor, tensor) +// CHECK: [[Out_6_:%.+]], [[OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[Out_]], [[RecScale_]], [[VAR_Offset_]], [[Out_]]_0, [[RecScale_]]_1, [[VAR_Offset_]]_2, [[Out_]]_3, [[RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_0_]], [[VAR_0_]]) {DequantizeOutput = 0 : si64, DisableClipping = 0 : si64, PreComputedBias = 0 : si64} : (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, tensor, tensor, none, none) -> (tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: onnx.Return [[Out_6_]] : tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> +// CHECK: } +} + +// ----- + +func.func @test_zhigh_quantized_matmul_dequantized(%arg0: tensor<1x3x5xf32>, %arg1: tensor<5x7xf32>, %arg2: tensor<7xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %x:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16"} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + %y:3 = "zhigh.QuantizedStick"(%arg1, %none, %none) {layout = "2D", quantized_type = "weights"} : (tensor<5x7xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + %b:3 = "zhigh.QuantizedStick"(%arg2, %none, %none) {layout = "1D", quantized_type = "int8"} : (tensor<7xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + %m:3 = "zhigh.QuantizedMatMul"(%x#0, %x#1, %x#2, %y#0, %y#1, %y#2, %b#0, %b#1, %b#2, %none, %none) {DequantizeOutput = -1 : si64} : (tensor<*xf16>, tensor, tensor, tensor<*xi8>, tensor, tensor, tensor<*xi8>, tensor, tensor, none, none) -> (tensor<*xf16>, tensor, tensor) + onnx.Return %m#0: tensor<*xf16> + +// CHECK-LABEL: func.func @test_zhigh_quantized_matmul_dequantized +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5x7xf32>, [[PARAM_2_:%.+]]: tensor<7xf32>) -> tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS"}>> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[VAR_0_]], [[VAR_0_]]) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: [[Out_0_:%.+]], [[RecScale_1_:%.+]], [[VAR_Offset_2_:%.+]] = "zhigh.QuantizedStick"([[PARAM_1_]], [[VAR_0_]], [[VAR_0_]]) {layout = "2D", quantized_type = "weights", sym_mode = 0 : i64} : (tensor<5x7xf32>, none, none) -> (tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: [[Out_3_:%.+]], [[RecScale_4_:%.+]], [[VAR_Offset_5_:%.+]] = "zhigh.QuantizedStick"([[PARAM_2_]], [[VAR_0_]], [[VAR_0_]]) {layout = "1D", quantized_type = "int8", sym_mode = 0 : i64} : (tensor<7xf32>, none, none) -> (tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, tensor, tensor) +// CHECK: [[Out_6_:%.+]], [[OutRecScale_:%.+]], [[VAR_OutOffset_:%.+]] = "zhigh.QuantizedMatMul"([[Out_]], [[RecScale_]], [[VAR_Offset_]], [[Out_]]_0, [[RecScale_]]_1, [[VAR_Offset_]]_2, [[Out_]]_3, [[RecScale_]]_4, [[VAR_Offset_]]_5, [[VAR_0_]], [[VAR_0_]]) {DequantizeOutput = -1 : si64, DisableClipping = 0 : si64, PreComputedBias = 0 : si64} : (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor, tensor<5x7xi8, #zhigh.layout<{dataLayout = "2D", quantizedType = "WEIGHTS"}>>, tensor, tensor, tensor<7xi8, #zhigh.layout<{dataLayout = "1D", quantizedType = "INT8"}>>, tensor, tensor, none, none) -> (tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor, tensor) +// CHECK: onnx.Return [[Out_6_]] : tensor<1x3x7xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_stick.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_stick.mlir new file mode 100644 index 0000000000..61b66f9543 --- /dev/null +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/quantized_stick.mlir @@ -0,0 +1,44 @@ +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s + +func.func @test_zhigh_quantized_stick_dlfloat16(%arg0: tensor<1x3x5xf32>) -> tensor<*xf16> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xf16>, tensor, tensor) + onnx.Return %0#0: tensor<*xf16> + +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_dlfloat16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5xf32>) -> tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> { +// CHECK: [[NONE:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[NONE]], [[NONE]]) {layout = "3DS", quantized_type = "dlfloat16", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>>, tensor, tensor) +// CHECK: onnx.Return [[Out_]] : tensor<1x3x5xf16, #zhigh.layout<{dataLayout = "3DS", quantizedType = "DLFLOAT16"}>> +// CHECK: } +} + +// ----- + +func.func @test_zhigh_quantized_stick_int8(%arg0: tensor<1x3x5xf32>) -> tensor<*xi8> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "int8", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + onnx.Return %0#0: tensor<*xi8> + +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_int8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5xf32>) -> tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "INT8"}>> { +// CHECK: [[NONE:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[NONE]], [[NONE]]) {layout = "3DS", quantized_type = "int8", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "INT8"}>>, tensor, tensor) +// CHECK: onnx.Return [[Out_]] : tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "INT8"}>> +// CHECK: } +} + +// ----- + +func.func @test_zhigh_quantized_stick_weights(%arg0: tensor<1x3x5xf32>) -> tensor<*xi8> { + %none = "onnx.NoValue"() {value} : () -> none + %0:3 = "zhigh.QuantizedStick"(%arg0, %none, %none) {layout = "3DS", quantized_type = "weights", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<*xi8>, tensor, tensor) + onnx.Return %0#0: tensor<*xi8> + +// CHECK-LABEL: func.func @test_zhigh_quantized_stick_weights +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x5xf32>) -> tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "WEIGHTS"}>> { +// CHECK: [[NONE:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Out_:%.+]], [[RecScale_:%.+]], [[VAR_Offset_:%.+]] = "zhigh.QuantizedStick"([[PARAM_0_]], [[NONE]], [[NONE]]) {layout = "3DS", quantized_type = "weights", sym_mode = 0 : i64} : (tensor<1x3x5xf32>, none, none) -> (tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "WEIGHTS"}>>, tensor, tensor) +// CHECK: onnx.Return [[Out_]] : tensor<1x3x5xi8, #zhigh.layout<{dataLayout = "3DS", quantizedType = "WEIGHTS"}>> +// CHECK: } +} diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/stick-unstick.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/stick-unstick.mlir index 60d73d9343..c616da172a 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/stick-unstick.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-shape-inference/stick-unstick.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference %s -split-input-file | FileCheck %s func.func @stick_unstick_static_dims(%arg0: tensor<1x3x5x7xf32>) -> tensor<*xf32> { %0 = "zhigh.Stick"(%arg0) {layout = "NHWC"} : (tensor<1x3x5x7xf32>) -> tensor<*xf16> diff --git a/test/mlir/accelerators/nnpa/transform/zlow-normalize-by-using-dummyop.mlir b/test/mlir/accelerators/nnpa/transform/zlow-normalize-by-using-dummyop.mlir index c03db866d7..ca55ab5ee2 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-normalize-by-using-dummyop.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-normalize-by-using-dummyop.mlir @@ -1,11 +1,11 @@ -// RUN: (onnx-mlir-opt --mcpu=z16 --maccel=NNPA --normalize-memrefs %s 2>&1 || true) | FileCheck --check-prefix=FAILED %s +// RUN: (onnx-mlir-opt --march=z16 --maccel=NNPA --normalize-memrefs %s 2>&1 || true) | FileCheck --check-prefix=FAILED %s // COM: Current MLIR normalize-memres does not support multiple dereferencing uses // in a single op, check expected failure emitted by MLIR. // FAILED: "multiple dereferencing uses in a single op not supported" -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zlow-dummyop-for-multideref --normalize-memrefs --canonicalize %s | FileCheck --check-prefix=PASSED %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zlow-dummyop-for-multideref --normalize-memrefs --canonicalize %s | FileCheck --check-prefix=PASSED %s // COM: Check normalize memrefs when there are multiple dereferencing uses in a single op. // COM: Check that --zlow-dummyop-for-multideref can help to bypass the issue. diff --git a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir index fd38ea6bf0..07cffa92b5 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s #map = affine_map<(d0, d1) -> (0, d1 floordiv 64, 0, d0 floordiv 32, d0 mod 32, d1 mod 64)> func.func @remove_dangling_stick(%arg0: memref) -> memref { diff --git a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir index e3761e8bd6..65676f4805 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-stick-unstick-expansion.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --zlow-stick-expansion %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zlow-stick-expansion %s -split-input-file | FileCheck %s // ----- @@ -11,12 +11,12 @@ func.func @test_stick_expansion_with_sat(%arg0: memref<16x8x128xf32>) -> memref< // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 64)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1 floordiv 64)> -// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<()[s0, s1] -> (s1 + 8)> -// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<()[s0, s1] -> (s1 + 16)> -// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<()[s0, s1] -> (s1 + 24)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d1 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0)[s0] -> (d0 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0] -> (d0 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + 24)> // CHECK-LABEL: func.func @test_stick_expansion_with_sat // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { // CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index @@ -28,19 +28,18 @@ func.func @test_stick_expansion_with_sat(%arg0: memref<16x8x128xf32>) -> memref< // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<8.57315738E+9> : vector<4xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<-8.57315738E+9> : vector<4xf32> -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map> // CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ // CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_1_]]#2] +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2) // CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[RES_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> -// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_1_]]#2, [[VAR_3_]]{{.}} +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#2, [[VAR_3_]]) // CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf32> // CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf16, #map> // CHECK: affine.for [[I_3_:%.+]] = 0 to 64 step 32 { -// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> // CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[CST_4_]] : index // CHECK-NOT: separator of consecutive DAGs @@ -86,11 +85,11 @@ func.func @test_stick_expansion_with_sat(%arg0: memref<16x8x128xf32>) -> memref< // CHECK-DAG: [[VAR_39_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_33_]], [[VAR_34_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> // CHECK: [[VAR_40_:%.+]] = "zlow.vec_f32_to_dlf16"([[VAR_35_]], [[VAR_36_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> // CHECK: vector.store [[VAR_37_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_41_:%.+]] = affine.apply [[MAP_4_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_41_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_38_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]1] : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_42_:%.+]] = affine.apply [[MAP_5_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_42_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_39_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]2] : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_43_:%.+]] = affine.apply [[MAP_6_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_43_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_40_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_4_]]3] : memref<2x64xf16>, vector<8xf16> // CHECK: } // CHECK: } @@ -109,12 +108,12 @@ func.func @test_stick_expansion_without_sat(%arg0: memref<16x8x128xf32>) -> memr // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 64)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1 floordiv 64)> -// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<()[s0, s1] -> (s1 + 8)> -// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<()[s0, s1] -> (s1 + 16)> -// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<()[s0, s1] -> (s1 + 24)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d1 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0)[s0] -> (d0 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0] -> (d0 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + 24)> // CHECK-LABEL: func.func @test_stick_expansion_without_sat // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32>) -> memref<16x8x128xf16, #map> { // CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index @@ -124,19 +123,18 @@ func.func @test_stick_expansion_without_sat(%arg0: memref<16x8x128xf32>) -> memr // CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index // CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map> // CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ // CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_1_]]#2] +// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2) // CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[RES_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> -// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_1_]]#2, [[VAR_3_]]{{.}} +// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#2, [[VAR_3_]]) // CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf32> // CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf16, #map> // CHECK: affine.for [[I_3_:%.+]] = 0 to 64 step 32 { -// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_5_:%.+]] = affine.apply [[MAP_3_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> // CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[CST_4_]] : index // CHECK-NOT: separator of consecutive DAGs @@ -164,11 +162,11 @@ func.func @test_stick_expansion_without_sat(%arg0: memref<16x8x128xf32>) -> memr // CHECK-DAG: [[VAR_23_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_4_]], [[LOAD_PARAM_0_MEM_5_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> // CHECK: [[VAR_24_:%.+]] = "zlow.vec_f32_to_dlf16"([[LOAD_PARAM_0_MEM_6_]], [[LOAD_PARAM_0_MEM_7_]]) : (vector<4xf32>, vector<4xf32>) -> vector<8xf16> // CHECK: vector.store [[VAR_21_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_4_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_22_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_25_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_5_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_23_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_26_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_6_]](){{.}}[[VAR_2_]], [[I_3_]]{{.}} +// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} // CHECK: vector.store [[VAR_24_]], [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_27_]]{{.}} : memref<2x64xf16>, vector<8xf16> // CHECK: } // CHECK: } @@ -189,91 +187,71 @@ func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<1 // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> // CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> // CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0)[s0] -> (s0 floordiv 64)> -// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (d0 + 8)> -// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 16)> -// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0) -> (d0 + 24)> -// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: [[MAP_7_:#.+]] = affine_map<()[s0] -> (-s0 + 121)> -// CHECK-DAG: [[MAP_8_:#.+]] = affine_map<()[s0] -> ((-s0) mod 8)> -// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<()[s0] -> (-s0 - (-s0) mod 8 + 128)> -// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0)[s0] -> (d0 + 8)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0] -> (d0 + 16)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + 24)> // CHECK-LABEL: func.func @test_unstick_expansion // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf16, #map>) -> memref<16x8x128xf32> { -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index // CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index // CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index // CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index // CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index // CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index // CHECK-DAG: [[VAR_true_:%.+]] = arith.constant true // CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf32> -// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 -// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){ -// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) -// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2) -// CHECK: [[VAR_3_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at {{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}} : memref<16x8x128xf16, #map> -// CHECK: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#2){{.}}[[VAR_3_]]{{.}} -// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, read, locality<1>, data : memref<16x8x128xf16, #map> -// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_2_]]{{.}}, write, locality<1>, data : memref<16x8x128xf32> -// CHECK: scf.if [[VAR_true_]] { -// CHECK: scf.for [[I_3_:%.+]] = [[CST_0_]] to [[CST_64_]] step [[CST_32_]] { -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK-DAG: [[VAR_6_:%.+]] = affine.apply [[MAP_3_]]([[I_3_]]) -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_6_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]) -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_8_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]) -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[VAR_10_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[VAR_12_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_2_]]{{.}} -// CHECK: vector.store [[output1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]2] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_12_]], [[CST_4_]] : index -// CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]3] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_12_]], [[CST_8_]] : index -// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]4] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_12_]], [[CST_12_]] : index -// CHECK: vector.store [[VAR_output2_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]5] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_16_:%.+]] = arith.addi [[VAR_12_]], [[CST_16_]] : index -// CHECK: vector.store [[output1_2_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_12_]], [[CST_20_]] : index -// CHECK: vector.store [[VAR_output2_3_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_12_]], [[CST_24_]] : index -// CHECK: vector.store [[output1_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]8] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_12_]], [[CST_28_]] : index -// CHECK: vector.store [[VAR_output2_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: } -// CHECK: } else { -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = affine.apply [[MAP_7_]](){{.}}[[VAR_2_]]{{.}} -// CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[LOAD_VAR_reinterpret_cast_MEM_4_]] step [[CST_8_]] { -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[I_4_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_0_]], [[VAR_output2_1_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_5_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[VAR_10_1_:%.+]] = affine.apply [[MAP_6_]]([[I_4_]]){{.}}[[VAR_2_]]{{.}} -// CHECK: vector.store [[output1_0_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = arith.addi [[VAR_10_1_]], [[CST_4_]] : index -// CHECK: vector.store [[VAR_output2_1_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]1] : memref<16x8x128xf32>, vector<4xf32> -// CHECK: } -// CHECK-DAG: [[VAR_6_1_:%.+]] = affine.apply [[MAP_8_]](){{.}}[[VAR_2_]]{{.}} -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = affine.apply [[MAP_9_]](){{.}}[[VAR_2_]]{{.}} -// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_6_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_4_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} : memref<2x64xf16>, vector<8xf16> -// CHECK: [[output1_]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_6_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) -// CHECK: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<8xf32> -// CHECK: vector.store [[output1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<4xf32> -// CHECK: vector.store [[VAR_output2_1_]], [[RES_1_]]{{.}}[[CST_4_]]{{.}} : memref<8xf32>, vector<4xf32> -// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[VAR_6_1_]] step [[CST_1_]] { -// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = krnl.load [[RES_1_]]{{.}}[[I_5_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_10_]]([[I_5_]]){{.}}[[VAR_2_]], [[LOAD_VAR_reinterpret_cast_MEM_1_]]{{.}} -// CHECK: krnl.store [[LOAD_VAR_reinterpret_cast_MEM_5_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]0] : memref<16x8x128xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 16){ +// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ +// CHECK-DAG: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_6_:%.+]] = affine.apply [[MAP_1_]]([[VAR_5_]]) +// CHECK: [[VAR_7_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at {{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}} : memref<16x8x128xf16, #map> +// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_2_]]([[VAR_5_]]){{.}}[[VAR_7_]]{{.}} +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}}, read, locality<1>, data : memref<16x8x128xf16, #map> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}}, write, locality<1>, data : memref<16x8x128xf32> +// CHECK: scf.if [[VAR_true_]] { +// CHECK: scf.for [[I_3_:%.+]] = [[CST_0_]] to [[CST_64_]] step [[CST_32_]] { +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_3_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_11_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_11_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_0_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_13_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_13_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_2_:%.+]], [[VAR_output2_3_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_15_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_15_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_4_:%.+]], [[VAR_output2_5_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: vector.store [[VAR_output1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_9_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_9_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]7] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_9_]], [[CST_8_]] : index +// CHECK: vector.store [[VAR_output1_0_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]8] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_9_]], [[CST_12_]] : index +// CHECK: vector.store [[VAR_output2_1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]9] : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_20_:%.+]] = arith.addi [[VAR_9_]], [[CST_16_]] : index +// CHECK: vector.store [[VAR_output1_2_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_20_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_21_:%.+]] = arith.addi [[VAR_9_]], [[CST_20_]] : index +// CHECK: vector.store [[VAR_output2_3_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_21_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addi [[VAR_9_]], [[CST_24_]] : index +// CHECK: vector.store [[VAR_output1_4_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_22_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK: [[VAR_23_:%.+]] = arith.addi [[VAR_9_]], [[CST_28_]] : index +// CHECK: vector.store [[VAR_output2_5_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_23_]]{{.}} : memref<16x8x128xf32>, vector<4xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } // CHECK: } // CHECK: } // CHECK: } @@ -281,3 +259,116 @@ func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<1 // CHECK: } } +// ----- + + +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @test_unstick_expansion_127(%arg0: memref<16x8x127xf16, #map>) -> memref<16x8x127xf32> { + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<16x8x127xf32> + "zlow.unstick"(%arg0, %alloc) {layout = "3DS"} : (memref<16x8x127xf16, #map>, memref<16x8x127xf32>) -> () + return %alloc : memref<16x8x127xf32> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 64)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0)[s0] -> (s0 floordiv 64)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (d0 * -64 + 63)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0] -> (d0 + 8)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0)[s0] -> (d0 + 16)> +// CHECK-DAG: [[MAP_7_:#.+]] = affine_map<(d0)[s0] -> (d0 + 24)> +// CHECK-DAG: [[MAP_8_:#.+]] = affine_map<()[s0] -> (-s0 + 120)> +// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<()[s0] -> ((-s0 + 127) mod 8)> +// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<()[s0] -> (-s0 - (-s0 + 127) mod 8 + 127)> +// CHECK-DAG: [[MAP_11_:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 + s1)> +// CHECK-LABEL: func.func @test_unstick_expansion_127 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x127xf16, #map>) -> memref<16x8x127xf32> { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_28_:%.+]] = arith.constant 28 : index +// CHECK-DAG: [[CST_24_:%.+]] = arith.constant 24 : index +// CHECK-DAG: [[CST_20_:%.+]] = arith.constant 20 : index +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x127xf32> +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x127xf16, #map> to memref<2x64xf16> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 16){ +// CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ +// CHECK-DAG: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){ +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_6_:%.+]] = affine.apply [[MAP_1_]]([[VAR_5_]]) +// CHECK: [[VAR_7_:%.+]] = krnl.get_linear_offset_index [[PARAM_0_]] at {{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}} : memref<16x8x127xf16, #map> +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_2_]]([[VAR_5_]]){{.}}[[VAR_7_]]{{.}} +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK: krnl.prefetch [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}}, read, locality<1>, data : memref<16x8x127xf16, #map> +// CHECK: krnl.prefetch [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_6_]]{{.}}, write, locality<1>, data : memref<16x8x127xf32> +// CHECK: [[VAR_9_:%.+]] = affine.apply [[MAP_3_]]([[VAR_5_]]){{.}}[[VAR_7_]]{{.}} +// CHECK: [[VAR_10_:%.+]] = arith.cmpi sge, [[VAR_9_]], [[CST_0_]] : index +// CHECK: scf.if [[VAR_10_]] { +// CHECK: scf.for [[I_3_:%.+]] = [[CST_0_]] to [[CST_64_]] step [[CST_32_]] { +// CHECK-DAG: [[VAR_11_:%.+]] = affine.apply [[MAP_4_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[I_3_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_:%.+]], [[VAR_output2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_13_:%.+]] = affine.apply [[MAP_5_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_1_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_13_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_1_:%.+]], [[VAR_output2_2_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_1_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_15_:%.+]] = affine.apply [[MAP_6_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_2_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_15_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_3_:%.+]], [[VAR_output2_4_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_2_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: [[VAR_17_:%.+]] = affine.apply [[MAP_7_]]([[I_3_]]){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_3_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_17_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_5_:%.+]], [[VAR_output2_6_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_3_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: vector.store [[VAR_output1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]1] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addi [[VAR_11_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]9] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_20_:%.+]] = arith.addi [[VAR_11_]], [[CST_8_]] : index +// CHECK: vector.store [[VAR_output1_1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_20_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_21_:%.+]] = arith.addi [[VAR_11_]], [[CST_12_]] : index +// CHECK: vector.store [[VAR_output2_2_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_21_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addi [[VAR_11_]], [[CST_16_]] : index +// CHECK: vector.store [[VAR_output1_3_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_22_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_23_:%.+]] = arith.addi [[VAR_11_]], [[CST_20_]] : index +// CHECK: vector.store [[VAR_output2_4_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_23_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_24_:%.+]] = arith.addi [[VAR_11_]], [[CST_24_]] : index +// CHECK: vector.store [[VAR_output1_5_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_24_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_25_:%.+]] = arith.addi [[VAR_11_]], [[CST_28_]] : index +// CHECK: vector.store [[VAR_output2_6_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_25_]]{{.}} : memref<16x8x127xf32>, vector<4xf32> +// CHECK: } +// CHECK: } else { +// CHECK: [[VAR_11_1_:%.+]] = affine.apply [[MAP_8_]](){{.}}[[VAR_6_]]{{.}} +// CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[VAR_11_1_]] step [[CST_8_]] { +// CHECK-DAG: [[VAR_15_1_:%.+]] = affine.apply [[MAP_4_]]([[I_4_]]){{.}}[[VAR_6_]]{{.}} +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[I_4_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_1_1_:%.+]], [[VAR_output2_2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_4_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: vector.store [[VAR_output1_1_1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]5] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: [[VAR_17_1_:%.+]] = arith.addi [[VAR_15_1_]], [[CST_4_]] : index +// CHECK: vector.store [[VAR_output2_2_1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]7] : memref<16x8x127xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_5_:%.+]] = affine.apply [[MAP_9_]](){{.}}[[VAR_6_]]{{.}} +// CHECK-DAG: [[VAR_13_1_:%.+]] = affine.apply [[MAP_10_]](){{.}}[[VAR_6_]]{{.}} +// CHECK: [[LOAD_VAR_reinterpret_cast_MEM_6_:%.+]] = vector.load [[VAR_reinterpret_cast_]]{{.}}[[VAR_8_]], [[VAR_13_1_]]{{.}} : memref<2x64xf16>, vector<8xf16> +// CHECK: [[VAR_output1_1_:%.+]], [[VAR_output2_1_:%.+]] = "zlow.vec_dlf16_to_f32"([[LOAD_VAR_reinterpret_cast_MEM_6_]]) : (vector<8xf16>) -> (vector<4xf32>, vector<4xf32>) +// CHECK: vector.store [[VAR_output1_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_output2_1_]], [[RES_1_]]{{.}}[[CST_4_]]{{.}} : memref<8xf32>, vector<4xf32> +// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[LOAD_VAR_reinterpret_cast_MEM_5_]] step [[CST_1_]] { +// CHECK-DAG: [[VAR_15_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[I_5_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[LOAD_VAR_reinterpret_cast_MEM_4_:%.+]] = affine.apply [[MAP_11_]]([[I_5_]]){{.}}[[VAR_6_]], [[VAR_13_1_]]{{.}} +// CHECK: krnl.store [[VAR_15_1_]], [[RES_]]{{.}}[[VAR_1_]], [[VAR_3_]], [[VAR_1_]]6] : memref<16x8x127xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return [[RES_]] : memref<16x8x127xf32> +// CHECK: } +} + diff --git a/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir new file mode 100644 index 0000000000..1124514b79 --- /dev/null +++ b/test/mlir/conversion/krnl_to_affine/krnl_to_affine_parallel_clause.mlir @@ -0,0 +1,111 @@ +// RUN: onnx-mlir-opt -O3 --convert-krnl-to-affine --canonicalize %s -split-input-file | FileCheck %s + +// ----- + +func.func @parallel_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block), num_threads(%c8_i32) {proc_bind = "spread"} : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_threads_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) {proc_bind = "spread"} : index +// CHECK: } +// CHECK: } +} + +// ----- + +func.func @parallel_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block), num_threads(%c8_i32) : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_threads +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]), num_threads([[CST_8_]]) : index +// CHECK: } +// CHECK: } +} + +// ----- + +func.func @parallel_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_0[0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_1[0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + affine.store %c16384, %alloc_3[0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + %0 = krnl.define_loops 1 + %loop_block, %loop_local = krnl.block %0 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) + krnl.parallel(%loop_block) {proc_bind = "spread"} : !krnl.loop + krnl.iterate(%loop_block) with (%0 -> %arg1 = 0 to 16384){ + %1 = krnl.get_induction_var_value(%loop_block) : (!krnl.loop) -> index + %2 = vector.load %reshape[%1] : memref<16384xf32>, vector<32xf32> + %3 = vector.load %reshape_2[%1] : memref<16384xf32>, vector<32xf32> + %4 = arith.addf %2, %3 : vector<32xf32> + vector.store %4, %reshape_4[%1] : memref<16384xf32>, vector<32xf32> + } + return %alloc : memref<16x8x128xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @parallel_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) attributes {llvm.emit_c_interface} { +// CHECK: affine.parallel ([[arg1_:%.+]]) = (0) to (16384) step (32) { +// CHECK: krnl.parallel_clause([[arg1_]]) {proc_bind = "spread"} : index +// CHECK: } +// CHECK: } +} diff --git a/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir b/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir new file mode 100644 index 0000000000..4f04da7a4a --- /dev/null +++ b/test/mlir/conversion/krnl_to_llvm/call_with_return.mlir @@ -0,0 +1,10 @@ +// RUN: onnx-mlir-opt --convert-krnl-to-llvm %s -split-input-file | FileCheck %s + +func.func private @test_krnl_call_with_return(%arg0: memref<2x3xi32>) -> i32 { + %1 = "krnl.call"() {funcName = "get_omp_num_thread", numOfOutput = 0 : si64} : () -> (i32) + func.return %1: i32 +// CHECK: llvm.func @get_omp_num_thread() -> i32 +// CHECK: llvm.func @test_krnl_call_with_return +// CHECK: [[VAR_0_:%.+]] = llvm.call @get_omp_num_thread() : () -> i32 +// CHECK: llvm.return [[VAR_0_]] : i32 +} diff --git a/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/constants.mlir b/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/constants.mlir index 82bf2fb522..48196c7c58 100644 --- a/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/constants.mlir +++ b/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/constants.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s && rm model.constants.bin +// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s && rm -f model.constants.bin // Thresholds for this files: // -constants-to-file-single-threshold=0.03: 30 bytes for a single constants @@ -39,7 +39,7 @@ func.func @test_constants_to_file() -> memref<10xi64> { // CHECK-LABEL: module // CHECK: llvm.func @omGetExternalConstantAddr(!llvm.ptr, !llvm.ptr, i64) -// CHECK: llvm.func @omMMapBinaryFile(!llvm.ptr, !llvm.ptr, i64, i64) +// CHECK: llvm.func @omMMapBinaryFile(!llvm.ptr, !llvm.ptr, i64, i64) -> i1 // CHECK: llvm.mlir.global internal constant @constant_2(dense<[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]> : tensor<10xi64>) {addr_space = 0 : i32, alignment = 4096 : i64} : !llvm.array<10 x i64> // CHECK: llvm.mlir.global internal @om_external_constant_data_constant_1() {addr_space = 0 : i32, alignment = 4096 : i64} : !llvm.ptr { // CHECK: [[VAR_0_4_:%.+]] = llvm.mlir.zero : !llvm.ptr @@ -59,22 +59,38 @@ func.func @test_constants_to_file() -> memref<10xi64> { // CHECK: llvm.return [[VAR_0_6_]] : !llvm.ptr // CHECK: } -// CHECK: llvm.func @omLoadConstantsFromFile() { +// CHECK: llvm.func @omLoadConstantsFromFile() -> i1 { // CHECK-DAG: [[VAR_0_9_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 // CHECK-DAG: [[VAR_1_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0 : !llvm.ptr -// CHECK-DAG: [[VAR_2_3_:%.+]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-DAG: [[VAR_3_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1 : !llvm.ptr +// CHECK-DAG: [[VAR_2_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1 : !llvm.ptr +// CHECK-DAG: [[VAR_3_3_:%.+]] = llvm.mlir.constant(true) : i1 // CHECK-DAG: [[VAR_4_3_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 +// CHECK-DAG: [[VAR_5_3_:%.+]] = llvm.mlir.constant(0 : i64) : i64 // CHECK-DAG: [[VAR_6_3_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst : !llvm.ptr // CHECK-DAG: [[VAR_7_3_:%.+]] = llvm.mlir.addressof @om_external_constant_filename : !llvm.ptr -// CHECK: llvm.call @omMMapBinaryFile([[VAR_6_3_]], [[VAR_7_3_]], [[VAR_4_3_]], [[VAR_2_3_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> () -// CHECK: llvm.call @omGetExternalConstantAddr([[VAR_3_3_]], [[VAR_6_3_]], [[VAR_2_3_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK: [[VAR_8_3_:%.+]] = llvm.call @omMMapBinaryFile([[VAR_6_3_]], [[VAR_7_3_]], [[VAR_4_3_]], [[VAR_5_3_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> i1 +// CHECK: [[VAR_9_3_:%.+]] = llvm.icmp "ne" [[VAR_3_3_]], [[VAR_8_3_]] : i1 +// CHECK: llvm.cond_br [[VAR_9_3_]], ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: llvm.return [[VAR_8_3_]] : i1 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: llvm.call @omGetExternalConstantAddr([[VAR_2_3_]], [[VAR_6_3_]], [[VAR_5_3_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () // CHECK: llvm.call @omGetExternalConstantAddr([[VAR_1_3_]], [[VAR_6_3_]], [[VAR_0_9_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK: llvm.return +// CHECK: llvm.return [[VAR_3_3_]] : i1 // CHECK: } // CHECK: llvm.func @run_main_graph({{.*}}: !llvm.ptr) -> !llvm.ptr { -// CHECK: llvm.call @omLoadConstantsFromFile() : () -> () +// CHECK-DAG: [[VAR_3_4_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-DAG: [[VAR_4_4_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-DAG: [[VAR_5_4_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-DAG: [[VAR_6_4_:%.+]] = llvm.call @omLoadConstantsFromFile() : () -> i1 +// CHECK: [[VAR_7_4_:%.+]] = llvm.icmp "ne" [[VAR_5_4_]], [[VAR_6_4_]] : i1 +// CHECK: llvm.cond_br [[VAR_7_4_]], ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: [[VAR_8_4_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK: llvm.store [[VAR_4_4_]], [[VAR_8_4_]] : i32, !llvm.ptr +// CHECK: llvm.return [[VAR_3_4_]] : !llvm.ptr +// CHECK: ^bb2: // pred: ^bb0 // CHECK: } } diff --git a/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/symbol-postfix.mlir b/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/symbol-postfix.mlir index 0b0c3c7197..884756676d 100644 --- a/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/symbol-postfix.mlir +++ b/test/mlir/conversion/krnl_to_llvm/constants_to_file/big_endian/symbol-postfix.mlir @@ -1,5 +1,5 @@ // RUN: onnx-mlir-opt --convert-krnl-to-llvm --canonicalize %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s -check-prefix=CHECK-CONST-TO-FILE && rm model.constants.bin +// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s -check-prefix=CHECK-CONST-TO-FILE && rm -f model.constants.bin // ----- @@ -152,30 +152,57 @@ module attributes {"onnx-mlir.symbol-postfix" = "tag_constants_to_file"} { // CHECK-CONST-TO-FILE: llvm.return [[VAR_0_15_]] : !llvm.ptr // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile_tag_constants_to_file() { -// CHECK-CONST-TO-FILE-DAG: [[VAR_0_18_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_1_9_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_2_9_:%.+]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_3_9_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_4_7_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_6_6_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_7_4_:%.+]] = llvm.mlir.addressof @om_external_constant_filename_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE: llvm.call @omMMapBinaryFile([[VAR_6_6_]], [[VAR_7_4_]], [[VAR_4_7_]], [[VAR_2_9_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> () -// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_3_9_]], [[VAR_6_6_]], [[VAR_2_9_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_1_9_]], [[VAR_6_6_]], [[VAR_0_18_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK-CONST-TO-FILE: llvm.return +// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile_tag_constants_to_file() -> i1 { +// CHECK-CONST-TO-FILE-DAG: [[VAR_0_19_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_1_10_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_2_10_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_10_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_4_9_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_5_9_:%.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_6_8_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_7_5_:%.+]] = llvm.mlir.addressof @om_external_constant_filename_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE: [[VAR_8_4_:%.+]] = llvm.call @omMMapBinaryFile([[VAR_6_8_]], [[VAR_7_5_]], [[VAR_4_9_]], [[VAR_5_9_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> i1 +// CHECK-CONST-TO-FILE: [[VAR_9_4_:%.+]] = llvm.icmp "ne" [[VAR_3_10_]], [[VAR_8_4_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_9_4_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: llvm.return [[VAR_8_4_]] : i1 +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 +// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_2_10_]], [[VAR_6_8_]], [[VAR_5_9_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_1_10_]], [[VAR_6_8_]], [[VAR_0_19_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK-CONST-TO-FILE: llvm.return [[VAR_3_10_]] : i1 // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile() { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () -// CHECK-CONST-TO-FILE: llvm.return +// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile() -> i1 { +// CHECK-CONST-TO-FILE: [[VAR_0_20_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_20_]] : i1 // CHECK-CONST-TO-FILE: } // CHECK-CONST-TO-FILE: llvm.func @run_main_graph_tag_constants_to_file([[arg0_:%.+]]: !llvm.ptr) -> !llvm.ptr { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_11_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_4_10_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-CONST-TO-FILE-DAG: [[VAR_5_10_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_6_9_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: [[VAR_7_6_:%.+]] = llvm.icmp "ne" [[VAR_5_10_]], [[VAR_6_9_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_7_6_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_8_5_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.store [[VAR_4_10_]], [[VAR_8_5_]] : i32, !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_3_11_]] : !llvm.ptr +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @run_main_graph([[arg0_:%.+]]: !llvm.ptr) -> !llvm.ptr { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () -// CHECK-CONST-TO-FILE: [[VAR_0_20_:%.+]] = llvm.call @run_main_graph_tag_constants_to_file([[arg0_]]) : (!llvm.ptr) -> !llvm.ptr -// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_20_]] : !llvm.ptr + +// CHECK-CONST-TO-FILE: llvm.func @run_main_graph([[arg0_]]: !llvm.ptr) -> !llvm.ptr { +// CHECK-CONST-TO-FILE-DAG: [[VAR_0_22_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_1_12_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-CONST-TO-FILE-DAG: [[VAR_2_12_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_12_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: [[VAR_4_11_:%.+]] = llvm.icmp "ne" [[VAR_2_12_]], [[VAR_3_12_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_4_11_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_5_11_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.store [[VAR_1_12_]], [[VAR_5_11_]] : i32, !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_22_]] : !llvm.ptr +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_6_10_:%.+]] = llvm.call @run_main_graph_tag_constants_to_file([[arg0_]]) : (!llvm.ptr) -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_6_10_]] : !llvm.ptr // CHECK-CONST-TO-FILE: } } diff --git a/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/constants.mlir b/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/constants.mlir index 47dcf60bf0..00a3183d99 100644 --- a/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/constants.mlir +++ b/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/constants.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s && rm model.constants.bin +// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s && rm -f model.constants.bin // Thresholds for this files: // -constants-to-file-single-threshold=0.03: 30 bytes for a single constants @@ -38,7 +38,7 @@ func.func @test_constants_to_file() -> memref<10xi64> { return %2 : memref<10xi64> // CHECK: llvm.func @omGetExternalConstantAddr(!llvm.ptr, !llvm.ptr, i64) -// CHECK: llvm.func @omMMapBinaryFile(!llvm.ptr, !llvm.ptr, i64, i64) +// CHECK: llvm.func @omMMapBinaryFile(!llvm.ptr, !llvm.ptr, i64, i64) -> i1 // CHECK: llvm.mlir.global internal constant @constant_2(dense<[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]> : tensor<10xi64>) {addr_space = 0 : i32, alignment = 4096 : i64} : !llvm.array<10 x i64> // CHECK: llvm.mlir.global internal @om_external_constant_data_constant_1() {addr_space = 0 : i32, alignment = 4096 : i64} : !llvm.ptr { // CHECK: [[VAR_0_4_:%.+]] = llvm.mlir.zero : !llvm.ptr @@ -58,23 +58,39 @@ func.func @test_constants_to_file() -> memref<10xi64> { // CHECK: llvm.return [[VAR_0_6_]] : !llvm.ptr // CHECK: } -// CHECK: llvm.func @omLoadConstantsFromFile() { +// CHECK: llvm.func @omLoadConstantsFromFile() -> i1 { // CHECK-DAG: [[VAR_0_9_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 // CHECK-DAG: [[VAR_1_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0 : !llvm.ptr -// CHECK-DAG: [[VAR_2_3_:%.+]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-DAG: [[VAR_3_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1 : !llvm.ptr +// CHECK-DAG: [[VAR_1_4_:%.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-DAG: [[VAR_2_3_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1 : !llvm.ptr +// CHECK-DAG: [[VAR_3_3_:%.+]] = llvm.mlir.constant(true) : i1 // CHECK-DAG: [[VAR_4_3_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 // CHECK-DAG: [[VAR_5_3_:%.+]] = llvm.mlir.constant(1 : i64) : i64 // CHECK-DAG: [[VAR_6_3_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst : !llvm.ptr // CHECK-DAG: [[VAR_7_3_:%.+]] = llvm.mlir.addressof @om_external_constant_filename : !llvm.ptr -// CHECK: llvm.call @omMMapBinaryFile([[VAR_6_3_]], [[VAR_7_3_]], [[VAR_4_3_]], [[VAR_5_3_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> () -// CHECK: llvm.call @omGetExternalConstantAddr([[VAR_3_3_]], [[VAR_6_3_]], [[VAR_2_3_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK: [[VAR_8_3_:%.+]] = llvm.call @omMMapBinaryFile([[VAR_6_3_]], [[VAR_7_3_]], [[VAR_4_3_]], [[VAR_5_3_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> i1 +// CHECK: [[VAR_9_3_:%.+]] = llvm.icmp "ne" [[VAR_3_3_]], [[VAR_8_3_]] : i1 +// CHECK: llvm.cond_br [[VAR_9_3_]], ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: llvm.return [[VAR_8_3_]] : i1 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: llvm.call @omGetExternalConstantAddr([[VAR_2_3_]], [[VAR_6_3_]], [[VAR_1_4_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () // CHECK: llvm.call @omGetExternalConstantAddr([[VAR_1_3_]], [[VAR_6_3_]], [[VAR_0_9_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK: llvm.return +// CHECK: llvm.return [[VAR_3_3_]] : i1 // CHECK: } // CHECK: llvm.func @run_main_graph({{.*}}: !llvm.ptr) -> !llvm.ptr { -// CHECK: llvm.call @omLoadConstantsFromFile() : () -> () +// CHECK-DAG: [[VAR_3_4_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-DAG: [[VAR_4_4_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-DAG: [[VAR_5_4_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-DAG: [[VAR_6_4_:%.+]] = llvm.call @omLoadConstantsFromFile() : () -> i1 +// CHECK: [[VAR_7_4_:%.+]] = llvm.icmp "ne" [[VAR_5_4_]], [[VAR_6_4_]] : i1 +// CHECK: llvm.cond_br [[VAR_7_4_]], ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: [[VAR_8_4_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK: llvm.store [[VAR_4_4_]], [[VAR_8_4_]] : i32, !llvm.ptr +// CHECK: llvm.return [[VAR_3_4_]] : !llvm.ptr +// CHECK: ^bb2: // pred: ^bb0 // CHECK: } } diff --git a/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/symbol-postfix.mlir b/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/symbol-postfix.mlir index ee4ba9c75e..f763b6b3ba 100644 --- a/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/symbol-postfix.mlir +++ b/test/mlir/conversion/krnl_to_llvm/constants_to_file/litte_endian/symbol-postfix.mlir @@ -1,5 +1,5 @@ // RUN: onnx-mlir-opt --convert-krnl-to-llvm --canonicalize %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s -check-prefix=CHECK-CONST-TO-FILE && rm model.constants.bin +// RUN: onnx-mlir-opt --convert-krnl-to-llvm="store-constants-to-file constants-to-file-single-threshold=0.03 constants-to-file-total-threshold=0.00000006" --canonicalize %s -split-input-file | FileCheck %s -check-prefix=CHECK-CONST-TO-FILE && rm -f model.constants.bin // ----- @@ -152,31 +152,58 @@ module attributes {"onnx-mlir.symbol-postfix" = "tag_constants_to_file"} { // CHECK-CONST-TO-FILE: llvm.return [[VAR_0_15_]] : !llvm.ptr // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile_tag_constants_to_file() { -// CHECK-CONST-TO-FILE-DAG: [[VAR_0_18_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_1_9_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_2_9_:%.+]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_3_9_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_4_7_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_5_8_:%.+]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK-CONST-TO-FILE-DAG: [[VAR_6_6_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE-DAG: [[VAR_7_4_:%.+]] = llvm.mlir.addressof @om_external_constant_filename_tag_constants_to_file : !llvm.ptr -// CHECK-CONST-TO-FILE: llvm.call @omMMapBinaryFile([[VAR_6_6_]], [[VAR_7_4_]], [[VAR_4_7_]], [[VAR_5_8_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> () -// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_3_9_]], [[VAR_6_6_]], [[VAR_2_9_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_1_9_]], [[VAR_6_6_]], [[VAR_0_18_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () -// CHECK-CONST-TO-FILE: llvm.return +// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile_tag_constants_to_file() -> i1 { +// CHECK-CONST-TO-FILE-DAG: [[VAR_0_19_:%.+]] = llvm.mlir.constant(4096 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_1_10_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_0_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_1_11_:%.+]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_2_10_:%.+]] = llvm.mlir.addressof @om_external_constant_data_constant_1_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_10_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_4_9_:%.+]] = llvm.mlir.constant(4176 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_5_9_:%.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-CONST-TO-FILE-DAG: [[VAR_6_8_:%.+]] = llvm.mlir.addressof @om_external_constant_packedConst_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_7_5_:%.+]] = llvm.mlir.addressof @om_external_constant_filename_tag_constants_to_file : !llvm.ptr +// CHECK-CONST-TO-FILE: [[VAR_8_4_:%.+]] = llvm.call @omMMapBinaryFile([[VAR_6_8_]], [[VAR_7_5_]], [[VAR_4_9_]], [[VAR_5_9_]]) : (!llvm.ptr, !llvm.ptr, i64, i64) -> i1 +// CHECK-CONST-TO-FILE: [[VAR_9_4_:%.+]] = llvm.icmp "ne" [[VAR_3_10_]], [[VAR_8_4_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_9_4_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: llvm.return [[VAR_8_4_]] : i1 +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 +// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_2_10_]], [[VAR_6_8_]], [[VAR_1_11_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK-CONST-TO-FILE: llvm.call @omGetExternalConstantAddr([[VAR_1_10_]], [[VAR_6_8_]], [[VAR_0_19_]]) : (!llvm.ptr, !llvm.ptr, i64) -> () +// CHECK-CONST-TO-FILE: llvm.return [[VAR_3_10_]] : i1 // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile() { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () -// CHECK-CONST-TO-FILE: llvm.return +// CHECK-CONST-TO-FILE: llvm.func @omLoadConstantsFromFile() -> i1 { +// CHECK-CONST-TO-FILE: [[VAR_0_20_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_20_]] : i1 // CHECK-CONST-TO-FILE: } // CHECK-CONST-TO-FILE: llvm.func @run_main_graph_tag_constants_to_file([[arg0_:%.+]]: !llvm.ptr) -> !llvm.ptr { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_11_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_4_10_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-CONST-TO-FILE-DAG: [[VAR_5_10_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_6_9_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: [[VAR_7_6_:%.+]] = llvm.icmp "ne" [[VAR_5_10_]], [[VAR_6_9_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_7_6_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_8_5_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.store [[VAR_4_10_]], [[VAR_8_5_]] : i32, !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_3_11_]] : !llvm.ptr +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 // CHECK-CONST-TO-FILE: } -// CHECK-CONST-TO-FILE: llvm.func @run_main_graph([[arg0_:%.+]]: !llvm.ptr) -> !llvm.ptr { -// CHECK-CONST-TO-FILE: llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> () -// CHECK-CONST-TO-FILE: [[VAR_0_20_:%.+]] = llvm.call @run_main_graph_tag_constants_to_file([[arg0_]]) : (!llvm.ptr) -> !llvm.ptr -// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_20_]] : !llvm.ptr + +// CHECK-CONST-TO-FILE: llvm.func @run_main_graph([[arg0_]]: !llvm.ptr) -> !llvm.ptr { +// CHECK-CONST-TO-FILE-DAG: [[VAR_0_22_:%.+]] = llvm.mlir.zero : !llvm.ptr +// CHECK-CONST-TO-FILE-DAG: [[VAR_1_12_:%.+]] = llvm.mlir.constant(22 : i32) : i32 +// CHECK-CONST-TO-FILE-DAG: [[VAR_2_12_:%.+]] = llvm.mlir.constant(true) : i1 +// CHECK-CONST-TO-FILE-DAG: [[VAR_3_12_:%.+]] = llvm.call @omLoadConstantsFromFile_tag_constants_to_file() : () -> i1 +// CHECK-CONST-TO-FILE: [[VAR_4_11_:%.+]] = llvm.icmp "ne" [[VAR_2_12_]], [[VAR_3_12_]] : i1 +// CHECK-CONST-TO-FILE: llvm.cond_br [[VAR_4_11_]], ^bb1, ^bb2 +// CHECK-CONST-TO-FILE: ^bb1: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_5_11_:%.+]] = llvm.call @__errno_location() : () -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.store [[VAR_1_12_]], [[VAR_5_11_]] : i32, !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_0_22_]] : !llvm.ptr +// CHECK-CONST-TO-FILE: ^bb2: // pred: ^bb0 +// CHECK-CONST-TO-FILE: [[VAR_6_10_:%.+]] = llvm.call @run_main_graph_tag_constants_to_file([[arg0_]]) : (!llvm.ptr) -> !llvm.ptr +// CHECK-CONST-TO-FILE: llvm.return [[VAR_6_10_]] : !llvm.ptr // CHECK-CONST-TO-FILE: } } diff --git a/test/mlir/conversion/onnx_to_krnl/ControlFlow/If.mlir b/test/mlir/conversion/onnx_to_krnl/ControlFlow/If.mlir index 2e83a47a36..0891f3c5b0 100644 --- a/test/mlir/conversion/onnx_to_krnl/ControlFlow/If.mlir +++ b/test/mlir/conversion/onnx_to_krnl/ControlFlow/If.mlir @@ -8,15 +8,20 @@ func.func @test_if_simple(%arg0: tensor, %arg1: tensor, %arg2: tensor }) : (tensor) -> tensor return %0 : tensor -// CHECK-LABEL: @test_if_simple +// CHECK-LABEL: func.func @test_if_simple // CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref { -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][] : memref -// CHECK: [[VAR_1_:%.+]] = scf.if [[LOAD_PARAM_0_MEM_]] -> (memref) { -// CHECK: scf.yield [[PARAM_1_]] : memref +// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_2_]] : memref to tensor +// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref to tensor +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = scf.if [[LOAD_PARAM_0_MEM_]] -> (memref) { +// CHECK-DAG: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[VAR_1_]] : tensor to memref +// CHECK: scf.yield [[VAR_4_]] : memref // CHECK: } else { -// CHECK: scf.yield [[PARAM_2_]] : memref +// CHECK: [[VAR_4_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_0_]] : tensor to memref +// CHECK: scf.yield [[VAR_4_1_]] : memref // CHECK: } -// CHECK: return [[VAR_1_]] : memref +// CHECK: return [[VAR_3_]] : memref // CHECK: } } diff --git a/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir b/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir index 49ed94b610..838e3a4bca 100644 --- a/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir +++ b/test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir @@ -41,19 +41,19 @@ func.func private @test_loop_simple_main_graph(%arg0: tensor, %arg1: tensor // CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64> // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_2_]]{{.}} : memref<1xi64> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64 // CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index -// CHECK: krnl.store [[VAR_10_]], [[RES_3_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64> -// CHECK-DAG: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64> +// CHECK: krnl.store [[VAR_11_]], [[RES_3_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64> // CHECK-DAG: [[VAR_12_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref to memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[VAR_11_]] : tensor<1xi64> to memref<1xi64> -// CHECK-DAG: [[LOAD_VAR_12_MEM_:%.+]] = krnl.load [[VAR_12_]][] : memref +// CHECK-DAG: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[VAR_8_]] : tensor<1xi64> to memref<1xi64> +// CHECK: [[LOAD_VAR_12_MEM_:%.+]] = krnl.load [[VAR_12_]][] : memref // CHECK: krnl.store [[LOAD_VAR_12_MEM_]], [[RES_1_]][] : memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index @@ -111,8 +111,10 @@ func.func @test_loop(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK-DAG: [[VAR_dim_7_:%.+]] = memref.dim [[PARAM_2_]], [[CST_0_1_]] : memref // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index // CHECK: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_2_]], [[CST_0_2_]] : memref -// CHECK: [[VAR_11_:%.+]] = affine.max [[MAP_0_]]([[VAR_dim_7_]], [[VAR_dim_9_]]) -// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc([[VAR_11_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_11_:%.+]] = affine.max [[MAP_0_]]([[VAR_dim_7_]], [[VAR_dim_9_]]) +// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index +// CHECK: [[RES_3_:%.+]] = memref.alloc([[VAR_11_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_12_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref to tensor // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index @@ -123,11 +125,9 @@ func.func @test_loop(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK: [[VAR_20_:%.+]] = arith.addf [[LOAD_PARAM_2_MEM_]], [[LOAD_PARAM_2_MEM_1_]] : f32 // CHECK: krnl.store [[VAR_20_]], [[RES_3_]]{{.}}[[VAR_17_]]{{.}} : memref // CHECK: } -// CHECK-DAG: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref to tensor // CHECK-DAG: [[VAR_14_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref to memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = builtin.unrealized_conversion_cast [[VAR_13_]] : tensor to memref -// CHECK-DAG: [[LOAD_VAR_14_MEM_:%.+]] = krnl.load [[VAR_14_]][] : memref +// CHECK-DAG: [[VAR_15_:%.+]] = builtin.unrealized_conversion_cast [[VAR_12_]] : tensor to memref +// CHECK: [[LOAD_VAR_14_MEM_:%.+]] = krnl.load [[VAR_14_]][] : memref // CHECK: krnl.store [[LOAD_VAR_14_MEM_]], [[RES_1_]][] : memref // CHECK: "krnl.seqstore"([[VAR_15_]], [[RES_]], [[VAR_8_]]) : (memref, memref>, index) -> () // CHECK: }) : () -> () @@ -150,11 +150,10 @@ func.func @test_loop(%arg0: tensor, %arg1: tensor, %arg2: tensor // CHECK: [[VAR_dim_7_1_:%.+]] = memref.dim [[LOAD_RES_1_MEM_1_]], [[CST_0_9_]] : memref // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 0 to [[MAP_2_]]([[VAR_dim_7_1_]])){ // CHECK: [[VAR_11_1_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOOP_1_:%.+]] = krnl.load [[LOAD_RES_1_MEM_1_]]{{.}}[[VAR_11_1_]]{{.}} : memref -// CHECK: krnl.store [[LOOP_1_]], [[RES_4_]]{{.}}[[VAR_8_1_]], [[VAR_11_1_]]{{.}} : memref +// CHECK: [[VAR_12_1_:%.+]] = krnl.load [[LOAD_RES_1_MEM_1_]]{{.}}[[VAR_11_1_]]{{.}} : memref +// CHECK: krnl.store [[VAR_12_1_]], [[RES_4_]]{{.}}[[VAR_8_1_]], [[VAR_11_1_]]{{.}} : memref // CHECK: } // CHECK: }) : () -> () // CHECK: } // CHECK: return [[RES_4_]] : memref -// CHECK: } -} +} \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir index d560acc6da..735312d524 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir @@ -446,12 +446,13 @@ func.func @where(%arg0: tensor<2x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x // ----- + func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { %0 = "onnx.Round"(%arg0) : (tensor<15xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py -// CHECK-LABEL: func @round +// CHECK-LABEL: func.func @round // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<15xf32>) -> memref<15xf32> { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 // CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 @@ -459,8 +460,8 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<15xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ -// CHECK: [[IV:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][[[IV]]] : memref<15xf32> +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32> // CHECK: [[VAR_3_:%.+]] = math.floor [[LOAD_PARAM_0_MEM_]] : f32 // CHECK: [[VAR_4_:%.+]] = arith.subf [[LOAD_PARAM_0_MEM_]], [[VAR_3_]] : f32 // CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpf ogt, [[VAR_4_]], [[CST_5_dot_000000_]] : f32 @@ -477,7 +478,7 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_12_]], [[VAR_13_]], [[VAR_3_]] : f32 // CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_4_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_16_:%.+]] = arith.select [[VAR_15_]], [[VAR_14_]], [[VAR_7_]] : f32 -// CHECK: krnl.store [[VAR_16_]], [[RES_]][[[IV]]] : memref<15xf32> +// CHECK: krnl.store [[VAR_16_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32> // CHECK: } // CHECK: return [[RES_]] : memref<15xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir index 5e149d2d96..075f1626a0 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir @@ -1,7 +1,7 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. @@ -62,6 +62,7 @@ func.func @test_mean(%arg0: tensor<30xf32>, %arg1: tensor<30xf32>, %arg2: tensor // ----- + func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { %0 = "onnx.Round"(%arg0) : (tensor<15xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -69,35 +70,31 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-LABEL: func.func @round // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<15xf32>) -> memref<15xf32> { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<64xi8> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}[] : memref<64xi8> to memref<15xf32> -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ -// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> -// CHECK: [[VAR_3_:%.+]] = math.floor [[LOAD_PARAM_0_MEM_]] : vector<16xf32> -// CHECK: [[VAR_4_:%.+]] = arith.subf [[LOAD_PARAM_0_MEM_]], [[VAR_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpf ogt, [[VAR_4_]], [[VAR_cst_]] : vector<16xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_3_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_5_]], [[VAR_6_]], [[VAR_3_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.mulf [[VAR_3_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : vector<16xf32> -// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_3_]], [[VAR_10_]] : vector<16xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.cmpf oeq, [[VAR_11_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = arith.addf [[VAR_3_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_12_]], [[VAR_13_]], [[VAR_3_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_4_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_16_:%.+]] = arith.select [[VAR_15_]], [[VAR_14_]], [[VAR_7_]] : vector<16xi1>, vector<16xf32> -// CHECK: vector.store [[VAR_16_]], [[VAR_view_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}[] : memref<64xi8> to memref<15xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: [[VAR_3_:%.+]] = vector.shape_cast [[LOAD_PARAM_0_MEM_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_4_:%.+]] = vector.extract [[VAR_3_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_5_:%.+]] = "krnl.round_even"([[VAR_4_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = vector.insert [[VAR_5_]], [[VAR_3_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = vector.extract [[VAR_3_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_8_:%.+]] = "krnl.round_even"([[VAR_7_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.insert [[VAR_8_]], [[VAR_6_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.extract [[VAR_3_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_11_:%.+]] = "krnl.round_even"([[VAR_10_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.insert [[VAR_11_]], [[VAR_9_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.extract [[VAR_3_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_14_:%.+]] = "krnl.round_even"([[VAR_13_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_15_:%.+]] = vector.insert [[VAR_14_]], [[VAR_12_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK: [[VAR_16_:%.+]] = vector.shape_cast [[VAR_15_]] : vector<4x4xf32> to vector<16xf32> +// CHECK: vector.store [[VAR_16_]], [[VAR_view_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref<15xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Gemm_with_parallel_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Gemm_with_parallel_canonicalize_O3.mlir index fb8283d4e6..4965501596 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Gemm_with_parallel_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Gemm_with_parallel_canonicalize_O3.mlir @@ -30,15 +30,15 @@ func.func @test_gemm_parallel(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32> // CHECK: krnl.permute([[BLOCK_TILE__2_]], [[BLOCK_TILE__3_]], [[BLOCK_IN__3_]], [[BLOCK_TILE__4_]], [[BLOCK_IN__4_]], [[BLOCK_TILE__0_]], [[BLOCK_TILE__0_]]_1, [[BLOCK_IN__1_]]) [0, 3, 5, 1, 6, 2, 4, 7] : !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__4_]]) with ([[LOOP_0_]]#1 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#2 -> [[I_1_:%.+]] = 0 to 5, [[LOOP_0_]]#0 -> [[I_2_:%.+]] = 0 to 10){ // CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__4_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<32x256xf32> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<256x64xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<32x256xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<256x64xf32> // CHECK: krnl.copy_to_tile_buffer [[RES_2_]], [[PARAM_1_]]{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[CST_0_dot_000000_]] {padToNext = [], tileSize = []} : memref<256x64xf32>, memref<5x10xf32> // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with (){ // CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index // CHECK: krnl.copy_to_tile_buffer [[RES_1_]], [[PARAM_0_]]{{.}}[[VAR_2_]]#1, [[VAR_3_]]{{.}}, [[CST_0_dot_000000_]] {padToNext = [], tileSize = [], transpose = true} : memref<32x256xf32>, memref<5x10xf32> // CHECK: krnl.iterate([[BLOCK_TILE__3_]], [[BLOCK_TILE__1_]]) with (){ // CHECK: [[VAR_4_:%.+]]:2 = krnl.get_induction_var_value([[BLOCK_TILE__3_]], [[BLOCK_TILE__1_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: krnl.matmul [[RES_1_]]{{.}}[[VAR_3_]], [[VAR_2_]]#1], [[RES_1_]]_9{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[RES_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}}, ([[BLOCK_IN__1_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_4_]]#1, [[VAR_4_]]#0, [[VAR_2_]]#1), ([[CST_10_]], [[CST_10_]], [[CST_5_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 16, 256], simdize = false} : memref<32x256xf32>, memref<256x64xf32>, memref<10x10xf32>, (!krnl.loop, !krnl.loop, !krnl.loop) +// CHECK: krnl.matmul [[RES_1_]]{{.}}[[VAR_3_]], [[VAR_2_]]#1], [[RES_2_]]{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[RES_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}}, ([[BLOCK_IN__1_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_4_]]#1, [[VAR_4_]]#0, [[VAR_2_]]#1), ([[CST_10_]], [[CST_10_]], [[CST_5_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 16, 256], simdize = false} : memref<32x256xf32>, memref<256x64xf32>, memref<10x10xf32>, (!krnl.loop, !krnl.loop, !krnl.loop) // CHECK: } // CHECK: } // CHECK: } @@ -82,15 +82,15 @@ func.func @test_gemm_parallel_success(%arg0 : tensor<1024x1024xf32>, %arg1 : ten // CHECK: krnl.parallel([[BLOCK_TILE__2_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__2_]], [[BLOCK_TILE__4_]]) with ([[LOOP_0_]]#1 -> [[I_0_:%.+]] = 0 to 1024, [[LOOP_0_]]#2 -> [[I_1_:%.+]] = 0 to 1024, [[LOOP_0_]]#0 -> [[I_2_:%.+]] = 0 to 1024){ // CHECK-DAG: [[VAR_2_:%.+]]:2 = krnl.get_induction_var_value([[BLOCK_TILE__2_]], [[BLOCK_TILE__4_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<32x256xf32> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<256x64xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<32x256xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<256x64xf32> // CHECK: krnl.copy_to_tile_buffer [[RES_2_]], [[PARAM_1_]]{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[CST_0_dot_000000_]] {padToNext = [], tileSize = []} : memref<256x64xf32>, memref<1024x1024xf32> // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with (){ // CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index // CHECK: krnl.copy_to_tile_buffer [[RES_1_]], [[PARAM_0_]]{{.}}[[VAR_2_]]#1, [[VAR_3_]]{{.}}, [[CST_0_dot_000000_]] {padToNext = [], tileSize = [], transpose = true} : memref<32x256xf32>, memref<1024x1024xf32> // CHECK: krnl.iterate([[BLOCK_TILE__3_]], [[BLOCK_TILE__1_]]) with (){ // CHECK: [[VAR_4_:%.+]]:2 = krnl.get_induction_var_value([[BLOCK_TILE__3_]], [[BLOCK_TILE__1_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: krnl.matmul [[RES_1_]]{{.}}[[VAR_3_]], [[VAR_2_]]#1], [[RES_1_]]_9{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[RES_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}}, ([[BLOCK_IN__1_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_4_]]#1, [[VAR_4_]]#0, [[VAR_2_]]#1), ([[CST_1024_]], [[CST_1024_]], [[CST_1024_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 16, 256]} : memref<32x256xf32>, memref<256x64xf32>, memref<1024x1024xf32>, (!krnl.loop, !krnl.loop, !krnl.loop) +// CHECK: krnl.matmul [[RES_1_]]{{.}}[[VAR_3_]], [[VAR_2_]]#1], [[RES_2_]]{{.}}[[VAR_2_]]#1, [[VAR_2_]]#0], [[RES_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}}, ([[BLOCK_IN__1_]], [[BLOCK_IN__3_]], [[BLOCK_IN__4_]]), ([[VAR_4_]]#1, [[VAR_4_]]#0, [[VAR_2_]]#1), ([[CST_1024_]], [[CST_1024_]], [[CST_1024_]]) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 16, 256]} : memref<32x256xf32>, memref<256x64xf32>, memref<1024x1024xf32>, (!krnl.loop, !krnl.loop, !krnl.loop) // CHECK: } // CHECK: } // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir index 29c76f24fc..576e245cb9 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/MatMulInteger_with_canonicalize_O3.mlir @@ -1,9 +1,9 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s // ----- -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. diff --git a/test/mlir/conversion/onnx_to_krnl/Math/MatMul_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/MatMul_with_canonicalize_O3.mlir index e221f249a9..e1be9cebbc 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/MatMul_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/MatMul_with_canonicalize_O3.mlir @@ -1,7 +1,7 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index 5ef0892d8b..f35603fc9e 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -1,13 +1,85 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. // ----- +func.func @test_reduce_scalar_axes(%arg0: tensor) -> tensor { + %axes= onnx.Constant dense<-2> : tensor + %0 = "onnx.ReduceSum"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor, tensor) -> tensor + return %0: tensor + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0, s1] -> (s1)> +// CHECK-LABEL: func.func @test_reduce_scalar_axes +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: [[RES_:%.+]] = memref.alloc([[VAR_dim_]], [[VAR_dim_]]_0) {{.*}}: memref +// CHECK: krnl.memset [[RES_]], [[CST_0_dot_000000_]] : memref +// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3 +// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_2_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_1_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_1_]], [[VAR_dim_2_]]{{.}}){ +// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2] : memref +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#2] : memref +// CHECK: [[VAR_4_:%.+]] = arith.addf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_4_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#2] : memref +// CHECK: } +// CHECK: return [[RES_]] : memref +// CHECK: } +} + +// ----- + +// COM: Full reduction over all dimensions to a scalar value. +func.func @test_reduce_all_to_scalar(%arg0: tensor) -> tensor<*xf32> { + %axes = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor, none) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-LABEL: func.func @test_reduce_all_to_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_0_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[VAR_1_]]){ +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_5_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_8_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_1_]] : vector<32xf32> into f32 +// CHECK: krnl.store [[VAR_4_]], [[RES_2_]][] : memref +// CHECK: return [[RES_2_]] : memref +// CHECK: } +} + +// ----- func.func private @test_reducemax_v13(%arg0 : tensor<3x2x2xf32>) -> tensor<*xf32> { %0 ="onnx.ReduceMaxV13"(%arg0) {axes=[1], keepdims = 0 : si64} : (tensor<3x2x2xf32>)-> tensor<*xf32> @@ -291,7 +363,7 @@ func.func private @gpt2_original(%arg0 : tensor) -> tensor (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ // CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> // CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} // CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index // CHECK: scf.if [[VAR_9_]] { @@ -299,7 +371,7 @@ func.func private @gpt2_original(%arg0 : tensor) -> tensor, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_768_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 768){ // CHECK: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_14_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -312,53 +384,51 @@ func.func private @gpt2_original(%arg0 : tensor) -> tensor // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_2_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_2_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_2_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_2_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_2_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_8_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_9_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_2_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_5_]], [[LOAD_RES_2_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -404,7 +474,7 @@ func.func private @gpt2_no_keepdims(%arg0 : tensor) -> tensor<*xf32 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ // CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> // CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} // CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index // CHECK: scf.if [[VAR_9_]] { @@ -412,7 +482,7 @@ func.func private @gpt2_no_keepdims(%arg0 : tensor) -> tensor<*xf32 // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_768_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 768){ // CHECK: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_14_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -425,53 +495,51 @@ func.func private @gpt2_no_keepdims(%arg0 : tensor) -> tensor<*xf32 // CHECK: krnl.store [[VAR_13_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_1_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_8_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_9_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_1_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_5_]], [[LOAD_RES_1_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -495,9 +563,9 @@ func.func private @gpt2_reduce2(%arg0 : tensor) -> tensor<*xf32> { // CHECK-LABEL: func.func private @gpt2_reduce2 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs @@ -528,7 +596,7 @@ func.func private @gpt2_reduce2(%arg0 : tensor) -> tensor<*xf32> { // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ // CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_3_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> // CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} // CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index // CHECK: scf.if [[VAR_9_]] { @@ -536,7 +604,7 @@ func.func private @gpt2_reduce2(%arg0 : tensor) -> tensor<*xf32> { // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_768_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 768){ // CHECK: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_14_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -549,53 +617,51 @@ func.func private @gpt2_reduce2(%arg0 : tensor) -> tensor<*xf32> { // CHECK: krnl.store [[VAR_13_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_5_]], [[LOAD_RES_3_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -619,9 +685,11 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK-LABEL: func.func private @gpt2_one_not_multiple // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> -// CHECK-DAG: [[CST_776_:%.+]] = arith.constant 776 : index +// CHECK-DAG: [[CST_773_:%.+]] = arith.constant 773 : index // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_776_:%.+]] = arith.constant 776 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs @@ -652,7 +720,7 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ // CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_3_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> // CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} // CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index // CHECK: scf.if [[VAR_9_]] { @@ -660,7 +728,7 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_776_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 773){ // CHECK: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_14_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -673,53 +741,51 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK: krnl.store [[VAR_13_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_776_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[CST_773_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_5_]], [[LOAD_RES_3_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -736,12 +802,19 @@ func.func private @gpt2_no_simd_as_not_mult_of_VL(%arg0 : tensor) // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0, s1] -> (s1)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2, s3] -> (s3)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 4)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func private @gpt2_no_simd_as_not_mult_of_VL // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_872_:%.+]] = arith.constant 872 : index +// CHECK-DAG: [[CST_870_:%.+]] = arith.constant 870 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_873_:%.+]] = arith.constant 873 : index -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs @@ -756,24 +829,114 @@ func.func private @gpt2_no_simd_as_not_mult_of_VL(%arg0 : tensor) // CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_dim_]], [[VAR_dim_]]_0 : index // CHECK: [[VAR_3_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_2_]] : index // CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_3_]] : index to i64 -// CHECK: [[VAR_5_:%.+]] = arith.sitofp [[VAR_4_]] : i64 to f32 -// CHECK: krnl.memset [[RES_]], [[CST_0_dot_000000_]] : memref -// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[VAR_4_]] : i64 to f32 // CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref // CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_3_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]]{{.}}, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 97, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 9){ -// CHECK: [[VAR_8_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[VAR_8_]]#2, [[VAR_8_]]#3] : memref -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_0_]], [[CST_0_]]{{.}} : memref -// CHECK: [[VAR_11_:%.+]] = arith.addf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_11_]], [[RES_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_0_]], [[CST_0_]]{{.}} : memref -// CHECK: } -// CHECK: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]], [[VAR_dim_]]{{.}}, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]], [[VAR_dim_]], [[VAR_dim_]]_0], [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 1, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 1){ -// CHECK: [[VAR_8_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[VAR_8_1_]]#2, [[VAR_8_1_]]#3] : memref -// CHECK: [[LOAD_RES_MEM_2_:%.+]] = arith.divf [[LOAD_RES_MEM_1_]], [[VAR_5_]] : f32 -// CHECK: krnl.store [[LOAD_RES_MEM_2_]], [[RES_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[VAR_8_1_]]#2, [[VAR_8_1_]]#3] : memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[VAR_dim_3_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[VAR_dim_4_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_873_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref, memref<3xindex>) -> memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[VAR_dim_]], [[RES_2_]][0] : memref<2xindex> +// CHECK: affine.store [[VAR_dim_0_]], [[RES_2_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref, memref<2xindex>) -> memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ +// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} +// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index +// CHECK: scf.if [[VAR_9_]] { +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_7_]]#1 to [[VAR_dim_0_]] step [[CST_1_]] { +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 870){ +// CHECK: [[VAR_15_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_15_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.addf [[LOAD_RES_3_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<4xf32> +// CHECK: vector.store [[VAR_18_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 872 to 873){ +// CHECK: [[VAR_15_1_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_15_1_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: [[VAR_18_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_18_1_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: } +// CHECK: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_13_:%.+]] = vector.reduction , [[LOAD_RES_3_MEM_2_]] : vector<4xf32> into f32 +// CHECK: [[VAR_14_:%.+]] = arith.divf [[VAR_13_]], [[VAR_5_]] : f32 +// CHECK: krnl.store [[VAR_14_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref +// CHECK: } +// CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOOP_2_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[CST_870_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_2_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_2_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_5_]] : vector<4xf32> +// CHECK: vector.store [[VAR_48_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_50_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_8_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_2_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_9_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_2_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[LOAD_RES_3_MEM_8_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[LOAD_RES_3_MEM_9_]], [[LOAD_VAR_reshape_MEM_8_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addf [[LOAD_RES_3_MEM_10_]], [[LOAD_VAR_reshape_MEM_9_]] : f32 +// CHECK: memref.store [[VAR_21_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_24_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_12_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_13_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_14_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_29_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_11_]], [[LOAD_RES_3_MEM_12_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_11_]], [[LOAD_RES_3_MEM_12_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_31_:%.+]] = arith.addf [[VAR_30_]], [[VAR_29_]] : vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_13_]], [[LOAD_RES_3_MEM_14_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_13_]], [[LOAD_RES_3_MEM_14_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_34_:%.+]] = arith.addf [[VAR_33_]], [[VAR_32_]] : vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.shuffle [[VAR_31_]], [[VAR_34_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = vector.shuffle [[VAR_31_]], [[VAR_34_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_36_]], [[VAR_35_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_39_:%.+]] = arith.divf [[VAR_37_]], [[VAR_38_]] : vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -795,59 +958,57 @@ func.func private @test_reducemax_v13_bis(%arg0 : tensor<1028x256xf32>) -> tenso // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<4xf32> // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1028xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 1028){ // CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = [[CST_0_]] to [[CST_256_]]){ -// CHECK: [[VAR_16_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_1_]]6] : memref<1028x256xf32>, vector<4xf32> +// CHECK: affine.for [[I_1_:%.+]] = 0 to 256 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_2_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_3_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_4_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_19_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_20_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_23_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_24_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_24_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_27_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_27_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_28_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_28_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_31_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_31_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.maxnumf [[VAR_7_]], [[VAR_8_]] : vector<4xf32> -// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_12_:%.+]] = arith.maxnumf [[VAR_10_]], [[VAR_11_]] : vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_15_:%.+]] = arith.maxnumf [[VAR_13_]], [[VAR_14_]] : vector<4xf32> -// CHECK: vector.store [[VAR_15_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<1028xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxnumf [[VAR_10_]], [[VAR_9_]] : vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[VAR_12_]] : vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.maxnumf [[VAR_16_]], [[VAR_15_]] : vector<4xf32> +// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<1028xf32>, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<1028xf32> // CHECK: } @@ -873,13 +1034,12 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<* // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_7_:%.+]] = arith.constant 7 : index -// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<7xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 7){ // CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> // CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) // CHECK: [[VAR_3_:%.+]] = arith.cmpi slt, [[VAR_2_]], [[CST_0_]] : index // CHECK: scf.if [[VAR_3_]] { @@ -887,7 +1047,7 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<* // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_8_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 8){ // CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[I_1_]], [[VAR_7_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -899,50 +1059,48 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<* // CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[I_1_]]{{.}} : memref<7xf32> // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_6_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_8_]]){ -// CHECK: [[VAR_18_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_1_]]8] : memref<7x8xf32>, vector<4xf32> +// CHECK: affine.for [[I_3_:%.+]] = 0 to 8 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[LOOP_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[LOAD_RES_1_MEM_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_21_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_21_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_22_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_22_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_26_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_30_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_30_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_31_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_8_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_9_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_1_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_10_]], [[VAR_10_1_]] : vector<4xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_12_]], [[VAR_13_]] : vector<4xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_17_:%.+]] = arith.maxnumf [[VAR_15_]], [[VAR_16_]] : vector<4xf32> -// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<7xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[VAR_11_]] : vector<4xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_16_:%.+]] = arith.maxnumf [[VAR_15_]], [[VAR_14_]] : vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = vector.shuffle [[VAR_13_]], [[VAR_16_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[VAR_13_]], [[VAR_16_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[VAR_17_]] : vector<4xf32> +// CHECK: vector.store [[VAR_19_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<7xf32>, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref<7xf32> @@ -961,16 +1119,15 @@ func.func private @test_reducemax_int_v13(%arg0 : tensor<128x256x768xi32>) -> te // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<128x256x768xi32>) -> memref<128x256xi32> { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<-2147483648> : vector<32xi32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<128x256xi32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 128, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<1x32xi32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1x32xi32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 768){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_5_]]{{.}} : memref<128x256x768xi32>, vector<32xi32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<1x32xi32>, vector<32xi32> @@ -1005,7 +1162,6 @@ func.func private @bertsquad10_same_pattern(%arg0 : tensor) -> te // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index @@ -1026,54 +1182,53 @@ func.func private @bertsquad10_same_pattern(%arg0 : tensor) -> te // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = affine.apply [[MAP_2_]]([[VAR_6_]]#1) +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_3_]]([[VAR_6_]]#1) +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_4_]]([[VAR_6_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_23_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[VAR_23_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_2_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_7_]], [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_8_]], [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_9_]], [[I_2_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_26_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_26_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_2_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_27_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_30_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_31_:%.+]] = affine.apply [[MAP_3_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_31_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_34_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_35_:%.+]] = affine.apply [[MAP_4_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_35_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_33_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_34_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_35_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_14_:%.+]] = arith.addf [[VAR_12_]], [[VAR_13_]] : vector<4xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_17_:%.+]] = arith.addf [[VAR_15_]], [[VAR_16_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[VAR_14_]], [[VAR_17_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[VAR_14_]], [[VAR_17_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addf [[VAR_15_]], [[VAR_14_]] : vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[VAR_16_]], [[VAR_19_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_16_]], [[VAR_19_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.splat [[VAR_4_]] : vector<4xf32> -// CHECK: [[VAR_22_:%.+]] = arith.divf [[VAR_20_]], [[VAR_21_]] : vector<4xf32> -// CHECK: vector.store [[VAR_22_]], [[VAR_reshape_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.splat [[VAR_4_]] : vector<4xf32> +// CHECK: [[VAR_24_:%.+]] = arith.divf [[VAR_22_]], [[VAR_23_]] : vector<4xf32> +// CHECK: vector.store [[VAR_24_]], [[VAR_reshape_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -1101,7 +1256,6 @@ func.func private @bertsquad10_const_pattern(%arg0 : tensor<1x256x768xf32>) -> t // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x256x1xf32> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> // CHECK: affine.store [[CST_1_]], [[RES_1_]][0] : memref<2xindex> @@ -1111,52 +1265,51 @@ func.func private @bertsquad10_const_pattern(%arg0 : tensor<1x256x768xf32>) -> t // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1) +// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1) +// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1) // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_17_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> +// CHECK: affine.for [[I_2_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_2_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_3_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_4_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_20_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_21_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_21_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_24_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_24_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_25_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_28_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_28_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_29_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_32_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_32_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_27_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_30_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_7_]], [[VAR_8_]] : vector<4xf32> -// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_12_:%.+]] = arith.addf [[VAR_10_]], [[VAR_11_]] : vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_15_:%.+]] = arith.addf [[VAR_13_]], [[VAR_14_]] : vector<4xf32> -// CHECK: [[VAR_16_:%.+]] = arith.divf [[VAR_15_]], [[VAR_cst_]] : vector<4xf32> -// CHECK: vector.store [[VAR_16_]], [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<1x256xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_10_]], [[VAR_9_]] : vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.addf [[VAR_13_]], [[VAR_12_]] : vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.addf [[VAR_16_]], [[VAR_15_]] : vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.divf [[VAR_17_]], [[VAR_cst_]] : vector<4xf32> +// CHECK: vector.store [[VAR_18_]], [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<1x256xf32>, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<1x256x1xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir index 7290d34032..8e76541332 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_parallel_canonicalize_O3.mlir @@ -2,6 +2,83 @@ // ----- +// COM: Full reduction over all dimensions to a scalar value. +func.func @test_reduce_all_to_scalar(%arg0: tensor) -> tensor<*xf32> { + %axes = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.ReduceMax"(%arg0, %axes) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor, none) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0] -> (s0 - 31)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0, s1] -> (s1 + ((s0 - s1) floordiv 32) * 32)> +// CHECK-LABEL: func.func @test_reduce_all_to_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_dim_1_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[VAR_dim_1_]] : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_1_]], [[RES_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = arith.ceildivsi [[VAR_1_]], [[CST_8_]] : index +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_8_:%.+]] = arith.muli [[VAR_7_]], [[VAR_2_]] : index +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_2_]] : index +// CHECK: [[VAR_10_:%.+]] = arith.cmpi slt, [[VAR_1_]], [[VAR_9_]] : index +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[VAR_1_]], [[VAR_9_]] : index +// CHECK-DAG: [[VAR_12_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]) +// CHECK: vector.store [[VAR_cst_0_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_13_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_11_]]{{.}} +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_8_]] to [[VAR_13_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK: vector.store [[VAR_19_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[VAR_14_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_11_]], [[VAR_8_]]{{.}} +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_14_]] to [[VAR_11_]] step [[CST_1_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = memref.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32> +// CHECK: [[VAR_19_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : f32 +// CHECK: memref.store [[VAR_19_1_]], [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[VAR_12_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_16_:%.+]] = vector.reduction , [[LOAD_RES_1_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_16_]], [[RES_2_]]{{.}}[[VAR_7_]]{{.}} : memref<8xf32> +// CHECK: } +// CHECK: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ +// CHECK: [[VAR_7_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_8_1_:%.+]] = krnl.load [[RES_2_]]{{.}}[[VAR_7_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = krnl.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: [[VAR_10_1_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[VAR_8_1_]] : f32 +// CHECK: krnl.store [[VAR_10_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[VAR_6_:%.+]] = vector.extract [[LOAD_RES_1_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK: krnl.store [[VAR_6_]], [[RES_3_]][] : memref +// CHECK: return [[RES_3_]] : memref +// CHECK: } +} + +// ----- + // With enable-parallel, a krnl.parallel should be created, which takes a loop (to be parallelized) // as input. The krnl.parallel should be the last operator before krnl.iterate, since the lowering // needs to interpret krnl.block, krnl.permute, krnl.unroll first. diff --git a/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir index 149c8e7c2a..ca4ce47c7e 100644 --- a/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir @@ -1,7 +1,7 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // ----- @@ -23,7 +23,7 @@ func.func @layernorm_4D_with_scale_bias(%arg0: tensor<2x64x32x8xf32>, %arg1: ten // ----- -// collapsed range is not a multiple of 4, cannot do simd +// collapsed range is not a multiple of 4, cannot do simd: Update, it is now supported. func.func @layernorm_4D_with_scale_bias_no_SIMD(%arg0: tensor<2x64x31x3xf32>, %arg1: tensor<31x3xf32>, %arg2: tensor<31x3xf32>) -> tensor<*xf32> { %0 = "onnx.NoValue"() {value} : () -> none @@ -31,12 +31,445 @@ func.func @layernorm_4D_with_scale_bias_no_SIMD(%arg0: tensor<2x64x31x3xf32>, %a onnx.Return %Y : tensor<*xf32> // mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func @layernorm_4D_with_scale_bias_no_SIMD // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x64x31x3xf32>, [[PARAM_1_:%.+]]: memref<31x3xf32>, [[PARAM_2_:%.+]]: memref<31x3xf32>) -> memref<2x64x31x3xf32> { -// CHECK: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 31, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 3){ -// CHECK: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 2, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to 64, [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 1, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 1){ +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<9.300000e+01> : vector<4xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_92_:%.+]] = arith.constant 92 : index +// CHECK-DAG: [[CST_90_:%.+]] = arith.constant 90 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_93_:%.+]] = arith.constant 93 : index +// CHECK-DAG: [[CST_11904_:%.+]] = arith.constant 11904 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [1], value = dense<9.99999974E-6> : tensor<1xf32>} : () -> memref<1xf32> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_2_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_2_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_11_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_2_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_3_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addf [[VAR_29_]], [[VAR_28_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_32_]], [[VAR_31_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_:%.+]] = arith.addf [[VAR_35_]], [[VAR_34_]] : vector<4xf32> +// CHECK: [[VAR_37_:%.+]] = arith.divf [[VAR_36_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[VAR_reshape_4_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[RES_]]([[RES_]]_8) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_7_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_4_]]([[RES_7_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_7_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_]], [[VAR_reshape_11_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_9_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_10_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_16_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_11_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_18_:%.+]] = memref.reshape [[RES_8_]]([[RES_11_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 11904){ +// CHECK: [[VAR_9_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_14_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_16_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_]], [[VAR_reshape_18_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_13_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_13_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_13_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_13_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_8_]]([[RES_13_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_14_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_14_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_14_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[RES_12_]]([[RES_14_]]) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_3_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 2, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_15_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_3_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_11_2_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_1_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_7_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_8_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_12_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_12_]], [[LOAD_VAR_reshape_MEM_8_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_1_]] : f32 +// CHECK-DAG: [[VAR_21_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_1_:%.+]] = arith.addf [[VAR_29_1_]], [[VAR_28_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_1_:%.+]] = arith.addf [[VAR_32_1_]], [[VAR_31_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_1_:%.+]] = arith.addf [[VAR_35_1_]], [[VAR_34_1_]] : vector<4xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_1_]], [[VAR_reshape_23_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_16_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_17_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_17_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_28_:%.+]] = memref.reshape [[RES_12_]]([[RES_17_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_18_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_18_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_30_:%.+]] = memref.reshape [[RES_4_]]([[RES_18_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_19_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_19_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_32_:%.+]] = memref.reshape [[RES_16_]]([[RES_19_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_4_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__4_]]) with ([[LOOP_4_]] -> [[I_8_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__4_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_2_:%.+]] = vector.load [[VAR_reshape_28_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_2_:%.+]] = vector.load [[VAR_reshape_30_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = arith.subf [[VAR_10_2_]], [[VAR_11_2_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_1_]], [[VAR_reshape_32_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_20_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_21_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_21_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_21_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_21_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_35_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_21_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_22_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_22_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_22_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_22_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_37_:%.+]] = memref.reshape [[RES_]]([[RES_]]_36) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_23_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_23_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_23_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_23_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_39_:%.+]] = memref.reshape [[RES_20_]]([[RES_23_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_5_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_9_:%.+]] = 0 to 2, [[LOOP_5_]]#1 -> [[I_10_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_6_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__5_:%.+]], [[BLOCK_IN__5_:%.+]] = krnl.block [[LOOP_6_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__5_]]) with ([[LOOP_6_]] -> [[I_11_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__5_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = vector.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_7_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_7_]]) with ([[LOOP_7_]] -> [[I_12_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_4_:%.+]] = krnl.get_induction_var_value([[LOOP_7_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = krnl.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_24_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_25_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_25_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_42_:%.+]] = memref.reshape [[RES_16_]]([[RES_25_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_26_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_26_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_44_:%.+]] = memref.reshape [[RES_24_]]([[RES_26_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_8_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__6_:%.+]], [[BLOCK_IN__6_:%.+]] = krnl.block [[LOOP_8_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__6_]]) with ([[LOOP_8_]] -> [[I_13_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__6_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_7_:%.+]] = vector.load [[VAR_reshape_42_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_4_:%.+]] = krnl.load [[VAR_0_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = vector.splat [[VAR_11_4_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = arith.addf [[LOOP_7_]], [[LOAD_VAR_reshape_MEM_4_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_5_1_1_]], [[VAR_reshape_44_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_27_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_28_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_28_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_47_:%.+]] = memref.reshape [[RES_24_]]([[RES_28_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_29_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_29_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_49_:%.+]] = memref.reshape [[RES_27_]]([[RES_29_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_9_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__7_:%.+]], [[BLOCK_IN__7_:%.+]] = krnl.block [[LOOP_9_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__7_]]) with ([[LOOP_9_]] -> [[I_14_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_6_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__7_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_47_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_5_:%.+]] = math.sqrt [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_5_]], [[VAR_reshape_49_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_30_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_31_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_31_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_52_:%.+]] = memref.reshape [[RES_27_]]([[RES_31_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_32_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_32_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_54_:%.+]] = memref.reshape [[RES_30_]]([[RES_32_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_10_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__8_:%.+]], [[BLOCK_IN__8_:%.+]] = krnl.block [[LOOP_10_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__8_]]) with ([[LOOP_10_]] -> [[I_15_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__8_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_52_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_6_:%.+]] = arith.divf [[VAR_cst_]], [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_6_]], [[VAR_reshape_54_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_33_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_34_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_34_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_34_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_34_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_57_:%.+]] = memref.reshape [[RES_20_]]([[RES_34_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_35_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_35_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_35_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_35_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_59_:%.+]] = memref.reshape [[RES_30_]]([[RES_35_]]) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_36_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_36_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_36_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_36_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_61_:%.+]] = memref.reshape [[RES_33_]]([[RES_36_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_11_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_11_]]#0, [[LOOP_11_]]#1) with ([[LOOP_11_]]#0 -> [[I_16_:%.+]] = 0 to 2, [[LOOP_11_]]#1 -> [[I_17_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_11_]]#0, [[LOOP_11_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_12_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__9_:%.+]], [[BLOCK_IN__9_:%.+]] = krnl.block [[LOOP_12_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__9_]]) with ([[LOOP_12_]] -> [[I_18_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__9_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_13_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_13_]]) with ([[LOOP_13_]] -> [[I_19_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_8_:%.+]] = krnl.get_induction_var_value([[LOOP_13_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_37_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_38_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_38_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_38_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_38_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_64_:%.+]] = memref.reshape [[RES_33_]]([[RES_38_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_39_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_39_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_66_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_39_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_40_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_40_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_40_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_40_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_68_:%.+]] = memref.reshape [[RES_37_]]([[RES_40_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_14_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_14_]]#0, [[LOOP_14_]]#1) with ([[LOOP_14_]]#0 -> [[I_20_:%.+]] = 0 to 2, [[LOOP_14_]]#1 -> [[I_21_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_4_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_14_]]#0, [[LOOP_14_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_15_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__10_:%.+]], [[BLOCK_IN__10_:%.+]] = krnl.block [[LOOP_15_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__10_]]) with ([[LOOP_15_]] -> [[I_22_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_9_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__10_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = vector.load [[VAR_reshape_66_]]{{.}}[[VAR_11_9_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_16_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_16_]]) with ([[LOOP_16_]] -> [[I_23_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_10_:%.+]] = krnl.get_induction_var_value([[LOOP_16_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_66_]]{{.}}[[VAR_11_10_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK: [[RES_41_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[RES_41_]] : memref<2x64x31x3xf32> to tensor<2x64x31x3xf32> +// CHECK-DAG: [[RES_42_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_42_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_42_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_42_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_71_:%.+]] = memref.reshape [[RES_37_]]([[RES_42_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_43_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_43_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_73_:%.+]] = memref.reshape [[PARAM_2_]]([[RES_43_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_44_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_44_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_44_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_44_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_75_:%.+]] = memref.reshape [[RES_41_]]([[RES_44_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_17_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_17_]]#0, [[LOOP_17_]]#1) with ([[LOOP_17_]]#0 -> [[I_24_:%.+]] = 0 to 2, [[LOOP_17_]]#1 -> [[I_25_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_17_]]#0, [[LOOP_17_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_18_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__11_:%.+]], [[BLOCK_IN__11_:%.+]] = krnl.block [[LOOP_18_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__11_]]) with ([[LOOP_18_]] -> [[I_26_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_11_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__11_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_73_]]{{.}}[[VAR_11_11_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_19_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_19_]]) with ([[LOOP_19_]] -> [[I_27_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_12_:%.+]] = krnl.get_induction_var_value([[LOOP_19_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_73_]]{{.}}[[VAR_11_12_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK: onnx.Return [[VAR_6_]] : tensor<2x64x31x3xf32> +// CHECK: } } // ----- diff --git a/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_parallel_canonicalize.mlir new file mode 100644 index 0000000000..e177dbf262 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_parallel_canonicalize.mlir @@ -0,0 +1,498 @@ +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl=enable-parallel --canonicalize %s -split-input-file | FileCheck %s + +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. + +// ----- + +// It should make the substitution with the fast algo +func.func @layernorm_4D_with_scale_bias(%arg0: tensor<2x64x32x8xf32>, %arg1: tensor<32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<*xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {axis = -2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x64x32x8xf32>, tensor<32x8xf32>, tensor<32x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + onnx.Return %Y : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_4D_with_scale_bias +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x64x32x8xf32>, [[PARAM_1_:%.+]]: memref<32x8xf32>, [[PARAM_2_:%.+]]: memref<32x8xf32>) -> memref<2x64x32x8xf32> { +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_:%.+]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 128){ +// CHECK: affine.for [[I_1_:%.+]] = 0 to 256 step 16 { +// CHECK: affine.for [[I_2_:%.+]] = 0 to 256 step 16 { +// CHECK: onnx.Return [[VAR_1_:%.+]] : tensor<2x64x32x8xf32> +} + +// ----- + +// collapsed range is not a multiple of 4, cannot do simd: Update, it is now supported. + +func.func @layernorm_4D_with_scale_bias_no_SIMD(%arg0: tensor<2x64x31x3xf32>, %arg1: tensor<31x3xf32>, %arg2: tensor<31x3xf32>) -> tensor<*xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {axis = -2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x64x31x3xf32>, tensor<31x3xf32>, tensor<31x3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + onnx.Return %Y : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 3)> +// CHECK-LABEL: func.func @layernorm_4D_with_scale_bias_no_SIMD +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x64x31x3xf32>, [[PARAM_1_:%.+]]: memref<31x3xf32>, [[PARAM_2_:%.+]]: memref<31x3xf32>) -> memref<2x64x31x3xf32> { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<9.300000e+01> : vector<4xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_92_:%.+]] = arith.constant 92 : index +// CHECK-DAG: [[CST_90_:%.+]] = arith.constant 90 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_93_:%.+]] = arith.constant 93 : index +// CHECK-DAG: [[CST_11904_:%.+]] = arith.constant 11904 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [1], value = dense<9.99999974E-6> : tensor<1xf32>} : () -> memref<1xf32> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_2_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_2_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_11_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_2_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_3_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addf [[VAR_29_]], [[VAR_28_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_32_]], [[VAR_31_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_:%.+]] = arith.addf [[VAR_35_]], [[VAR_34_]] : vector<4xf32> +// CHECK: [[VAR_37_:%.+]] = arith.divf [[VAR_36_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[VAR_reshape_4_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[RES_]]([[RES_]]_8) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_7_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_4_]]([[RES_7_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_7_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_]], [[VAR_reshape_11_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_9_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_10_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_16_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_11_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_18_:%.+]] = memref.reshape [[RES_8_]]([[RES_11_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__2_]]) : !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 11904){ +// CHECK: [[VAR_9_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_14_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_16_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_]], [[VAR_reshape_18_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_13_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_13_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_13_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_13_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_8_]]([[RES_13_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_14_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_14_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_14_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[RES_12_]]([[RES_14_]]) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_3_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__3_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 2, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_15_:%.+]] = memref.alloc() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_3_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_11_2_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_1_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_7_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_8_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_12_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_12_]], [[LOAD_VAR_reshape_MEM_8_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_1_]] : f32 +// CHECK-DAG: [[VAR_21_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_1_:%.+]] = arith.addf [[VAR_29_1_]], [[VAR_28_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_1_:%.+]] = arith.addf [[VAR_32_1_]], [[VAR_31_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_1_:%.+]] = arith.addf [[VAR_35_1_]], [[VAR_34_1_]] : vector<4xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_1_]], [[VAR_reshape_23_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_16_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_17_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_17_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_28_:%.+]] = memref.reshape [[RES_12_]]([[RES_17_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_18_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_18_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_30_:%.+]] = memref.reshape [[RES_4_]]([[RES_18_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_19_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_19_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_32_:%.+]] = memref.reshape [[RES_16_]]([[RES_19_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_4_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__4_]]) with ([[LOOP_4_]] -> [[I_8_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__4_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_2_:%.+]] = vector.load [[VAR_reshape_28_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_2_:%.+]] = vector.load [[VAR_reshape_30_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = arith.subf [[VAR_10_2_]], [[VAR_11_2_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_1_]], [[VAR_reshape_32_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_20_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_21_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_21_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_21_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_21_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_35_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_21_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_22_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_22_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_22_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_22_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_37_:%.+]] = memref.reshape [[RES_]]([[RES_]]_36) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_23_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_23_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_23_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_23_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_39_:%.+]] = memref.reshape [[RES_20_]]([[RES_23_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_5_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.parallel([[LOOP_5_]]#1) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_9_:%.+]] = 0 to 2, [[LOOP_5_]]#1 -> [[I_10_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_6_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__5_:%.+]], [[BLOCK_IN__5_:%.+]] = krnl.block [[LOOP_6_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__5_]]) with ([[LOOP_6_]] -> [[I_11_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__5_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = vector.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_7_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_7_]]) with ([[LOOP_7_]] -> [[I_12_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_4_:%.+]] = krnl.get_induction_var_value([[LOOP_7_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = krnl.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_24_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_25_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_25_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_42_:%.+]] = memref.reshape [[RES_16_]]([[RES_25_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_26_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_26_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_44_:%.+]] = memref.reshape [[RES_24_]]([[RES_26_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_8_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__6_:%.+]], [[BLOCK_IN__6_:%.+]] = krnl.block [[LOOP_8_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__6_]]) with ([[LOOP_8_]] -> [[I_13_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__6_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_7_:%.+]] = vector.load [[VAR_reshape_42_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_4_:%.+]] = krnl.load [[VAR_0_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = vector.splat [[VAR_11_4_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = arith.addf [[LOOP_7_]], [[LOAD_VAR_reshape_MEM_4_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_5_1_1_]], [[VAR_reshape_44_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_27_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_28_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_28_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_47_:%.+]] = memref.reshape [[RES_24_]]([[RES_28_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_29_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_29_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_49_:%.+]] = memref.reshape [[RES_27_]]([[RES_29_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_9_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__7_:%.+]], [[BLOCK_IN__7_:%.+]] = krnl.block [[LOOP_9_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__7_]]) with ([[LOOP_9_]] -> [[I_14_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_6_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__7_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_47_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_5_:%.+]] = math.sqrt [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_5_]], [[VAR_reshape_49_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_30_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_31_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_31_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_52_:%.+]] = memref.reshape [[RES_27_]]([[RES_31_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_32_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_32_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_54_:%.+]] = memref.reshape [[RES_30_]]([[RES_32_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_10_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__8_:%.+]], [[BLOCK_IN__8_:%.+]] = krnl.block [[LOOP_10_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__8_]]) with ([[LOOP_10_]] -> [[I_15_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__8_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_52_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_6_:%.+]] = arith.divf [[VAR_cst_]], [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_6_]], [[VAR_reshape_54_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_33_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_34_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_34_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_34_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_34_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_57_:%.+]] = memref.reshape [[RES_20_]]([[RES_34_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_35_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_35_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_35_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_35_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_59_:%.+]] = memref.reshape [[RES_30_]]([[RES_35_]]) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_36_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_36_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_36_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_36_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_61_:%.+]] = memref.reshape [[RES_33_]]([[RES_36_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_11_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.parallel([[LOOP_11_]]#1) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_11_]]#0, [[LOOP_11_]]#1) with ([[LOOP_11_]]#0 -> [[I_16_:%.+]] = 0 to 2, [[LOOP_11_]]#1 -> [[I_17_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_11_]]#0, [[LOOP_11_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_12_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__9_:%.+]], [[BLOCK_IN__9_:%.+]] = krnl.block [[LOOP_12_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__9_]]) with ([[LOOP_12_]] -> [[I_18_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__9_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_13_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_13_]]) with ([[LOOP_13_]] -> [[I_19_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_8_:%.+]] = krnl.get_induction_var_value([[LOOP_13_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_37_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_38_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_38_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_38_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_38_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_64_:%.+]] = memref.reshape [[RES_33_]]([[RES_38_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_39_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_39_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_66_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_39_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_40_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_40_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_40_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_40_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_68_:%.+]] = memref.reshape [[RES_37_]]([[RES_40_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_14_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.parallel([[LOOP_14_]]#1) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_14_]]#0, [[LOOP_14_]]#1) with ([[LOOP_14_]]#0 -> [[I_20_:%.+]] = 0 to 2, [[LOOP_14_]]#1 -> [[I_21_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_4_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_14_]]#0, [[LOOP_14_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_15_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__10_:%.+]], [[BLOCK_IN__10_:%.+]] = krnl.block [[LOOP_15_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__10_]]) with ([[LOOP_15_]] -> [[I_22_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_9_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__10_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = vector.load [[VAR_reshape_66_]]{{.}}[[VAR_11_9_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_16_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_16_]]) with ([[LOOP_16_]] -> [[I_23_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_10_:%.+]] = krnl.get_induction_var_value([[LOOP_16_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_66_]]{{.}}[[VAR_11_10_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK: [[RES_41_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[RES_41_]] : memref<2x64x31x3xf32> to tensor<2x64x31x3xf32> +// CHECK-DAG: [[RES_42_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_42_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_42_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_42_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_71_:%.+]] = memref.reshape [[RES_37_]]([[RES_42_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_43_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_43_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_73_:%.+]] = memref.reshape [[PARAM_2_]]([[RES_43_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_44_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_44_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_44_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_44_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_75_:%.+]] = memref.reshape [[RES_41_]]([[RES_44_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_17_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.parallel([[LOOP_17_]]#1) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_17_]]#0, [[LOOP_17_]]#1) with ([[LOOP_17_]]#0 -> [[I_24_:%.+]] = 0 to 2, [[LOOP_17_]]#1 -> [[I_25_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_17_]]#0, [[LOOP_17_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_18_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__11_:%.+]], [[BLOCK_IN__11_:%.+]] = krnl.block [[LOOP_18_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__11_]]) with ([[LOOP_18_]] -> [[I_26_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_11_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__11_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_73_]]{{.}}[[VAR_11_11_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_19_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_19_]]) with ([[LOOP_19_]] -> [[I_27_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_12_:%.+]] = krnl.get_induction_var_value([[LOOP_19_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_73_]]{{.}}[[VAR_11_12_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK: onnx.Return [[VAR_6_]] : tensor<2x64x31x3xf32> +// CHECK: } +} + +// ----- + +// arg1 is defined for every outer loop, arg2 is defined for 64 of the 128 outer loops. +func.func @layernorm_4D_with_scale_bias_with_high_dims(%arg0: tensor<2x64x32x8xf32>, %arg1: tensor<2x64x32x8xf32>, %arg2: tensor<64x32x8xf32>) -> tensor<*xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %Y, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {axis = -2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x64x32x8xf32>, tensor<2x64x32x8xf32>, tensor<64x32x8xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + onnx.Return %Y : tensor<*xf32> + +// CHECK-LABEL: func.func @layernorm_4D_with_scale_bias_with_high_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x64x32x8xf32>, [[PARAM_1_:%.+]]: memref<2x64x32x8xf32>, [[PARAM_2_:%.+]]: memref<64x32x8xf32>) -> memref<2x64x32x8xf32> { +// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_1_:%.+]]) : (memref<2x64x32x8xf32>, memref<2xindex>) -> memref<128x256xf32> +// CHECK-DAG: [[VAR_reshape_6_:%.+]] = memref.reshape [[PARAM_2_]]([[RES_2_:%.+]]) : (memref<64x32x8xf32>, memref<2xindex>) -> memref<64x256xf32> +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_:%.+]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_:%.+]] -> [[I_0_:%.+]] = 0 to 128){ +// CHECK: affine.for [[I_2_:%.+]] = 0 to 256 step 16 { +// CHECK: affine.for [[I_2_:%.+]] = 0 to 256 step 16 { +} diff --git a/test/mlir/conversion/onnx_to_krnl/NN/Pooling.mlir b/test/mlir/conversion/onnx_to_krnl/NN/Pooling.mlir index 47d1ab873c..87a6ce9acd 100644 --- a/test/mlir/conversion/onnx_to_krnl/NN/Pooling.mlir +++ b/test/mlir/conversion/onnx_to_krnl/NN/Pooling.mlir @@ -110,6 +110,14 @@ func.func private @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) // CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0, d1) -> (0, d1)> // CHECK-DAG: [[MAP_7_:#.+]] = affine_map<(d0, d1) -> (32, d1 + 2)> // CHECK-DAG: [[MAP_8_:#.+]] = affine_map<(d0)[s0, s1, s2, s3, s4] -> (s0 - ((s2 ceildiv s4) * s4 - s2), -(d0 * s3 - s2) + s0, d0 * s3 + (s1 - 1) * s4 - s2 - ((s2 ceildiv s4) * s4 - s2) + 1, d0 * s3 + (s1 - 1) * s4 - s2 - (d0 * s3 - s2) + 1)> +// CHECK-DAG: [[MAP_9_:#.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: [[MAP_10_:#.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: [[MAP_11_:#.+]] = affine_map<(d0, d1, d2) -> (d2 + d0)> +// CHECK-DAG: [[MAP_12_:#.+]] = affine_map<(d0, d1, d2) -> (d2, d2 + d0)> +// CHECK-DAG: [[MAP_13_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: [[MAP_14_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: [[MAP_15_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d3 + d1)> +// CHECK-DAG: [[MAP_16_:#.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d3 + d1)> // CHECK-LABEL: func.func private @test_maxpool_pooling_operation // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x32x32xf32>) -> memref<1x3x31x31xf32> { // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index @@ -181,13 +189,21 @@ func.func private @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) // CHECK-DAG: [[VAR_13_:%.+]] = arith.subi [[VAR_10_]], [[VAR_9_]] : index // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to min [[MAP_8_]]([[VAR_1_]]#2){{.}}[[CST_32_2_]], [[CST_2_4_]], [[CST_0_7_]], [[CST_1_9_]], [[CST_1_10_]]{{.}}, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to min [[MAP_8_]]([[VAR_1_]]#3){{.}}[[CST_32_3_]], [[CST_2_5_]], [[CST_0_8_]], [[CST_1_11_]], [[CST_1_12_]]{{.}}){ -// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[I_4_]], [[VAR_4_]] : index -// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[I_5_]], [[VAR_9_]] : index +// CHECK-DAG: [[CST_0_11_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_16_:%.+]] = affine.apply [[MAP_9_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]]) +// CHECK-DAG: [[VAR_17_:%.+]] = affine.apply [[MAP_10_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]]) +// CHECK-DAG: [[VAR_18_:%.+]] = affine.apply [[MAP_11_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]]) +// CHECK-DAG: [[VAR_19_:%.+]] = affine.max [[MAP_12_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]]) +// CHECK-DAG: [[CST_0_12_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_20_:%.+]] = affine.apply [[MAP_13_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]], [[I_5_]]) +// CHECK-DAG: [[VAR_21_:%.+]] = affine.apply [[MAP_14_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]], [[I_5_]]) +// CHECK-DAG: [[VAR_22_:%.+]] = affine.apply [[MAP_15_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]], [[I_5_]]) +// CHECK-DAG: [[VAR_23_:%.+]] = affine.max [[MAP_16_]]([[VAR_1_]]#2, [[VAR_1_]]#3, [[I_4_]], [[I_5_]]) // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]6, [[VAR_1_]]7] : memref<1x3x32x32xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]9, [[VAR_23_]]{{.}} : memref<1x3x32x32xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref -// CHECK: [[VAR_20_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_20_]], [[RES_1_]][] : memref +// CHECK: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_26_]], [[RES_1_]][] : memref // CHECK: } // CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = krnl.load [[RES_1_]][] : memref // CHECK: krnl.store [[LOAD_RES_1_MEM_1_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]#2, [[VAR_1_]]#3] : memref<1x3x31x31xf32> diff --git a/test/mlir/conversion/onnx_to_krnl/NN/Pooling_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/NN/Pooling_with_canonicalize.mlir index 549ca2a6f0..0667a6702a 100644 --- a/test/mlir/conversion/onnx_to_krnl/NN/Pooling_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/NN/Pooling_with_canonicalize.mlir @@ -1,5 +1,7 @@ // RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// ----- + // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. @@ -7,55 +9,58 @@ func.func private @test_pool_unknown_dimensions(%arg0 : tensor<1x3x?x32xf32>) -> %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x?x32xf32>) -> tensor<*xf32> "func.return"(%0) : (tensor<*xf32>) -> () -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 - 1)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (0, d0)> -// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (s0, d0 + 2)> -// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (32, d0 + 2)> -// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0, s1, s2, s3, s4] -> (s0 - ((s2 ceildiv s4) * s4 - s2), -(d0 * s3 - s2) + s0, d0 * s3 + (s1 - 1) * s4 - s2 - ((s2 ceildiv s4) * s4 - s2) + 1, d0 * s3 + (s1 - 1) * s4 - s2 - (d0 * s3 - s2) + 1)> -// CHECK-LABEL: func private @test_pool_unknown_dimensions +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 - 1)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0)[s0] -> (s0, d0 + 2)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (32, d0 + 2)> +// CHECK-DAG: [[MAP_5_:#.+]] = affine_map<(d0)[s0, s1, s2, s3, s4] -> (s0 - ((s2 ceildiv s4) * s4 - s2), -(d0 * s3 - s2) + s0, d0 * s3 + (s1 - 1) * s4 - s2 - ((s2 ceildiv s4) * s4 - s2) + 1, d0 * s3 + (s1 - 1) * s4 - s2 - (d0 * s3 - s2) + 1)> +// CHECK-DAG: [[MAP_6_:#.+]] = affine_map<(d0, d1) -> (d1, d0 + d1)> +// CHECK-LABEL: func.func private @test_pool_unknown_dimensions // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x3x?x32xf32>) -> memref<1x3x?x31xf32> { -// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x3x?x32xf32> -// CHECK: [[VAR_1_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_0_]]{{.}} -// CHECK-DAG: [[VAR_2_:%.+]] = memref.alloc([[VAR_1_]]) {{.*}}: memref<1x3x?x31xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = memref.alloca() : memref +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x3x?x32xf32> +// CHECK: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<1x3x?x31xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() : memref // CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_1_]]([[VAR_1_]]), [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 31){ -// CHECK: [[IV:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK: krnl.store [[CST_0_dot_000000_]], [[VAR_4_]][] : memref -// CHECK-DAG: [[VAR_5_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x3x?x32xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = affine.max [[MAP_2_]]([[IV]]#2) +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to [[MAP_1_]]([[VAR_0_]]), [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 31){ +// CHECK: [[VAR_2_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) +// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_1_]][] : memref +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_2_]] : memref<1x3x?x32xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = affine.max [[MAP_2_]]([[VAR_2_]]#2) // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = affine.min [[MAP_3_]]([[IV]]#2){{.}}[[VAR_5_]]{{.}} -// CHECK-DAG: [[VAR_8_:%.+]] = affine.max [[MAP_2_]]([[IV]]#3) -// CHECK-DAG: [[VAR_9_:%.+]] = affine.min [[MAP_4_]]([[IV]]#3) +// CHECK-DAG: [[VAR_4_:%.+]] = affine.min [[MAP_3_]]([[VAR_2_]]#2){{.}}[[VAR_dim_0_]]{{.}} +// CHECK-DAG: [[VAR_5_:%.+]] = affine.max [[MAP_2_]]([[VAR_2_]]#3) +// CHECK-DAG: [[VAR_6_:%.+]] = affine.min [[MAP_4_]]([[VAR_2_]]#3) // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.subi [[VAR_7_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[VAR_9_]], [[VAR_8_]] : index +// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_4_]], [[VAR_3_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.subi [[VAR_6_]], [[VAR_5_]] : index // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to min [[MAP_5_]]([[IV]]#2){{.}}[[VAR_5_]], [[CST_2_]], [[CST_0_]], [[CST_1_]], [[CST_1_]]{{.}}, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to min [[MAP_5_]]([[IV]]#3){{.}}[[CST_32_]], [[CST_2_]], [[CST_0_]], [[CST_1_]], [[CST_1_]]{{.}}){ -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[I_4_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[I_5_]], [[VAR_8_]] : index +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to min [[MAP_5_]]([[VAR_2_]]#2){{.}}[[VAR_dim_0_]], [[CST_2_]], [[CST_0_]], [[CST_1_]], [[CST_1_]]{{.}}, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to min [[MAP_5_]]([[VAR_2_]]#3){{.}}[[CST_32_]], [[CST_2_]], [[CST_0_]], [[CST_1_]], [[CST_1_]]{{.}}){ +// CHECK-DAG: [[VAR_16_:%.+]] = affine.max [[MAP_6_]]([[VAR_2_]]#2, [[I_4_]]) +// CHECK-DAG: [[VAR_17_:%.+]] = affine.max [[MAP_6_]]([[VAR_2_]]#3, [[I_5_]]) // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[IV]]#0, [[IV]]#1, [[VAR_19_]], [[VAR_20_]]{{.}} : memref<1x3x?x32xf32> -// CHECK-DAG: [[LOAD_VAR_4_MEM_:%.+]] = krnl.load [[VAR_4_]][] : memref -// CHECK: [[VAR_23_:%.+]] = arith.addf [[LOAD_VAR_4_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_23_]], [[VAR_4_]][] : memref +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_16_]], [[VAR_17_]]{{.}} : memref<1x3x?x32xf32> +// CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_20_]], [[RES_1_]][] : memref // CHECK: } -// CHECK: [[LOAD_VAR_4_MEM_1_:%.+]] = krnl.load [[VAR_4_]][] : memref -// CHECK: krnl.store [[LOAD_VAR_4_MEM_1_]], [[VAR_2_]]{{.}}[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3{{.}} : memref<1x3x?x31xf32> -// CHECK-DAG: [[LOAD_VAR_2_MEM_:%.+]] = krnl.load [[VAR_2_]]{{.}}[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3{{.}} : memref<1x3x?x31xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[VAR_10_]], [[VAR_11_]] : index -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64 -// CHECK: [[VAR_17_:%.+]] = arith.sitofp [[VAR_16_]] : i64 to f32 -// CHECK: [[VAR_18_:%.+]] = arith.divf [[LOAD_VAR_2_MEM_]], [[VAR_17_]] : f32 -// CHECK: krnl.store [[VAR_18_]], [[VAR_2_]]{{.}}[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3{{.}} : memref<1x3x?x31xf32> +// CHECK: [[LOAD_RES_1_MEM_1_:%.+]] = krnl.load [[RES_1_]][] : memref +// CHECK: krnl.store [[LOAD_RES_1_MEM_1_]], [[RES_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x3x?x31xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x3x?x31xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.muli [[VAR_7_]], [[VAR_8_]] : index +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64 +// CHECK: [[VAR_14_:%.+]] = arith.sitofp [[VAR_13_]] : i64 to f32 +// CHECK: [[VAR_15_:%.+]] = arith.divf [[LOAD_RES_MEM_]], [[VAR_14_]] : f32 +// CHECK: krnl.store [[VAR_15_]], [[RES_]]{{.}}[[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_2_]]#3] : memref<1x3x?x31xf32> // CHECK: } -// CHECK: return [[VAR_2_]] : memref<1x3x?x31xf32> +// CHECK: return [[RES_]] : memref<1x3x?x31xf32> // CHECK: } } + diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir index f6b022444a..4873f5a1a4 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir @@ -3,10 +3,14 @@ // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. +// ----- + + func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_dequantizelinear_i8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> @@ -16,12 +20,13 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %ar // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xi8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[LOAD_PARAM_0_MEM_]] : i8 to f32 -// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_6_]], [[VAR_5_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.mulf [[VAR_7_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: [[VAR_5_:%.+]] = arith.extsi [[LOAD_PARAM_0_MEM_]] : i8 to i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[VAR_5_]] : i32 to f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32 +// CHECK: [[VAR_8_:%.+]] = arith.sitofp [[VAR_7_]] : i32 to f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<4xf32> // CHECK: } @@ -29,10 +34,12 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %ar // ----- + func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_dequantizelinear_ui8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> @@ -42,13 +49,15 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, % // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32 -// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_8_]], [[VAR_6_]] : f32 -// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.uitofp [[VAR_6_]] : i32 to f32 +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_9_:%.+]] = arith.extui [[VAR_8_]] : i8 to i32 +// CHECK: [[VAR_10_:%.+]] = arith.uitofp [[VAR_9_]] : i32 to f32 +// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_7_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<4xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir index 55dbdb1942..d8dd788672 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir @@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref +// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref +// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -75,45 +75,47 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor -// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref -// CHECK-DAG: [[VAR_28_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK: krnl.store [[VAR_28_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[VAR_28_]], [[RES_5_]][0] : memref<1xindex> +// CHECK: affine.store [[VAR_29_]], [[RES_5_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_30_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex> +// CHECK: affine.store [[VAR_30_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32 -// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32 -// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32 +// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32 +// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32 -// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32 -// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i8 -// CHECK: [[VAR_52_:%.+]] = builtin.unrealized_conversion_cast [[VAR_51_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_52_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32 +// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32 +// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : f32 to i32 +// CHECK: [[VAR_53_:%.+]] = arith.trunci [[VAR_52_]] : i32 to i8 +// CHECK: [[VAR_54_:%.+]] = builtin.unrealized_conversion_cast [[VAR_53_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_54_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_2_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index d7180c3e2c..2682e155c4 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt -O3 -mcpu=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --march=z16 --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. @@ -13,16 +13,10 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4096_:%.+]] = arith.constant 4096 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 @@ -33,100 +27,87 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_3_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> -// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref -// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> -// CHECK: vector.store [[VAR_cst_5_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() : memref -// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = [[CST_0_]] to [[CST_4096_]]){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_35_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: [[LOAD_RES_7_MEM_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: [[VAR_37_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4096){ +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } -// CHECK: [[LOAD_RES_5_MEM_1_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = vector.reduction , [[LOAD_RES_5_MEM_1_]] : vector<32xf32> into f32 -// CHECK-DAG: [[LOAD_RES_7_MEM_1_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_7_MEM_1_]] : vector<32xf32> into f32 -// CHECK: krnl.store [[VAR_2_]], [[RES_4_]][] : memref -// CHECK: krnl.store [[VAR_4_]], [[RES_6_]][] : memref -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = krnl.load [[RES_6_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_1_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_1_]] : vector<32xf32> into f32 +// CHECK: krnl.store [[VAR_3_]], [[RES_5_]][] : memref +// CHECK: krnl.store [[VAR_4_]], [[RES_7_]][] : memref +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK-DAG: [[LOAD_RES_7_MEM_:%.+]] = krnl.load [[RES_7_]][] : memref // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 // CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[CST_2_dot_550000_]] : f32 // CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_8_]], [[VAR_10_]] : f32 // CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 // CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 -// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_5_MEM_2_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> -// CHECK: [[VAR_35_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_1_]], [[LOAD_RES_5_MEM_2_]] : vector<8xf32> -// CHECK: [[LOAD_RES_7_MEM_2_:%.+]] = math.floor [[VAR_35_1_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[VAR_35_1_]], [[LOAD_RES_7_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_]], [[VAR_39_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_7_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<8xui8> +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_32_:%.+]] = "krnl.round_even"([[VAR_31_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.insert [[VAR_32_]], [[VAR_30_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_35_:%.+]] = "krnl.round_even"([[VAR_34_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_36_:%.+]] = vector.insert [[VAR_35_]], [[VAR_33_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = vector.shape_cast [[VAR_36_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = vector.splat [[VAR_15_]] : vector<16xf32> +// CHECK: [[VAR_39_:%.+]] = arith.addf [[VAR_37_]], [[VAR_38_]] : vector<16xf32> +// CHECK: [[VAR_40_:%.+]] = arith.maxnumf [[VAR_39_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.minnumf [[VAR_40_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.fptoui [[VAR_41_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_43_:%.+]] = arith.trunci [[VAR_42_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_44_:%.+]] = builtin.unrealized_conversion_cast [[VAR_43_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_44_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<256x16xui8>, memref, memref // CHECK: } } @@ -140,16 +121,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4335_:%.+]] = arith.constant 4335 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 @@ -158,124 +134,115 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_3_]][0] : memref<1xindex> -// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref -// CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref -// CHECK: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 255, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 17){ -// CHECK: [[VAR_30_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_30_]]#0, [[VAR_30_]]#1] : memref<255x17xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_33_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_33_]], [[RES_4_]][] : memref -// CHECK: } -// CHECK: [[RES_5_:%.+]] = memref.alloc() : memref -// CHECK: krnl.memset [[RES_5_]], [[CST_0_]] : memref -// CHECK: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to 255, [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 17){ -// CHECK: [[VAR_30_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_30_1_]]#0, [[VAR_30_1_]]#1] : memref<255x17xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_5_]][] : memref -// CHECK: [[VAR_33_1_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_33_1_]], [[RES_5_]][] : memref +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4304){ +// CHECK: [[VAR_22_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_22_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_22_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_27_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 4320 to 4335){ +// CHECK: [[VAR_22_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_22_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_22_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_2_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 -// CHECK: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.divf [[VAR_5_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_8_]] : f32 -// CHECK: [[VAR_10_:%.+]] = arith.maxnumf [[VAR_9_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_11_:%.+]] = arith.minnumf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_12_:%.+]] = math.floor [[VAR_11_]] : f32 -// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_11_]], [[VAR_12_]] : f32 -// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf ogt, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_28_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: krnl.store [[VAR_27_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: krnl.store [[VAR_28_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_12_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.mulf [[VAR_12_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_18_:%.+]] = math.floor [[VAR_17_]] : f32 -// CHECK: [[VAR_19_:%.+]] = arith.mulf [[VAR_18_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.subf [[VAR_12_]], [[VAR_19_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_20_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[VAR_12_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_5_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: krnl.store [[VAR_4_]], [[RES_5_]][] : memref +// CHECK: krnl.store [[VAR_5_]], [[RES_7_]][] : memref +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK-DAG: [[LOAD_RES_7_MEM_:%.+]] = krnl.load [[RES_7_]][] : memref // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_12_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.select [[VAR_24_]], [[VAR_23_]], [[VAR_16_]] : f32 -// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i8 -// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref -// CHECK: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[CST_4335_]], [[RES_6_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_6_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> -// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[CST_4335_]], [[RES_7_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_18_:%.+]] = memref.reshape [[RES_]]([[RES_]]_17) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.divf [[VAR_9_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref +// CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_8_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4328){ -// CHECK: [[VAR_30_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_30_2_]]{{.}} : memref<4335xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.splat [[VAR_7_]] : vector<8xf32> -// CHECK: [[VAR_33_2_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[LOAD_RES_4_MEM_1_]] : vector<8xf32> -// CHECK: [[VAR_34_:%.+]] = math.floor [[VAR_33_2_]] : vector<8xf32> -// CHECK: [[VAR_35_:%.+]] = arith.subf [[VAR_33_2_]], [[VAR_34_]] : vector<8xf32> -// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : vector<8xf32> -// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_]], [[VAR_41_]] : vector<8xf32> -// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = vector.splat [[VAR_25_]] : vector<8xf32> -// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_47_]], [[VAR_48_]] : vector<8xf32> -// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_53_:%.+]] = builtin.unrealized_conversion_cast [[VAR_52_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_53_]], [[VAR_reshape_18_]]{{.}}[[VAR_30_2_]]{{.}} : memref<4335xui8>, vector<8xui8> +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ +// CHECK: [[VAR_22_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_27_2_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_28_2_:%.+]] = "krnl.round_even"([[VAR_27_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = vector.insert [[VAR_28_2_]], [[LOAD_RES_6_MEM_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_31_:%.+]] = "krnl.round_even"([[VAR_30_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.insert [[VAR_31_]], [[VAR_29_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_34_:%.+]] = "krnl.round_even"([[VAR_33_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.insert [[VAR_34_]], [[VAR_32_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_37_:%.+]] = "krnl.round_even"([[VAR_36_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_38_:%.+]] = vector.insert [[VAR_37_]], [[VAR_35_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = vector.shape_cast [[VAR_38_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.addf [[VAR_39_]], [[VAR_40_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.maxnumf [[VAR_41_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.minnumf [[VAR_42_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.fptoui [[VAR_43_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_45_:%.+]] = arith.trunci [[VAR_44_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_46_:%.+]] = builtin.unrealized_conversion_cast [[VAR_45_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_46_]], [[VAR_reshape_15_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4328 to 4335){ -// CHECK: [[VAR_30_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_30_3_]]{{.}} : memref<4335xf32> -// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_33_3_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK: [[VAR_34_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_33_3_]] : f32 -// CHECK-DAG: [[VAR_35_1_:%.+]] = arith.cmpf ogt, [[VAR_34_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_36_1_:%.+]] = arith.addf [[VAR_33_3_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_1_:%.+]] = arith.select [[VAR_35_1_]], [[VAR_36_1_]], [[VAR_33_3_]] : f32 -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.mulf [[VAR_33_3_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : f32 -// CHECK: [[VAR_40_1_:%.+]] = arith.mulf [[VAR_39_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_41_1_:%.+]] = arith.subf [[VAR_33_3_]], [[VAR_40_1_]] : f32 -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.cmpf oeq, [[VAR_41_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.addf [[VAR_33_3_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_44_1_:%.+]] = arith.select [[VAR_42_1_]], [[VAR_43_1_]], [[VAR_33_3_]] : f32 -// CHECK-DAG: [[VAR_45_1_:%.+]] = arith.cmpf oeq, [[VAR_34_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_46_1_:%.+]] = arith.select [[VAR_45_1_]], [[VAR_44_1_]], [[VAR_37_1_]] : f32 -// CHECK: [[VAR_47_1_:%.+]] = arith.addf [[VAR_46_1_]], [[VAR_25_]] : f32 -// CHECK: [[VAR_48_1_:%.+]] = arith.maxnumf [[VAR_47_1_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_49_1_:%.+]] = arith.minnumf [[VAR_48_1_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_50_1_:%.+]] = arith.fptoui [[VAR_49_1_]] : f32 to i8 -// CHECK: [[VAR_51_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_50_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_51_1_]], [[VAR_reshape_18_]]{{.}}[[VAR_30_3_]]{{.}} : memref<4335xui8> +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ +// CHECK: [[VAR_22_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_13_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 +// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_1_]]) : (f32) -> f32 +// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_27_3_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_28_3_:%.+]] = arith.minnumf [[VAR_27_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_29_1_:%.+]] = arith.fptoui [[VAR_28_3_]] : f32 to i32 +// CHECK: [[VAR_30_1_:%.+]] = arith.trunci [[VAR_29_1_]] : i32 to i8 +// CHECK: [[VAR_31_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_31_1_]], [[VAR_reshape_15_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<255x17xui8>, memref, memref // CHECK: } } @@ -291,14 +258,8 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x8xf32>) -> (memref<1x8xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 @@ -309,100 +270,81 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_3_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> -// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() : memref -// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> -// CHECK: vector.store [[VAR_cst_5_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() : memref -// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = [[CST_0_]] to [[CST_8_]]){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK: vector.store [[VAR_35_]], [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: [[LOAD_RES_7_MEM_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: [[VAR_37_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } -// CHECK: [[LOAD_RES_5_MEM_1_:%.+]] = vector.load [[RES_5_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = vector.reduction , [[LOAD_RES_5_MEM_1_]] : vector<8xf32> into f32 -// CHECK-DAG: [[LOAD_RES_7_MEM_1_:%.+]] = vector.load [[RES_7_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_7_MEM_1_]] : vector<8xf32> into f32 -// CHECK: krnl.store [[VAR_2_]], [[RES_4_]][] : memref -// CHECK: krnl.store [[VAR_4_]], [[RES_6_]][] : memref -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = krnl.load [[RES_6_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[CST_0_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_3_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_1_]] : vector<8xf32> into f32 +// CHECK-DAG: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_1_]] : vector<8xf32> into f32 +// CHECK: krnl.store [[VAR_3_]], [[RES_5_]][] : memref +// CHECK: krnl.store [[VAR_4_]], [[RES_7_]][] : memref +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK-DAG: [[LOAD_RES_7_MEM_:%.+]] = krnl.load [[RES_7_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 // CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[CST_2_dot_550000_]] : f32 // CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_8_]], [[VAR_10_]] : f32 // CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 // CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 -// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_RES_5_MEM_2_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> -// CHECK: [[VAR_35_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_1_]], [[LOAD_RES_5_MEM_2_]] : vector<8xf32> -// CHECK: [[LOAD_RES_7_MEM_2_:%.+]] = math.floor [[VAR_35_1_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[VAR_35_1_]], [[LOAD_RES_7_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_]], [[VAR_39_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_7_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_7_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_7_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi8> -// CHECK: [[VAR_55_:%.+]] = builtin.unrealized_conversion_cast [[VAR_54_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_55_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<8xf32> to vector<2x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shape_cast [[VAR_30_]] : vector<2x4xf32> to vector<8xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.splat [[VAR_15_]] : vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_31_]], [[VAR_32_]] : vector<8xf32> +// CHECK: [[VAR_34_:%.+]] = arith.maxnumf [[VAR_33_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[VAR_34_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_36_:%.+]] = arith.fptoui [[VAR_35_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_37_:%.+]] = arith.trunci [[VAR_36_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_38_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_38_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<1x8xui8>, memref, memref // CHECK: } } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir new file mode 100644 index 0000000000..269c85f9ce --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -0,0 +1,443 @@ +// RUN: onnx-mlir-opt -O3 --march=arch14 --shape-inference --convert-onnx-to-krnl=enable-parallel --canonicalize %s -split-input-file | FileCheck %s +// above: used --march=arch14 instead of --march=z16 on purpose to make sure either option works + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// ----- + + +func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> (tensor<256x16xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<256x16xf32>) -> (tensor<256x16xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<256x16xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 * 512)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (4096, d0 * 512 + 512)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0] -> (s0 - 31)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<()[s0] -> ((s0 floordiv 32) * 32)> +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_only +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4096_:%.+]] = arith.constant 4096 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<256x16xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_3_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_21_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_:%.+]] = affine.apply [[MAP_0_]]([[VAR_21_]]) +// CHECK-DAG: [[VAR_23_:%.+]] = affine.min [[MAP_1_]]([[VAR_21_]]) +// CHECK-DAG: [[VAR_24_:%.+]] = affine.apply [[MAP_2_]]([[VAR_21_]]) +// CHECK: vector.store [[VAR_cst_4_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_3_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_23_]]{{.}} +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_22_]] to [[VAR_25_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_35_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_4_]](){{.}}[[VAR_23_]]{{.}} +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_26_]] to [[VAR_23_]] step [[CST_1_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_35_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_36_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_35_1_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_36_1_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_29_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_30_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_29_]], [[RES_5_]]{{.}}[[VAR_21_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_30_]], [[RES_7_]]{{.}}[[VAR_21_]]{{.}} : memref<8xf32> +// CHECK: } +// CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ +// CHECK: [[VAR_21_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_21_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_23_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_21_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_22_1_]] : f32 +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_23_1_]] : f32 +// CHECK: krnl.store [[VAR_26_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = vector.extract [[LOAD_RES_4_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = vector.extract [[LOAD_RES_6_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK: krnl.store [[VAR_4_]], [[RES_8_]][] : memref +// CHECK: krnl.store [[VAR_5_]], [[RES_9_]][] : memref +// CHECK-DAG: [[LOAD_RES_8_MEM_:%.+]] = krnl.load [[RES_8_]][] : memref +// CHECK-DAG: [[LOAD_RES_9_MEM_:%.+]] = krnl.load [[RES_9_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_9_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.minnumf [[LOAD_RES_8_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.divf [[VAR_9_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref +// CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_10_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_17_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4096_]], [[RES_11_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[RES_]]([[RES_]]_18) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ +// CHECK: [[VAR_21_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_1_:%.+]] = vector.load [[VAR_reshape_17_]]{{.}}[[VAR_21_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_23_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_24_1_:%.+]] = arith.divf [[VAR_22_1_]], [[VAR_23_2_]] : vector<16xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.shape_cast [[VAR_24_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_26_2_:%.+]] = vector.extract [[VAR_25_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = "krnl.round_even"([[VAR_26_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.insert [[LOAD_RES_4_MEM_2_1_]], [[VAR_25_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.extract [[VAR_25_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_30_1_:%.+]] = "krnl.round_even"([[VAR_29_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.insert [[VAR_30_1_]], [[LOAD_RES_6_MEM_2_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.extract [[VAR_25_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.insert [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.extract [[VAR_25_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_36_2_:%.+]] = "krnl.round_even"([[VAR_35_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_37_:%.+]] = vector.insert [[VAR_36_2_]], [[LOAD_RES_6_MEM_1_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = vector.shape_cast [[VAR_37_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_40_:%.+]] = arith.addf [[VAR_38_]], [[VAR_39_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.maxnumf [[VAR_40_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.minnumf [[VAR_41_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.fptoui [[VAR_42_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_44_:%.+]] = arith.trunci [[VAR_43_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_45_:%.+]] = builtin.unrealized_conversion_cast [[VAR_44_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_45_]], [[VAR_reshape_19_]]{{.}}[[VAR_21_2_]]{{.}} : memref<4096xui8>, vector<16xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_7, [[RES_]]_8 : memref<256x16xui8>, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32>) -> (tensor<255x17xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<255x17xf32>) -> (tensor<255x17xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<255x17xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 * 542)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (4335, d0 * 542 + 542)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<()[s0] -> (s0 - 31)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0)[s0] -> (d0 * 542 + ((d0 * -542 + s0) floordiv 32) * 32)> +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_simd_and_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_4335_:%.+]] = arith.constant 4335 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<255x17xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_3_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<256xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref<8xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_22_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_:%.+]] = affine.apply [[MAP_0_]]([[VAR_22_]]) +// CHECK-DAG: [[VAR_24_:%.+]] = affine.min [[MAP_1_]]([[VAR_22_]]) +// CHECK-DAG: [[VAR_25_:%.+]] = affine.apply [[MAP_2_]]([[VAR_22_]]) +// CHECK: vector.store [[VAR_cst_4_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_3_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_3_]](){{.}}[[VAR_24_]]{{.}} +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_23_]] to [[VAR_26_]] step [[CST_32_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_4_]]([[VAR_22_]]){{.}}[[VAR_24_]]{{.}} +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_27_]] to [[VAR_24_]] step [[CST_1_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_37_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_36_1_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_37_1_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_31_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_30_]], [[RES_5_]]{{.}}[[VAR_22_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_31_]], [[RES_7_]]{{.}}[[VAR_22_]]{{.}} : memref<8xf32> +// CHECK: } +// CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ +// CHECK: [[VAR_22_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_22_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_24_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_22_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_23_1_]] : f32 +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_24_1_]] : f32 +// CHECK: krnl.store [[VAR_27_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = vector.extract [[LOAD_RES_4_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = vector.extract [[LOAD_RES_6_MEM_4_]][0] : f32 from vector<1xf32> +// CHECK: krnl.store [[VAR_4_]], [[RES_8_]][] : memref +// CHECK: krnl.store [[VAR_5_]], [[RES_9_]][] : memref +// CHECK-DAG: [[LOAD_RES_8_MEM_:%.+]] = krnl.load [[RES_8_]][] : memref +// CHECK-DAG: [[LOAD_RES_9_MEM_:%.+]] = krnl.load [[RES_9_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = arith.maxnumf [[LOAD_RES_9_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.minnumf [[LOAD_RES_8_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_10_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.divf [[VAR_9_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref +// CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_10_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_17_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_4335_]], [[RES_11_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[RES_]]([[RES_]]_18) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ +// CHECK: [[VAR_22_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_1_:%.+]] = vector.load [[VAR_reshape_17_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_24_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_25_1_:%.+]] = arith.divf [[VAR_23_1_]], [[VAR_24_2_]] : vector<16xf32> +// CHECK: [[VAR_26_1_:%.+]] = vector.shape_cast [[VAR_25_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_27_2_:%.+]] = vector.extract [[VAR_26_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = "krnl.round_even"([[VAR_27_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.insert [[LOAD_RES_4_MEM_2_1_]], [[VAR_26_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_30_1_:%.+]] = vector.extract [[VAR_26_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_31_1_:%.+]] = "krnl.round_even"([[VAR_30_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.insert [[VAR_31_1_]], [[LOAD_RES_6_MEM_2_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.extract [[VAR_26_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.insert [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.extract [[VAR_26_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_37_2_:%.+]] = "krnl.round_even"([[VAR_36_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_38_:%.+]] = vector.insert [[VAR_37_2_]], [[LOAD_RES_6_MEM_1_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = vector.shape_cast [[VAR_38_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.addf [[VAR_39_]], [[VAR_40_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.maxnumf [[VAR_41_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.minnumf [[VAR_42_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.fptoui [[VAR_43_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_45_:%.+]] = arith.trunci [[VAR_44_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_46_:%.+]] = builtin.unrealized_conversion_cast [[VAR_45_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_46_]], [[VAR_reshape_19_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xui8>, vector<16xui8> +// CHECK: } +// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ +// CHECK: [[VAR_22_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_23_1_1_:%.+]] = krnl.load [[VAR_reshape_17_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_24_3_:%.+]] = arith.divf [[VAR_23_1_1_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_25_2_:%.+]] = "krnl.round_even"([[VAR_24_3_]]) : (f32) -> f32 +// CHECK: [[VAR_26_2_:%.+]] = arith.addf [[VAR_25_2_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_27_3_:%.+]] = arith.maxnumf [[VAR_26_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = arith.minnumf [[VAR_27_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[LOAD_RES_6_MEM_2_1_:%.+]] = arith.fptoui [[LOAD_RES_4_MEM_2_1_]] : f32 to i32 +// CHECK: [[VAR_30_2_:%.+]] = arith.trunci [[LOAD_RES_6_MEM_2_1_]] : i32 to i8 +// CHECK: [[VAR_31_2_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_2_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_31_2_]], [[VAR_reshape_19_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_7, [[RES_]]_8 : memref<255x17xui8>, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32>) -> (tensor<1x8xui8>, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<1x8xf32>) -> (tensor<1x8xui8>, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor<1x8xui8>, tensor, tensor + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dynamic_quantize_linear_reduced_simd_only +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x8xf32>) -> (memref<1x8xui8>, memref, memref) { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x8xui8> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_3_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_3_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_1_]] : vector<8xf32> into f32 +// CHECK-DAG: [[VAR_4_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_1_]] : vector<8xf32> into f32 +// CHECK: krnl.store [[VAR_3_]], [[RES_5_]][] : memref +// CHECK: krnl.store [[VAR_4_]], [[RES_7_]][] : memref +// CHECK-DAG: [[LOAD_RES_5_MEM_:%.+]] = krnl.load [[RES_5_]][] : memref +// CHECK-DAG: [[LOAD_RES_7_MEM_:%.+]] = krnl.load [[RES_7_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = arith.maxnumf [[LOAD_RES_7_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.minnumf [[LOAD_RES_5_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_9_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_8_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref +// CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.parallel([[BLOCK_TILE__1_]]) : !krnl.loop +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<8xf32> to vector<2x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shape_cast [[VAR_30_]] : vector<2x4xf32> to vector<8xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.splat [[VAR_15_]] : vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_31_]], [[VAR_32_]] : vector<8xf32> +// CHECK: [[VAR_34_:%.+]] = arith.maxnumf [[VAR_33_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[VAR_34_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_36_:%.+]] = arith.fptoui [[VAR_35_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_37_:%.+]] = arith.trunci [[VAR_36_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_38_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_38_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<1x8xui8>, memref, memref +// CHECK: } +} + diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir new file mode 100644 index 0000000000..15fedeab1b --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -0,0 +1,179 @@ +// RUN: onnx-mlir-opt --disable-quantization-zero-point --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// Test quantization with disabled zero point + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// ----- + + +func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { + %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dequantizelinear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_5_:%.+]] = arith.extui [[VAR_4_]] : i8 to i32 +// CHECK: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i32 to f32 +// CHECK: [[VAR_7_:%.+]] = arith.mulf [[VAR_6_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: } +// CHECK: return [[RES_]] : memref<4xf32> +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor) -> (tensor, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor, tensor, tensor + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-LABEL: func.func @test_dynamic_quantize_linear +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> (memref, memref, memref) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i8 +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_3_]], [[CST_0_2_]] : memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_]]#0, [[VAR_12_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_15_]], [[RES_3_]][] : memref +// CHECK: } +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref +// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_1_]]#0, [[VAR_12_1_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK: [[VAR_15_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_15_1_]], [[RES_4_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[CST_0_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_8_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_9_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_10_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_12_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_15_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_2_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_28_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.minnumf [[VAR_29_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_31_:%.+]] = arith.fptoui [[VAR_30_]] : f32 to i32 +// CHECK: [[VAR_32_:%.+]] = arith.trunci [[VAR_31_]] : i32 to i8 +// CHECK: [[VAR_33_:%.+]] = builtin.unrealized_conversion_cast [[VAR_32_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_33_]], [[VAR_reshape_14_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> + return %0 : tensor<6xui8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_2_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_4_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_5_:%.+]] = math.floor [[VAR_4_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[VAR_8_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.mulf [[VAR_5_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = math.floor [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_5_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[VAR_16_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.minnumf [[VAR_19_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.fptoui [[VAR_20_]] : f32 to i32 +// CHECK: [[VAR_22_:%.+]] = arith.trunci [[VAR_21_]] : i32 to i8 +// CHECK: [[VAR_23_:%.+]] = builtin.unrealized_conversion_cast [[VAR_22_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_23_]], [[RES_]]{{.}}[[VAR_2_]]{{.}} : memref<6xui8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: } +} + diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir index 65c77c702d..ece948da3d 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir @@ -6,12 +6,12 @@ // ----- -func.func @test_quantize_linear(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { +func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> return %0 : tensor<6xui8> // mlir2FileCheck.py -// CHECK-LABEL: func.func @test_quantize_linear +// CHECK-LABEL: func.func @test_quantize_linear_ui8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 // CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 @@ -22,7 +22,61 @@ func.func @test_quantize_linear(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 -// CHECK-DAG: [[VAR_3_:%.+]] = arith.uitofp [[VAR_2_]] : i8 to f32 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.uitofp [[VAR_3_]] : i32 to f32 +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpf ogt, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = arith.select [[VAR_11_]], [[VAR_12_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.mulf [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.mulf [[VAR_15_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_9_]], [[VAR_16_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[VAR_20_]], [[VAR_13_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.addf [[VAR_22_]], [[VAR_4_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.maxnumf [[VAR_23_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.minnumf [[VAR_24_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i32 +// CHECK: [[VAR_27_:%.+]] = arith.trunci [[VAR_26_]] : i32 to i8 +// CHECK: [[VAR_28_:%.+]] = builtin.unrealized_conversion_cast [[VAR_27_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_28_]], [[RES_]]{{.}}[[VAR_6_]]{{.}} : memref<6xui8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: } +} + +// ----- + + +func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xi8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xi8> + return %0 : tensor<6xi8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_i8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xi8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_minus_1_dot_280000_:%.+]] = arith.constant -1.280000e+02 : f32 +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xi8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref +// CHECK: [[VAR_2_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.sitofp [[VAR_2_]] : i32 to f32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index @@ -45,13 +99,60 @@ func.func @test_quantize_linear(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: // CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 // CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.fptoui [[VAR_24_]] : f32 to i8 -// CHECK: [[VAR_26_:%.+]] = builtin.unrealized_conversion_cast [[VAR_25_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_26_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> +// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_1_dot_270000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.fptosi [[VAR_24_]] : f32 to i32 +// CHECK: [[VAR_26_:%.+]] = arith.trunci [[VAR_25_]] : i32 to i8 +// CHECK: krnl.store [[VAR_26_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xi8> // CHECK: } -// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: return [[RES_]] : memref<6xi8> // CHECK: } } +// ----- + +func.func @test_quantize_linear_ui8_scalar(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL: func.func @test_quantize_linear_ui8_scalar +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref +// CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.uitofp [[VAR_3_]] : i32 to f32 +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][] : memref +// CHECK: [[VAR_6_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_7_:%.+]] = math.floor [[VAR_6_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpf ogt, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_9_]], [[VAR_10_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.mulf [[VAR_7_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = math.floor [[VAR_12_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.mulf [[VAR_13_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.subf [[VAR_7_]], [[VAR_14_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf oeq, [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_16_]], [[VAR_17_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.cmpf oeq, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[VAR_18_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.addf [[VAR_20_]], [[VAR_4_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[VAR_21_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.minnumf [[VAR_22_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.fptoui [[VAR_23_]] : f32 to i32 +// CHECK: [[VAR_25_:%.+]] = arith.trunci [[VAR_24_]] : i32 to i8 +// CHECK: [[VAR_26_:%.+]] = builtin.unrealized_conversion_cast [[VAR_25_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_26_]], [[RES_]][] : memref +// CHECK: return [[RES_]] : memref +// CHECK: } +} diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_fast_math_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_fast_math_canonicalize.mlir new file mode 100644 index 0000000000..b1e8c43587 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_fast_math_canonicalize.mlir @@ -0,0 +1,114 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-krnl=enable-fast-math --canonicalize %s -split-input-file | FileCheck %s + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// Test fast math where the divide by scale is replaced by mutiply by the reciprocal of the scale. +// ----- + + +func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> + return %0 : tensor<6xui8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref +// CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.uitofp [[VAR_3_]] : i32 to f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_7_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_7_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_9_:%.+]] = arith.mulf [[LOAD_PARAM_0_MEM_]], [[VAR_5_]] : f32 +// CHECK: [[VAR_10_:%.+]] = math.floor [[VAR_9_]] : f32 +// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_9_]], [[VAR_10_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.cmpf ogt, [[VAR_11_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.addf [[VAR_10_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_12_]], [[VAR_13_]], [[VAR_10_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.mulf [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = math.floor [[VAR_15_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.mulf [[VAR_16_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = arith.subf [[VAR_10_]], [[VAR_17_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.cmpf oeq, [[VAR_18_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[VAR_10_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = arith.select [[VAR_19_]], [[VAR_20_]], [[VAR_10_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.cmpf oeq, [[VAR_11_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.select [[VAR_22_]], [[VAR_21_]], [[VAR_14_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.addf [[VAR_23_]], [[VAR_4_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.maxnumf [[VAR_24_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.minnumf [[VAR_25_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_27_:%.+]] = arith.fptoui [[VAR_26_]] : f32 to i32 +// CHECK: [[VAR_28_:%.+]] = arith.trunci [[VAR_27_]] : i32 to i8 +// CHECK: [[VAR_29_:%.+]] = builtin.unrealized_conversion_cast [[VAR_28_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_29_]], [[RES_]]{{.}}[[VAR_7_]]{{.}} : memref<6xui8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: } +} + +// ----- + + +func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xi8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xi8> + return %0 : tensor<6xi8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_i8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xi8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_minus_1_dot_280000_:%.+]] = arith.constant -1.280000e+02 : f32 +// CHECK-DAG: [[CST_1_dot_270000_:%.+]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xi8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref +// CHECK: [[VAR_2_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.sitofp [[VAR_2_]] : i32 to f32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_8_:%.+]] = arith.mulf [[LOAD_PARAM_0_MEM_]], [[VAR_4_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpf ogt, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_13_:%.+]] = arith.select [[VAR_11_]], [[VAR_12_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.mulf [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.mulf [[VAR_15_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_9_]], [[VAR_16_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[VAR_20_]], [[VAR_13_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.addf [[VAR_22_]], [[VAR_3_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.maxnumf [[VAR_23_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.minnumf [[VAR_24_]], [[CST_1_dot_270000_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.fptosi [[VAR_25_]] : f32 to i32 +// CHECK: [[VAR_27_:%.+]] = arith.trunci [[VAR_26_]] : i32 to i8 +// CHECK: krnl.store [[VAR_27_]], [[RES_]]{{.}}[[VAR_6_]]{{.}} : memref<6xi8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xi8> +// CHECK: } +} + diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_depth_to_space_op.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_depth_to_space_op.mlir index 86d4e6c7a7..d3eaec08ff 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_depth_to_space_op.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_depth_to_space_op.mlir @@ -7,47 +7,46 @@ func.func private @test_depth_to_space_dynamic_dims(%arg0 : tensor<1x?x8x?xf32>) %0 = "onnx.DepthToSpace"(%arg0) {blocksize = 4 : si64} : (tensor<1x?x8x?xf32>) -> tensor<1x?x32x?xf32> "func.return"(%0) : (tensor<1x?x32x?xf32>) -> () -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 floordiv 16)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0] -> (s0 * 4)> -// CHECK-LABEL: func private @test_depth_to_space_dynamic_dims +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 floordiv 16)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-LABEL: func.func private @test_depth_to_space_dynamic_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x?x8x?xf32>) -> memref<1x?x32x?xf32> { -// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_c32_:%.+]] = arith.constant 32 : index -// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index -// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[VAR_c8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<1x?x8x?xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c3_]] : memref<1x?x8x?xf32> +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x8x?xf32> +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_3_]] : memref<1x?x8x?xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_0_]]{{.}} -// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_2_]](){{.}}[[VAR_1_]]{{.}} +// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_0_]]{{.}} // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xindex> -// CHECK: krnl.store [[VAR_c1_]], [[RES_]]{{.}}[[VAR_c0_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c4_]], [[RES_]]{{.}}[[VAR_c1_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c4_]], [[RES_]]{{.}}[[VAR_c2_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_2_]], [[RES_]]{{.}}[[VAR_c3_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c8_]], [[RES_]]{{.}}[[VAR_c4_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_1_]], [[RES_]]{{.}}[[VAR_c5_]]{{.}} : memref<6xindex> -// CHECK-DAG: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : memref<1x?x8x?xf32> to tensor<1x?x8x?xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[RES_]] : memref<6xindex> to tensor<6xi64> -// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[VAR_5_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<1x?x8x?xf32>, tensor<6xi64>) -> tensor -// CHECK: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[VAR_7_]] : tensor to memref -// CHECK: [[VAR_9_:%.+]] = memref.cast [[VAR_8_]] : memref to memref<1x4x4x?x8x?xf32> -// CHECK: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : memref<1x4x4x?x8x?xf32> to tensor<1x4x4x?x8x?xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Transpose"([[VAR_10_]]) {perm = [0, 3, 4, 1, 5, 2]} : (tensor<1x4x4x?x8x?xf32>) -> tensor<1x?x8x4x?x4xf32> +// CHECK: krnl.store [[CST_1_]], [[RES_]]{{.}}[[CST_0_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_4_]], [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_4_]], [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[VAR_0_]], [[RES_]]{{.}}[[CST_3_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_8_]], [[RES_]]{{.}}[[CST_4_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[VAR_dim_0_]], [[RES_]]{{.}}[[CST_5_]]{{.}} : memref<6xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[RES_]] : memref<6xindex> to tensor<6xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : memref<1x?x8x?xf32> to tensor<1x?x8x?xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[VAR_3_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<1x?x8x?xf32>, tensor<6xi64>) -> tensor +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[VAR_4_]] : tensor to memref +// CHECK: [[VAR_cast_:%.+]] = memref.cast [[VAR_5_]] : memref to memref<1x4x4x?x8x?xf32> +// CHECK: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[VAR_cast_]] : memref<1x4x4x?x8x?xf32> to tensor<1x4x4x?x8x?xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [0, 3, 4, 1, 5, 2]} : (tensor<1x4x4x?x8x?xf32>) -> tensor<1x?x8x4x?x4xf32> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4xindex> -// CHECK: krnl.store [[VAR_c1_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_2_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_c32_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<4xindex> -// CHECK: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[RES_1_]] : memref<4xindex> to tensor<4xi64> -// CHECK: [[VAR_14_:%.+]] = "onnx.Reshape"([[VAR_11_]], [[VAR_13_]]) {allowzero = 0 : si64} : (tensor<1x?x8x4x?x4xf32>, tensor<4xi64>) -> tensor -// CHECK: [[VAR_15_:%.+]] = builtin.unrealized_conversion_cast [[VAR_14_]] : tensor to memref -// CHECK: [[VAR_16_:%.+]] = memref.cast [[VAR_15_]] : memref to memref<1x?x32x?xf32> -// CHECK: return [[VAR_16_]] : memref<1x?x32x?xf32> -// CHECK: } +// CHECK: krnl.store [[CST_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[VAR_0_]], [[RES_1_]]{{.}}[[CST_1_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[CST_32_]], [[RES_1_]]{{.}}[[CST_2_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[VAR_1_]], [[RES_1_]]{{.}}[[CST_3_]]{{.}} : memref<4xindex> +// CHECK: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[RES_1_]] : memref<4xindex> to tensor<4xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[VAR_7_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<1x?x8x4x?x4xf32>, tensor<4xi64>) -> tensor +// CHECK: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : tensor to memref +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[VAR_10_]] : memref to memref<1x?x32x?xf32> +// CHECK: return [[VAR_cast_2_]] : memref<1x?x32x?xf32> } diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_space_to_depth.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_space_to_depth.mlir index 117209694c..ec9dbfecc8 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_space_to_depth.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/onnx_lowering_space_to_depth.mlir @@ -7,46 +7,44 @@ func.func private @test_space_to_depth_dynamic_dims(%arg0 : tensor<1x?x8x?xf32>) %0 = "onnx.SpaceToDepth"(%arg0) {blocksize = 4 : si64} : (tensor<1x?x8x?xf32>) -> tensor<1x?x2x?xf32> "func.return"(%0) : (tensor<1x?x2x?xf32>) -> () -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 16)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> -// CHECK-LABEL: func private @test_space_to_depth_dynamic_dims +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-LABEL: func.func private @test_space_to_depth_dynamic_dims // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x?x8x?xf32>) -> memref<1x?x2x?xf32> { -// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index -// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_0_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c1_]] : memref<1x?x8x?xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c3_]] : memref<1x?x8x?xf32> +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref<1x?x8x?xf32> +// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[PARAM_0_]], [[CST_3_]] : memref<1x?x8x?xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_0_]]{{.}} -// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_1_]]{{.}} +// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_0_]]{{.}} // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xindex> -// CHECK: krnl.store [[VAR_c1_]], [[RES_]]{{.}}[[VAR_c0_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_0_]], [[RES_]]{{.}}[[VAR_c1_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c2_]], [[RES_]]{{.}}[[VAR_c2_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c4_]], [[RES_]]{{.}}[[VAR_c3_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_3_]], [[RES_]]{{.}}[[VAR_c4_]]{{.}} : memref<6xindex> -// CHECK: krnl.store [[VAR_c4_]], [[RES_]]{{.}}[[VAR_c5_]]{{.}} : memref<6xindex> -// CHECK-DAG: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : memref<1x?x8x?xf32> to tensor<1x?x8x?xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[RES_]] : memref<6xindex> to tensor<6xi64> -// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[VAR_5_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<1x?x8x?xf32>, tensor<6xi64>) -> tensor -// CHECK: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[VAR_7_]] : tensor to memref -// CHECK: [[VAR_9_:%.+]] = memref.cast [[VAR_8_]] : memref to memref<1x?x2x4x?x4xf32> -// CHECK: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : memref<1x?x2x4x?x4xf32> to tensor<1x?x2x4x?x4xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Transpose"([[VAR_10_]]) {perm = [0, 1, 3, 5, 2, 4]} : (tensor<1x?x2x4x?x4xf32>) -> tensor<1x?x4x4x2x?xf32> +// CHECK: krnl.store [[CST_1_]], [[RES_]]{{.}}[[CST_0_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[VAR_dim_]], [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_2_]], [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_4_]], [[RES_]]{{.}}[[CST_3_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[VAR_1_]], [[RES_]]{{.}}[[CST_4_]]{{.}} : memref<6xindex> +// CHECK: krnl.store [[CST_4_]], [[RES_]]{{.}}[[CST_5_]]{{.}} : memref<6xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[RES_]] : memref<6xindex> to tensor<6xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : memref<1x?x8x?xf32> to tensor<1x?x8x?xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[VAR_3_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<1x?x8x?xf32>, tensor<6xi64>) -> tensor +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[VAR_4_]] : tensor to memref +// CHECK: [[VAR_cast_:%.+]] = memref.cast [[VAR_5_]] : memref to memref<1x?x2x4x?x4xf32> +// CHECK: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[VAR_cast_]] : memref<1x?x2x4x?x4xf32> to tensor<1x?x2x4x?x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [0, 1, 3, 5, 2, 4]} : (tensor<1x?x2x4x?x4xf32>) -> tensor<1x?x4x4x2x?xf32> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<4xindex> -// CHECK: krnl.store [[VAR_c1_]], [[RES_1_]]{{.}}[[VAR_c0_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_2_]], [[RES_1_]]{{.}}[[VAR_c1_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_c2_]], [[RES_1_]]{{.}}[[VAR_c2_]]{{.}} : memref<4xindex> -// CHECK: krnl.store [[VAR_3_]], [[RES_1_]]{{.}}[[VAR_c3_]]{{.}} : memref<4xindex> -// CHECK: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[RES_1_]] : memref<4xindex> to tensor<4xi64> -// CHECK: [[VAR_14_:%.+]] = "onnx.Reshape"([[VAR_11_]], [[VAR_13_]]) {allowzero = 0 : si64} : (tensor<1x?x4x4x2x?xf32>, tensor<4xi64>) -> tensor -// CHECK: [[VAR_15_:%.+]] = builtin.unrealized_conversion_cast [[VAR_14_]] : tensor to memref -// CHECK: [[VAR_16_:%.+]] = memref.cast [[VAR_15_]] : memref to memref<1x?x2x?xf32> -// CHECK: return [[VAR_16_]] : memref<1x?x2x?xf32> -// CHECK: } - +// CHECK: krnl.store [[CST_1_]], [[RES_1_]]{{.}}[[CST_0_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[VAR_0_]], [[RES_1_]]{{.}}[[CST_1_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[CST_2_]], [[RES_1_]]{{.}}[[CST_2_]]{{.}} : memref<4xindex> +// CHECK: krnl.store [[VAR_1_]], [[RES_1_]]{{.}}[[CST_3_]]{{.}} : memref<4xindex> +// CHECK: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[RES_1_]] : memref<4xindex> to tensor<4xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[VAR_7_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<1x?x4x4x2x?xf32>, tensor<4xi64>) -> tensor +// CHECK: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : tensor to memref +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[VAR_10_]] : memref to memref<1x?x2x?xf32> +// CHECK: return [[VAR_cast_2_]] : memref<1x?x2x?xf32> } diff --git a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir index 5309b274fd..f319459c68 100644 --- a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir +++ b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir @@ -124,8 +124,9 @@ func.func @test_fuse_element8(%arg0: tensor, %arg1: tensor<1xf32>) -> ten // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: [[VAR_4_:%.+]] = math.powf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i8 -// CHECK: krnl.store [[VAR_5_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref +// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i32 +// CHECK: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i32 to i8 +// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -322,13 +323,14 @@ func.func @fuse_element_20(%533: tensor, %537 : tensor, // ----- + func.func @test_fuse_element21(%arg0: tensor, %arg1: tensor<1xf32>, %arg2 : tensor<1xi8>) -> tensor { - %0 = "onnx.Pow"(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor %1 = "onnx.Cast"(%0) {to = i8} : (tensor) -> tensor %2 = "onnx.Add"(%1, %arg2) : (tensor, tensor<1xi8>) -> tensor return %2 : tensor -} +// mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @test_fuse_element21 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref<1xf32>, [[PARAM_2_:%.+]]: memref<1xi8>) -> memref { @@ -341,12 +343,13 @@ func.func @test_fuse_element21(%arg0: tensor, %arg1: tensor<1xf32>, %arg2 // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: [[VAR_4_:%.+]] = math.powf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i8 +// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i32 to i8 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[CST_0_]]{{.}} : memref<1xi8> -// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[LOAD_PARAM_2_MEM_]] : i8 -// CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_6_]], [[LOAD_PARAM_2_MEM_]] : i8 +// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } - +} diff --git a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir new file mode 100644 index 0000000000..2279a7c901 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// ----- +func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32> + %1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32> + %2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32> + return %2 : tensor<1024xf32> +} +// CHECK-LABEL: func.func @test_reuse +// CHECK-NOT: memref.alloc diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir index d74f7288f6..0da75f096a 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir @@ -32,46 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor) -> () } -//TODO: Renable dynamic shape test -// func.func @test_softmax_dynamic -// ([[PARAM_0_:%.+]]: tensor) -> tensor { -// [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// [[CST_2_:%.+]] = arith.constant 2 : index -// [[CST_1_:%.+]] = arith.constant 1 : index -// [[CST_0_:%.+]] = arith.constant 0 : index -// [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// separator of consecutive DAGs -// [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor -// [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// separator of consecutive DAGs -// [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index -// [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index -// [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index -// [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> -// [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> -// [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor -// [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor -// [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor -// [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// separator of consecutive DAGs -// [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index -// [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index -// [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index -// [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> -// [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> -// [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor -// return [[VAR_28_]] : tensor -// } +// CHECK-LABEL: func.func @test_softmax_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index +// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor +// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor +// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index +// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor +// CHECK: return [[VAR_28_]] : tensor +// CHECK: } + // ----- diff --git a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir index a5893a6833..15de997cad 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Tensor/GatherElements.mlir @@ -3,17 +3,21 @@ func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>) -> tensor<2x2xf32> { %0 = "onnx.GatherElements"(%arg0, %arg1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> -// CHECK: func.func @main_gather_elements([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> { -// CHECK-DAG: [[CST_:%.+]] = arith.constant dense<[2, 2, 1]> : tensor<3xindex> -// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<3> : tensor -// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.broadcast_in_dim [[VAR_0_]], dims = [] : (tensor) -> tensor<2x2xi64> -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_1_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> -// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_2_]] : tensor<2x2xi64> -// CHECK-NEXT: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_reshape [[VAR_5_]], [[CST_]] : (tensor<2x2xi64>, tensor<3xindex>) -> tensor<2x2x1xi64> -// CHECK-DAG: [[VAR_7_:%.+]] = stablehlo.dynamic_iota [[CST_]], dim = 1 : (tensor<3xindex>) -> tensor<2x2x1xi64> -// CHECK-NEXT: [[VAR_8_:%.+]] = stablehlo.concatenate [[VAR_6_]], [[VAR_7_]], dim = 2 : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64> -// CHECK-NEXT: [[VAR_9_:%.+]] = "stablehlo.gather"([[PARAM_0_]], [[VAR_8_]]) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32> -// CHECK-NEXT: return [[VAR_9_]] : tensor<2x2xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @main_gather_elements +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<[2, 2, 1]> : tensor<3xindex> +// CHECK-DAG: [[VAR_c_:%.+]] = stablehlo.constant dense<3> : tensor +// CHECK-DAG: [[VAR_c_0_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.broadcast_in_dim [[VAR_c_]], dims = [] : (tensor) -> tensor<2x2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.compare LT, [[PARAM_1_]], [[VAR_c_0_]] : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> +// CHECK: [[VAR_2_:%.+]] = stablehlo.add [[PARAM_1_]], [[VAR_0_]] : tensor<2x2xi64> +// CHECK: [[VAR_3_:%.+]] = stablehlo.select [[VAR_1_]], [[VAR_2_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.dynamic_reshape [[VAR_3_]], [[VAR_cst_]] : (tensor<2x2xi64>, tensor<3xindex>) -> tensor<2x2x1xi64> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_iota [[VAR_cst_]], dim = 1 : (tensor<3xindex>) -> tensor<2x2x1xi64> +// CHECK: [[VAR_6_:%.+]] = stablehlo.concatenate [[VAR_4_]], [[VAR_5_]], dim = 2 : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64> +// CHECK: [[VAR_7_:%.+]] = "stablehlo.gather"([[PARAM_0_]], [[VAR_6_]]) <{dimension_numbers = #stablehlo.gather, slice_sizes = array}> : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32> +// CHECK: return [[VAR_7_]] : tensor<2x2xf32> +// CHECK: } } \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_tosa/Flow/entrypoint.mlir b/test/mlir/conversion/onnx_to_tosa/Flow/entrypoint.mlir new file mode 100644 index 0000000000..5b07174576 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Flow/entrypoint.mlir @@ -0,0 +1,14 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +module { + func.func @main_graph(%arg0: tensor<32x512x1x1xf32>) -> tensor<32x512xf32> { + %0 = "onnx.Constant"() {value = dense<[32,512]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<32x512x1x1xf32>, tensor<2xi64>) -> tensor<32x512xf32> + return %1 : tensor<32x512xf32> + } + "onnx.EntryPoint"() {func = @main_graph} : () -> () +// CHECK-LABEL: func @forward +// CHECK-NOT: main_graph +// CHECK-NOT: "onnx.EntryPoint" +} + diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir index a772d6b4a0..ffc57013c3 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa=grouped-conv-threshold=4 -cse %s -split-input-file | FileCheck %s func.func @test_onnx_conv2d_stride_13(%arg0: tensor<5x3x256x256xf32>, %arg1 : tensor<2x3x64x64xf32>, %arg2: tensor<2xf32>) -> tensor<5x2x15x15xf32> { @@ -9,12 +9,13 @@ func.func @test_onnx_conv2d_stride_13(%arg0: tensor<5x3x256x256xf32>, %arg1 : te // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x64x64xf32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<2xf32>) -> tensor<5x2x15x15xf32> { // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> -// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x15x15x2xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x2xf32>, tensor<4xi32>) -> tensor<5x2x15x15xf32> -// CHECK: return %[[VAL_8]] : tensor<5x2x15x15xf32> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x3xf32>) -> tensor<5x245x245x3xf32> +// CHECK-DAG: %[[VAL_7:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_5]], %[[VAL_2]] {dilation = array, pad = array, stride = array} : (tensor<5x245x245x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x15x15x2xf32> +// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_7]], %[[VAL_8]] : (tensor<5x15x15x2xf32>, tensor<4xi32>) -> tensor<5x2x15x15xf32> +// CHECK: return %[[VAL_9]] : tensor<5x2x15x15xf32> } // ----- @@ -44,13 +45,14 @@ func.func @test_onnx_conv2d_no_dilation_pad(%arg0: tensor<5x3x256x256xf32>, %arg // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x3x256x256xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<7x3x64x64xf32>) -> tensor<5x7x15x15xf32> { // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> -// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<7x3x64x64xf32>, tensor<4xi32>) -> tensor<7x64x64x3xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<7xf32>}> : () -> tensor<7xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>) -> tensor<5x15x15x7xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x7xf32>, tensor<4xi32>) -> tensor<5x7x15x15xf32> -// CHECK: return %[[VAL_8]] : tensor<5x7x15x15xf32> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<7x3x64x64xf32>, tensor<4xi32>) -> tensor<7x64x64x3xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<7xf32>}> : () -> tensor<7xf32> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_3]] {size = array, start = array} : (tensor<5x256x256x3xf32>) -> tensor<5x246x246x3xf32> +// CHECK: %[[VAL_7:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] {dilation = array, pad = array, stride = array} : (tensor<5x246x246x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>) -> tensor<5x15x15x7xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_7]], %[[VAL_8]] : (tensor<5x15x15x7xf32>, tensor<4xi32>) -> tensor<5x7x15x15xf32> +// CHECK: return %[[VAL_9]] : tensor<5x7x15x15xf32> } // ----- @@ -80,28 +82,32 @@ func.func @test_onnx_conv2d_group(%arg0: tensor<5x64x256x256xf32>, %arg1 : tenso // CHECK-SAME: %[[VAL_1:.*]]: tensor<12x16x45x45xf32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<12xf32>) -> tensor<5x12x17x17xf32> { // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x64x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x64xf32> -// CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<12x16x45x45xf32>, tensor<4xi32>) -> tensor<12x45x45x16xf32> -// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_9:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.conv2d %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_21:.*]] = tosa.conv2d %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] {dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_22:.*]] = tosa.concat %[[VAL_9]], %[[VAL_13]], %[[VAL_17]], %[[VAL_21]] {axis = 3 : i32} : (tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>) -> tensor<5x17x17x12xf32> -// CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<5x17x17x12xf32>, tensor<4xi32>) -> tensor<5x12x17x17xf32> -// CHECK: return %[[VAL_24]] : tensor<5x12x17x17xf32> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x64x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x64xf32> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<12x16x45x45xf32>, tensor<4xi32>) -> tensor<12x45x45x16xf32> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x252x252x64xf32> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]] {size = array, start = array} : (tensor<5x252x252x64xf32>) -> tensor<5x252x252x16xf32> +// CHECK-DAG: %[[VAL_8:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> +// CHECK-DAG: %[[VAL_9:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: %[[VAL_10:.*]] = tosa.conv2d %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {dilation = array, pad = array, stride = array} : (tensor<5x252x252x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> +// CHECK-DAG: %[[VAL_11:.*]] = tosa.slice %[[VAL_6]] {size = array, start = array} : (tensor<5x252x252x64xf32>) -> tensor<5x252x252x16xf32> +// CHECK-DAG: %[[VAL_12:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> +// CHECK-DAG: %[[VAL_13:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: %[[VAL_14:.*]] = tosa.conv2d %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] {dilation = array, pad = array, stride = array} : (tensor<5x252x252x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> +// CHECK-DAG: %[[VAL_15:.*]] = tosa.slice %[[VAL_6]] {size = array, start = array} : (tensor<5x252x252x64xf32>) -> tensor<5x252x252x16xf32> +// CHECK-DAG: %[[VAL_16:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> +// CHECK-DAG: %[[VAL_17:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: %[[VAL_18:.*]] = tosa.conv2d %[[VAL_15]], %[[VAL_16]], %[[VAL_17]] {dilation = array, pad = array, stride = array} : (tensor<5x252x252x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> +// CHECK-DAG: %[[VAL_19:.*]] = tosa.slice %[[VAL_6]] {size = array, start = array} : (tensor<5x252x252x64xf32>) -> tensor<5x252x252x16xf32> +// CHECK-DAG: %[[VAL_20:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> +// CHECK-DAG: %[[VAL_21:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_20]], %[[VAL_21]] {dilation = array, pad = array, stride = array} : (tensor<5x252x252x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> +// CHECK-DAG: %[[VAL_23:.*]] = tosa.concat %[[VAL_10]], %[[VAL_14]], %[[VAL_18]], %[[VAL_22]] {axis = 3 : i32} : (tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>) -> tensor<5x17x17x12xf32> +// CHECK-DAG: %[[VAL_24:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_25:.*]] = tosa.transpose %[[VAL_23]], %[[VAL_24]] : (tensor<5x17x17x12xf32>, tensor<4xi32>) -> tensor<5x12x17x17xf32> +// CHECK: return %[[VAL_25]] : tensor<5x12x17x17xf32> } // ----- @@ -119,4 +125,76 @@ func.func @test_onnx_conv2d_autopad(%arg0: tensor<5x3x125x256xf32>, %arg1 : tens // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x125x256x2xf32>, tensor<4xi32>) -> tensor<5x2x125x256xf32> // CHECK: return %[[VAL_8]] : tensor<5x2x125x256xf32> +} + +// ----- +func.func @test_onnx_conv2d_group_higher_4(%arg0: tensor<5x128x256x256xf32>, %arg1 : tensor<16x16x45x45xf32>, %arg2: tensor<16xf32>) -> tensor<5x16x17x17xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 8 : si64, pads = [1, 1, 1, 1], strides = [13, 13]} : (tensor<5x128x256x256xf32>, tensor<16x16x45x45xf32>, tensor<16xf32>) -> tensor<5x16x17x17xf32> + return %0 : tensor<5x16x17x17xf32> +// CHECK-LABEL: func.func @test_onnx_conv2d_group_higher_4 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x128x256x256xf32>, [[PARAM_1_:%.+]]: tensor<16x16x45x45xf32>, [[PARAM_2_:%.+]]: tensor<16xf32>) -> tensor<5x16x17x17xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {auto_pad = "NOTSET", group = 8 : si64, pads = [1, 1, 1, 1], strides = [13, 13]} : (tensor<5x128x256x256xf32>, tensor<16x16x45x45xf32>, tensor<16xf32>) -> tensor<5x16x17x17xf32> +// CHECK: return [[VAR_0_]] : tensor<5x16x17x17xf32> +} + +// ----- +func.func @test_onnx_conv2d_group_to_depthwise(%arg0: tensor<32x48x112x112xf32>, %arg1 : tensor<48x1x3x3xf32>, %arg2: tensor<48xf32>) -> tensor<32x48x112x112xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1, 1], group = 48 : si64, kernel_shape = [3, 3], onnx_node_name = "Conv_1395", pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<32x48x112x112xf32>, tensor<48x1x3x3xf32>, tensor<48xf32>) -> tensor<32x48x112x112xf32> + return %0 : tensor<32x48x112x112xf32> +// CHECK-LABEL: func.func @test_onnx_conv2d_group_to_depthwise +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<32x48x112x112xf32>, [[PARAM_1_:%.+]]: tensor<48x1x3x3xf32>, [[PARAM_2_:%.+]]: tensor<48xf32>) -> tensor<32x48x112x112xf32> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[VAR_0_]] : (tensor<32x48x112x112xf32>, tensor<4xi32>) -> tensor<32x112x112x48xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: [[VAR_3_:%.+]] = tosa.transpose [[PARAM_1_]], [[VAR_2_]] : (tensor<48x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x48x1xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3x3x48x1xf32>) -> tensor<3x3x48x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.depthwise_conv2d [[VAR_1_]], [[VAR_4_]], [[PARAM_2_]] {dilation = array, pad = array, stride = array} : (tensor<32x112x112x48xf32>, tensor<3x3x48x1xf32>, tensor<48xf32>) -> tensor<32x112x112x48xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.transpose [[VAR_5_]], [[VAR_6_]] : (tensor<32x112x112x48xf32>, tensor<4xi32>) -> tensor<32x48x112x112xf32> +// CHECK: return [[VAR_7_]] : tensor<32x48x112x112xf32> +} + +// ----- + +func.func @test_onnx_conv2d_group_to_depthwise_integer_multiple(%arg0: tensor<32x24x112x112xf32>, %arg1 : tensor<48x1x3x3xf32>, %arg2: tensor<48xf32>) -> tensor<32x48x112x112xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1, 1], group = 24 : si64, kernel_shape = [3, 3], onnx_node_name = "Conv_1395", pads = [1, 1, 1, 1], strides = [1, 1]} : (tensor<32x24x112x112xf32>, tensor<48x1x3x3xf32>, tensor<48xf32>) -> tensor<32x48x112x112xf32> + return %0 : tensor<32x48x112x112xf32> +// CHECK-LABEL: func.func @test_onnx_conv2d_group_to_depthwise_integer_multiple +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<32x24x112x112xf32>, [[PARAM_1_:%.+]]: tensor<48x1x3x3xf32>, [[PARAM_2_:%.+]]: tensor<48xf32>) -> tensor<32x48x112x112xf32> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[VAR_0_]] : (tensor<32x24x112x112xf32>, tensor<4xi32>) -> tensor<32x112x112x24xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: [[VAR_3_:%.+]] = tosa.transpose [[PARAM_1_]], [[VAR_2_]] : (tensor<48x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x48x1xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3x3x48x1xf32>) -> tensor<3x3x24x2xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.depthwise_conv2d [[VAR_1_]], [[VAR_4_]], [[PARAM_2_]] {dilation = array, pad = array, stride = array} : (tensor<32x112x112x24xf32>, tensor<3x3x24x2xf32>, tensor<48xf32>) -> tensor<32x112x112x48xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: [[VAR_7_:%.+]] = tosa.transpose [[VAR_5_]], [[VAR_6_]] : (tensor<32x112x112x48xf32>, tensor<4xi32>) -> tensor<32x48x112x112xf32> +// CHECK: return [[VAR_7_]] : tensor<32x48x112x112xf32> +} + +// ----- + +func.func @test_onnx_conv2d_dyn_shapes(%arg0: tensor, %arg1 : tensor<2x3x64x64xf32>, %arg2: tensor<2xf32>) -> tensor { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {dilations = [1, 1], pads = [1, 1, 1, 1], strides = [13, 13]} : (tensor, tensor<2x3x64x64xf32>, tensor<2xf32>) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_onnx_conv2d_dyn_shapes +// CHECK: onnx.Conv +} + +// ----- + +func.func @test_onnx_conv2d_dyn_shapes_no_rank(%arg0: tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {dilations = [1, 1], pads = [1, 1, 1, 1], strides = [13, 13]} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_onnx_conv2d_dyn_shapes_no_rank +// CHECK: onnx.Conv +} + +// ----- + +func.func @test_onnx_conv2d_dyn_shapes_with_shape_inference(%arg0: tensor<5x3x256x256xf32>, %arg1 : tensor<2x3x64x64xf32>, %arg2: tensor<2xf32>) -> tensor { + %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {dilations = [1, 1], pads = [1, 1, 1, 1], strides = [13, 13]} : (tensor<5x3x256x256xf32>, tensor<2x3x64x64xf32>, tensor<2xf32>) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_onnx_conv2d_dyn_shapes_with_shape_inference +// CHECK: tosa.conv } \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 623ef3fe5f..2385761bdb 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -1,5 +1,76 @@ // RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s +// ----- + +func.func @test_cast_f32_i8(%arg0: tensor<13x21x1xf32>) -> tensor<13x21x1xi8> { + %0 = "onnx.Cast"(%arg0) {to = i8} : (tensor<13x21x1xf32>) -> tensor<13x21x1xi8> + "func.return"(%0) : (tensor<13x21x1xi8>) -> () +// CHECK-LABEL: func.func @test_cast_f32_i8( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x1xf32>) -> tensor<13x21x1xi8> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: %[[VAL_2:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xi1> +// CHECK: %[[VAL_3:.*]] = tosa.floor %[[VAL_0]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.ceil %[[VAL_0]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.select %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (tensor<13x21x1xi1>, tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xi8> +// CHECK: return %[[VAL_6]] : tensor<13x21x1xi8> +// CHECK: } +} + +// ----- + +func.func @test_cast_int4_and_uint4_to_from_int8_uint8(%arg0: tensor<1xi4>, %arg1: tensor<1xui4>) -> (tensor<1xi4>, tensor<1xui4>) { + %0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> + %1 = "onnx.Cast"(%0) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> + %2 = "onnx.Cast"(%arg1) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> + %3 = "onnx.Cast"(%2) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> + onnx.Return %1, %3 : tensor<1xi4>, tensor<1xui4> + // CHECK-LABEL: func.func @test_cast_int4_and_uint4_to_from_int8_uint8( + // TOSA does not support int4 casting + // CHECK-NOT: tosa.cast +} + +// ----- + +func.func @test_cast_f16_i8(%arg0: tensor<13x21x1xf16>) -> tensor<13x21x1xi8> { + %0 = "onnx.Cast"(%arg0) {to = i8} : (tensor<13x21x1xf16>) -> tensor<13x21x1xi8> + "func.return"(%0) : (tensor<13x21x1xi8>) -> () +// CHECK-LABEL: func.func @test_cast_f16_i8( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x1xf16>) -> tensor<13x21x1xi8> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK: %[[VAL_2:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x1xf16>, tensor<1x1x1xf16>) -> tensor<13x21x1xi1> +// CHECK: %[[VAL_3:.*]] = tosa.floor %[[VAL_0]] : (tensor<13x21x1xf16>) -> tensor<13x21x1xf16> +// CHECK: %[[VAL_4:.*]] = tosa.ceil %[[VAL_0]] : (tensor<13x21x1xf16>) -> tensor<13x21x1xf16> +// CHECK: %[[VAL_5:.*]] = tosa.select %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (tensor<13x21x1xi1>, tensor<13x21x1xf16>, tensor<13x21x1xf16>) -> tensor<13x21x1xf16> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<13x21x1xf16>) -> tensor<13x21x1xi8> +// CHECK: return %[[VAL_6]] : tensor<13x21x1xi8> +// CHECK: } +} + +// ----- + +func.func @test_cast_i8_i1(%arg0: tensor<1x21x1x1xi8>) -> tensor<1x21x1x1xi1> { + %0 = "onnx.Cast"(%arg0) {to = i1} : (tensor<1x21x1x1xi8>) -> tensor<1x21x1x1xi1> + "func.return"(%0) : (tensor<1x21x1x1xi1>) -> () +// CHECK-LABEL: func @test_cast_i8_i1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x21x1x1xi8>) -> tensor<1x21x1x1xi1> { +// CHECK-NEXT: tosa.cast [[PARAM_0_]] : (tensor<1x21x1x1xi8>) -> tensor<1x21x1x1xi1> +} + +// ----- + +func.func @test_cast_f32_i1(%arg0: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Cast"(%arg0) {to = i1} : (tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () + + // CHECK-LABEL: func @test_cast_f32_i1 + // CHECK-SAME: (%[[VAL_0:.*]]: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + // CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + // CHECK: return %[[VAL_1]] : tensor<13x21x1xi1> +} + +// ----- + func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> "func.return"(%0) : (tensor<10x10xf32>) -> () @@ -42,7 +113,6 @@ func.func @test_floor(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.floor [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> } - // ----- func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { @@ -65,6 +135,52 @@ func.func @test_add_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) // CHECK: return [[VAR_1_]] : tensor<13x21x1xf32> } +// ----- + +func.func @test_add_ui32(%arg0: tensor<13x21x1xui32>, %arg1: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32> + "func.return"(%0) : (tensor<13x21x1xui32>) -> () +// CHECK-LABEL: func.func @test_add_ui32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xui32>, [[PARAM_1_:%.+]]: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> { +// CHECK: [[VAR_0_:%.+]] = tosa.add [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xui32> +} + +// ----- + +func.func @test_add_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64> + "func.return"(%0) : (tensor<13x21x1xf64>) -> () +// CHECK-LABEL: func.func @test_add_f64 +// CHECK-NOT: onnx.Add +// CHECK: return {{.*}}: tensor<13x21x1xf64> +} + +// ----- + +func.func @test_add_dyn_shape_and_const(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<8400> : tensor<1xi64> + %1 = "onnx.Add"(%arg0, %0) : (tensor, tensor<1xi64>) -> tensor + "func.return"(%1) : (tensor) -> () +// CHECK-LABEL: test_add_dyn_shape_and_const +// CHECK: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<8400> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<1xi64>) -> tensor<1x1xi64> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_1_]] : (tensor, tensor<1x1xi64>) -> tensor +// CHECK: return [[VAR_2_]] : tensor +} + +// ----- + +func.func @test_add_dyn_shape_no_rank(%arg0: tensor<*xi64>) -> tensor<*xi64> { + %0 = "onnx.Add"(%arg0, %arg0) : (tensor<*xi64>, tensor<*xi64>) -> tensor<*xi64> + "func.return"(%0) : (tensor<*xi64>) -> () +// CHECK-LABEL: test_add_dyn_shape_no_rank +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xi64>) -> tensor<*xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.add [[PARAM_0_]], [[PARAM_0_]] : (tensor<*xi64>, tensor<*xi64>) -> tensor<*xi64> +// CHECK: return [[VAR_0_]] : tensor<*xi64> +// CHECK: } +} // ----- @@ -88,6 +204,47 @@ func.func @test_sub_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) // CHECK: return [[VAR_1_]] : tensor<13x21x1xf32> } +// ----- + +func.func @test_mul(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_mul +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_mul_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func @test_mul_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.mul [[PARAM_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +} + +// ----- + +func.func @test_mul_rank_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_mul_rank_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<21x1xf32>) -> tensor<1x21x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_mul_rank_broadcast2(%arg0: tensor<21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_mul_rank_broadcast2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<21x1xf32>) -> tensor<1x21x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[VAR_0_]], [[PARAM_1_]] {shift = 0 : i8} : (tensor<1x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} // ----- @@ -101,6 +258,39 @@ func.func @test_div(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x1xi32>) -> t // ----- +func.func @test_div_dynamic(%arg0: tensor, %arg1: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> + "func.return"(%0) : (tensor<13x?x?xi32>) -> () +// CHECK-LABEL: func @test_div_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xi32>) -> tensor<13x?x?xi32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.int_div [[PARAM_0_]], [[PARAM_1_]] : (tensor, tensor<13x?x?xi32>) -> tensor<13x?x?xi32> +} + +// ----- + +func.func @test_div_dynamic_float(%arg0: tensor, %arg1: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> + "func.return"(%0) : (tensor<13x?x?xf32>) -> () +// CHECK-LABEL: func.func @test_div_dynamic_float +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<13x?x?xf32>) -> tensor<13x?x?xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<13x?x?xf32>) -> tensor<13x?x?xf32> +// CHECK: return [[VAR_1_]] : tensor<13x?x?xf32> +// CHECK: } +} + +// ----- + +func.func @test_div_unsigned(%arg0: tensor<13x21x1xui8>, %arg1: tensor<13x21x1xui8>) -> tensor<13x21x1xui8> { + %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xui8>, tensor<13x21x1xui8>) -> tensor<13x21x1xui8> + "func.return"(%0) : (tensor<13x21x1xui8>) -> () +// CHECK-LABEL: func @test_div_unsigned +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xui8>, [[PARAM_1_:%.+]]: tensor<13x21x1xui8>) -> tensor<13x21x1xui8> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.int_div [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xui8>, tensor<13x21x1xui8>) -> tensor<13x21x1xui8> +} + +// ----- + func.func @test_div_broadcast(%arg0: tensor<13x21x1xi32>, %arg1: tensor<1xi32>) -> tensor<13x21x1xi32> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<1xi32>) -> tensor<13x21x1xi32> "func.return"(%0) : (tensor<13x21x1xi32>) -> () @@ -123,6 +313,351 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1 // ----- +func.func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.LeakyRelu"(%arg0) {alpha = 0.707330704 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_leaky_relu +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.707330704> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR1]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR3:.*]] = tosa.greater_equal %arg0, %[[VAR0]] +// CHECK: %[[VAR6:.*]] = tosa.select %[[VAR3]], %arg0, %[[VAR2]] +} + +func.func @test_leaky_relu_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { + %0 = "onnx.LeakyRelu"(%arg0) {alpha = 0.707330704 : f32} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> + func.return %0 : tensor<13x21x3xbf16> +// CHECK-LABEL: test_leaky_relu_bf16 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xbf16>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<7.070310e-01> : tensor<1x1x1xbf16>}> +// CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR1]] {shift = 0 : i8} +// CHECK-DAG: %[[VAR3:.*]] = tosa.greater_equal %arg0, %[[VAR0]] +// CHECK: %[[VAR6:.*]] = tosa.select %[[VAR3]], %arg0, %[[VAR2]] +} + +// ----- + +func.func @test_prelu(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.PRelu"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_prelu +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.mul %arg0, %arg1 {shift = 0 : i8} +// CHECK: [[VAR_2_:%.+]] = tosa.greater_equal %arg0, [[VAR_0_]] +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], %arg0, [[VAR_1_]] +} + +func.func @test_prelu_bf16(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { + %0 = "onnx.PRelu"(%arg0, %arg1) : (tensor<13x21x3xbf16>, tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> + func.return %0 : tensor<13x21x3xbf16> +// CHECK-LABEL: test_prelu_bf16 +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xbf16>}> : () -> tensor<1x1x1xbf16> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.mul %arg0, %arg1 {shift = 0 : i8} +// CHECK: [[VAR_2_:%.+]] = tosa.greater_equal %arg0, [[VAR_0_]] +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], %arg0, [[VAR_1_]] +} + +// ----- + +func.func @test_selu_default_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.Selu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_selu_default_value +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.673260e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.050700e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp %arg0 +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] {shift = 0 : i8} +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater %arg0, [[VAR_2_]] +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], %arg0, [[VAR_5_]] +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] {shift = 0 : i8} +// CHECK: return [[VAR_8_]] : tensor<13x21x3xf32> +} + +func.func @test_selu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_selu +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.500000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] {shift = 0 : i8} +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater %arg0, [[VAR_2_]] +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], %arg0, [[VAR_5_]] +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] {shift = 0 : i8} +// CHECK: return [[VAR_8_]] : tensor<13x21x3xf32> +} + +// ----- + +func.func @test_selu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_selu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Selu"([[PARAM_0_]]) {alpha = 1.500000e+00 : f32, gamma = 2.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_selu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Selu"(%arg0) {alpha = 1.5 : f32, gamma = 2.0 : f32} : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_selu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.500000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_0_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.sub [[VAR_4_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_2_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_8_]] : tensor +} + +// ----- + +func.func @test_softplus(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.Softplus"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_softplus +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.exp %arg0 +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_0_]], [[VAR_1_]] +// CHECK: [[VAR_3_:%.+]] = tosa.log [[VAR_2_]] +// CHECK: return [[VAR_3_]] : tensor<13x21x3xf32> +} + +// ----- + +func.func @test_softplus_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.Softplus"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_softplus_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor) -> tensor +// CHECK: [[VAR_2_:%.+]] = tosa.add [[VAR_1_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.log [[VAR_2_]] : (tensor) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_softplus_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Softplus"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_softplus_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Softplus"([[PARAM_0_]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + + +// ----- + +func.func @test_thresholdedrelu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.ThresholdedRelu"(%arg0) {alpha = 0.5 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_thresholdedrelu +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<5.000000e-01> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> +// CHECK: [[VAR_2_:%.+]] = tosa.greater %arg0, [[VAR_0_]] +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], %arg0, [[VAR_1_]] +// CHECK: return [[VAR_3_]] : tensor<13x21x3xf32> +} + + + +func.func @test_thresholdedrelu_default_value(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +// CHECK-LABEL: test_thresholdedrelu_default_value +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.greater %arg0, [[VAR_0_]] +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], %arg0, [[VAR_1_]] +// CHECK: return [[VAR_3_]] : tensor<13x21x3xf32> +} + +// ----- + +func.func @test_thresholded_relu_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +// CHECK-LABEL: func.func @test_thresholded_relu_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.greater [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.select [[VAR_2_]], [[PARAM_0_]], [[VAR_1_]] : (tensor, tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: return [[VAR_3_]] : tensor +} + +// ----- + +func.func @test_thresholded_relu_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.ThresholdedRelu"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_thresholded_relu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.ThresholdedRelu"([[PARAM_0_]]) {alpha = 1.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +} + +// ----- + +func.func @test_sigmoid(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Sigmoid"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_sigmoid +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.sigmoid [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_ceil(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Ceil"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_ceil +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.ceil [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_exp(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Exp"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_exp +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_log(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Log"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_log +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.log [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_reciprocal(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Reciprocal"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_reciprocal +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_tanh(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Tanh"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_tanh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.tanh [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_clip(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor<3xi32> { + %0 = "onnx.Clip"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor, tensor) -> tensor<3xi32> + return %0 : tensor<3xi32> +// CHECK-LABEL: func @test_clip +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.maximum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xi32>, tensor) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.minimum [[VAR_0_]], [[PARAM_2_]] : (tensor<3xi32>, tensor) -> tensor<3xi32> +// CHECK-NEXT: return [[VAR_1_]] : tensor<3xi32> +// CHECK-NEXT: } +} + +// ----- + +// Test when min is none +func.func @test_clip_default_min_f32(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<3xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Clip"(%arg0, %cst, %arg1) : (tensor<3xf32>, none, tensor) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func @test_clip_default_min_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>, [[PARAM_1_:%.+]]: tensor) -> tensor<3xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.minimum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xf32>, tensor) -> tensor<3xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xf32> +// CHECK-NEXT: } +} + +// ----- + +// Test when max is none +func.func @test_clip_default_max_bf16(%arg0: tensor<3xbf16>, %arg1: tensor) -> tensor<3xbf16> { + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Clip"(%arg0, %arg1, %cst) : (tensor<3xbf16>, tensor, none) -> tensor<3xbf16> + return %0 : tensor<3xbf16> +// CHECK-LABEL: func @test_clip_default_max_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>, [[PARAM_1_:%.+]]: tensor) -> tensor<3xbf16> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.maximum [[PARAM_0_]], [[PARAM_1_]] : (tensor<3xbf16>, tensor) -> tensor<3xbf16> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16> +// CHECK-NEXT: } +} + +// ----- + +// Test when min and max are splat constants +func.func @test_clip_constant_minimum_maximum(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { + %cst1 = "onnx.Constant"() {value = dense<-2.0> : tensor} : () -> tensor + %cst2 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xbf16>} : () -> tensor<1xbf16> + %0 = "onnx.Clip"(%arg0, %cst1, %cst2) : (tensor<3xbf16>, tensor, tensor<1xbf16>) -> tensor<3xbf16> + return %0 : tensor<3xbf16> +// CHECK-LABEL: func @test_clip_constant_minimum_maximum +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 2.000000e+00 : f32, max_int = 2 : i64, min_fp = -2.000000e+00 : f32, min_int = -2 : i64} : (tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16> +// CHECK-NEXT: } +} + +// ----- + +// Test when min and max are constants and min is non-splat. +func.func @test_clip_constant_minimum_maximum_non_splat(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %cst1 = "onnx.Constant"() {value = dense<[-1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst2 = "onnx.Constant"() {value = dense<[2]> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "onnx.Clip"(%arg0, %cst1, %cst2) : (tensor<3xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +// CHECK-LABEL: func @test_clip_constant_minimum_maximum_non_splat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[-1, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_3_:%.+]] = tosa.minimum [[VAR_2_]], [[VAR_1_]] : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: return [[VAR_3_]] : tensor<3xi32> +// CHECK-NEXT: } +} + func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> "func.return"(%0) : (tensor<13x21x1xf32>) -> () @@ -132,3 +667,581 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> // CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> } + +// ----- + +func.func @test_pow(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_pow +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.pow [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +func.func @test_pow_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_pow_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.pow [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +} + +func.func @test_pow_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64> + "func.return"(%0) : (tensor<13x21x1xf64>) -> () +// CHECK-LABEL: func @test_pow +// CHECK-NOT: onnx.Pow +// CHECK: return {{.*}}: tensor<13x21x1xf64> +} + +// ----- + +func.func @test_pow_mixed_types(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) -> (tensor<3xf32>) { + // CHECK-LABEL: func @test_pow_mixed_types + // CHECK-SAME: ([[PARAM_0:%.*]]: tensor<3xf32>, [[PARAM_1:%.*]]: tensor<3xi32>) -> tensor<3xf32> + // CHECK: "onnx.Pow"([[PARAM_0]], [[PARAM_1]]) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + %0 = "onnx.Pow"(%arg0, %arg1) {onnx_node_name = "onnx.Pow_0"} : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +} + +// ----- + +func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.Sqrt"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func @test_sqrt +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.pow [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-NEXT: return [[VAR_1_]] : tensor<3xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_abs_i32(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +// CHECK-LABEL: func @test_abs_i32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.abs [[PARAM_0_]] : (tensor<3xi32>) -> tensor<3xi32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xi32> +// CHECK-NEXT: } +} + +func.func @test_abs_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xbf16>) -> tensor<3xbf16> + return %0 : tensor<3xbf16> +// CHECK-LABEL: func @test_abs_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.abs [[PARAM_0_]] : (tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16> +// CHECK-NEXT: } +} + +func.func @test_abs_f64(%arg0: tensor<3xf64>) -> tensor<3xf64> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xf64>) -> tensor<3xf64> + return %0 : tensor<3xf64> +// CHECK-LABEL: func @test_abs_f64 +// CHECK-NOT: onnx.Abs +// CHECK: return {{.*}}: tensor<3xf64> +} + +// ----- + +func.func @test_erf_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.Erf"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func @test_erf_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.erf [[PARAM_0_]] : (tensor<3xf32>) -> tensor<3xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xf32> +// CHECK-NEXT: } +} + +func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { + %0 = "onnx.Erf"(%arg0) : (tensor<3xbf16>) -> tensor<3xbf16> + return %0 : tensor<3xbf16> +// CHECK-LABEL: func @test_erf_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.erf [[PARAM_0_]] : (tensor<3xbf16>) -> tensor<3xbf16> +// CHECK-NEXT: return [[VAR_0_]] : tensor<3xbf16> +// CHECK-NEXT: } +} + +func.func @test_erf_f64(%arg0: tensor<3xf64>) -> tensor<3xf64> { + %0 = "onnx.Erf"(%arg0) : (tensor<3xf64>) -> tensor<3xf64> + return %0 : tensor<3xf64> +// CHECK-LABEL: func @test_erf_f64 +// CHECK-NOT: onnx.Erf +// CHECK: return %0 : tensor<3xf64> +} + +// ----- + +func.func @test_bitwise_not(%arg0 : tensor<10x10xi32>) -> tensor<10x10xi32> { + %0 = "onnx.BitwiseNot"(%arg0) : (tensor<10x10xi32>) -> tensor<10x10xi32> + "func.return"(%0) : (tensor<10x10xi32>) -> () +// CHECK-LABEL: func @test_bitwise_not +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xi32>) -> tensor<10x10xi32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.bitwise_not [[PARAM_0_]] : (tensor<10x10xi32>) -> tensor<10x10xi32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xi32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_not(%arg0 : tensor<10x10xi1>) -> tensor<10x10xi1> { + %0 = "onnx.Not"(%arg0) : (tensor<10x10xi1>) -> tensor<10x10xi1> + "func.return"(%0) : (tensor<10x10xi1>) -> () +// CHECK-LABEL: func @test_not +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xi1>) -> tensor<10x10xi1> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.logical_not [[PARAM_0_]] : (tensor<10x10xi1>) -> tensor<10x10xi1> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xi1> +// CHECK-NEXT: } +} + +// ----- + +// Default values: alpha = 0.2, beta = 0.5 +func.func @test_hardsigmoid_default_values_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func.func @test_hardsigmoid_default_values_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<2.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_4_]] : tensor<3xf32> +} + +func.func @test_hardsigmoid_default_values_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { + %0 = "onnx.HardSigmoid"(%arg0) : (tensor<3xf16>) -> tensor<3xf16> + return %0 : tensor<3xf16> +// CHECK-LABEL: func @test_hardsigmoid_default_values_f16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<2.500000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.999510e-01> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf16>) -> tensor<3xf16> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK: return [[VAR_4_]] : tensor<3xf16> +} + +// alpha = 0.166666672, beta = 5.000000e-01 +func.func @test_hardsigmoid_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func @test_hardsigmoid_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_4_]] : tensor<3xf32> +} + +func.func @test_hardsigmoid_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf16>) -> tensor<3xf16> + return %0 : tensor<3xf16> +// CHECK-LABEL: func @test_hardsigmoid_f16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3xf16>) -> tensor<3xf16> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK: return [[VAR_4_]] : tensor<3xf16> +} + +// ----- + +func.func @test_hardsigmoid_dynamic(%arg0: tensor) -> tensor { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor) -> tensor + return %0 : tensor +// CHECK-LABEL: func.func @test_hardsigmoid_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1x1x1xf16>}> : () -> tensor<1x1x1xf16> +// CHECK: [[VAR_2_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: [[VAR_3_:%.+]] = tosa.clamp [[VAR_2_]] {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf16>) -> tensor +// CHECK: return [[VAR_4_]] : tensor +} + +// ----- + +func.func @test_hardsigmoid_unranked(%arg0: tensor<*xf16>) -> tensor<*xf16> { + %0 = "onnx.HardSigmoid"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> + return %0 : tensor<*xf16> +// CHECK-LABEL: func.func @test_hardsigmoid_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf16>) -> tensor<*xf16> { +// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<*xf16>) -> tensor<*xf16> +// CHECK: return [[VAR_0_]] : tensor<*xf16> +} + +// ----- + +func.func @test_elu_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<3xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func.func @test_elu_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<3xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<3xi1>, tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_7_]] +} + +func.func @test_elu_f16(%arg0: tensor<3xf16>) -> tensor<3xf16> { + %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor<3xf16>) -> tensor<3xf16> + return %0 : tensor<3xf16> +// CHECK-LABEL: func.func @test_elu_f16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf16>) -> tensor<3xf16> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.666260e-01> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<3xf16>) -> tensor<3xf16> +// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xf16> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<3xf16>, tensor<1xf16>) -> tensor<3xi1> +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<3xi1>, tensor<3xf16>, tensor<3xf16>) -> tensor<3xf16> +// CHECK: return [[VAR_7_]] +} + +// ----- + +func.func @test_elu_unranked(%arg0: tensor<*xf32>) -> tensor<3xf32> { + %0 = "onnx.Elu"(%arg0) {alpha = 0.166666672 : f32} : (tensor<*xf32>) -> tensor<3xf32> + return %0 : tensor<3xf32> +// CHECK-LABEL: func.func @test_elu_unranked +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.166666672> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.exp [[PARAM_0_]] : (tensor<*xf32>) -> tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.sub [[VAR_3_]], [[VAR_0_]] : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.greater_equal [[PARAM_0_]], [[VAR_2_]] : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xi1> +// CHECK: [[VAR_7_:%.+]] = tosa.select [[VAR_6_]], [[PARAM_0_]], [[VAR_5_]] : (tensor<*xi1>, tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: return [[VAR_7_]] : tensor<3xf32> +// CHECK: } +} + + +// ----- + +func.func @test_and(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_and +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.logical_and [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> +} + +// ----- + +func.func @test_and_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.And"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_and_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: [[VAR_1_:%.+]] = tosa.logical_and [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} +// ----- + +func.func @test_bitwise_and(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func @test_bitwise_and +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.bitwise_and [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> +} +// ----- + +func.func @test_bitwise_and_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseAnd"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func.func @test_bitwise_and_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi64>) -> tensor<1x1x1xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.bitwise_and [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64> +} +// ----- + +func.func @test_or(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_or +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.logical_or [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> +} +// ----- + +func.func @test_or_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.Or"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_or_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: [[VAR_1_:%.+]] = tosa.logical_or [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} +// ----- + +func.func @test_bitwise_or(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func @test_bitwise_or +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.bitwise_or [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> +} +// ----- + +func.func @test_bitwise_or_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseOr"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func.func @test_bitwise_or_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi64>) -> tensor<1x1x1xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.bitwise_or [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64> +} + +// ----- + +func.func @test_xor(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_xor +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xi1>) -> tensor<13x21x1xi1> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.logical_xor [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi1>, tensor<13x21x1xi1>) -> tensor<13x21x1xi1> +} + +// ----- + +func.func @test_xor_broadcast(%arg0: tensor<13x21x1xi1>, %arg1: tensor<1xi1>) -> tensor<13x21x1xi1> { + %0 = "onnx.Xor"(%arg0, %arg1) : (tensor<13x21x1xi1>, tensor<1xi1>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_xor_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<1xi1>) -> tensor<13x21x1xi1> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: [[VAR_1_:%.+]] = tosa.logical_xor [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi1>, tensor<1x1x1xi1>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} +// ----- + +func.func @test_bitwise_xor(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func @test_bitwise_xor +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<13x21x1xi64>) -> tensor<13x21x1xi64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.bitwise_xor [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xi64>, tensor<13x21x1xi64>) -> tensor<13x21x1xi64> +} +// ----- + +func.func @test_bitwise_xor_broadcast(%arg0: tensor<13x21x1xi64>, %arg1: tensor<1xi64>) -> tensor<13x21x1xi64> { + %0 = "onnx.BitwiseXor"(%arg0, %arg1) : (tensor<13x21x1xi64>, tensor<1xi64>) -> tensor<13x21x1xi64> + "func.return"(%0) : (tensor<13x21x1xi64>) -> () +// CHECK-LABEL: func.func @test_bitwise_xor_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi64>, [[PARAM_1_:%.+]]: tensor<1xi64>) -> tensor<13x21x1xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi64>) -> tensor<1x1x1xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.bitwise_xor [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi64>, tensor<1x1x1xi64>) -> tensor<13x21x1xi64> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi64> +} + +// ----- + +func.func @test_min(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.minimum [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_min_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_min_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.minimum [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +} +// ----- + +func.func @test_max(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.maximum [[PARAM_0_]], [[PARAM_1_]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_max_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_max_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.maximum [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_sin(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Sin"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_sin +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.sin [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_sin_dynamic(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sin"(%arg0) : (tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_sin_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.sin [[PARAM_0_]] : (tensor) -> tensor +// CHECK-NEXT: return [[VAR_0_]] : tensor +// CHECK-NEXT: } +} + +// ----- + +func.func @test_cos(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Cos"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + "func.return"(%0) : (tensor<10x10xf32>) -> () +// CHECK-LABEL: func @test_cos +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.cos [[PARAM_0_]] : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> +// CHECK-NEXT: } +} + +// ----- + +func.func @test_cos_dynamic(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Cos"(%arg0) : (tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func @test_cos_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.cos [[PARAM_0_]] : (tensor) -> tensor +// CHECK-NEXT: return [[VAR_0_]] : tensor +// CHECK-NEXT: } +} + +// ----- + +func.func @test_equal(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Equal"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_equal +// CHECK: [[VAR_0_:%.+]] = tosa.equal %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xi1> +} + +func.func @test_equal_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Equal"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_equal_broadcast +// CHECK: [[VAR_0_:%.+]] = tosa.reshape %arg1 {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.equal %arg0, [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} + +// ----- + +func.func @test_greaterequal(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.GreaterOrEqual"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_greaterequal +// CHECK: [[VAR_0_:%.+]] = tosa.greater_equal %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xi1> +} + +func.func @test_greaterequal_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.GreaterOrEqual"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_greaterequal_broadcast +// CHECK: [[VAR_0_:%.+]] = tosa.reshape %arg1 {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.greater_equal %arg0, [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} + +// ----- + +func.func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_greater +// CHECK: [[VAR_0_:%.+]] = tosa.greater %arg0, %arg1 : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xi1> +} + +func.func @test_greater_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_greater_broadcast +// CHECK: [[VAR_0_:%.+]] = tosa.reshape %arg1 {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.greater %arg0, [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} + +// ----- + +func.func @test_lessequal(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.LessOrEqual"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_lessequal +// CHECK: [[VAR_0_:%.+]] = tosa.greater_equal %arg1, %arg0 : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xi1> +} + +func.func @test_lessequal_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.LessOrEqual"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_lessequal_broadcast +// CHECK: [[VAR_0_:%.+]] = tosa.reshape %arg1 {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.greater_equal [[VAR_0_]], %arg0 : (tensor<1x1x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} + +// ----- + +func.func @test_less(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Less"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func @test_less +// CHECK: [[VAR_0_:%.+]] = tosa.greater %arg1, %arg0 : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xi1> +} + +func.func @test_less_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) -> tensor<13x21x1xi1> { + %0 = "onnx.Less"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xi1> + "func.return"(%0) : (tensor<13x21x1xi1>) -> () +// CHECK-LABEL: func.func @test_less_broadcast +// CHECK: [[VAR_0_:%.+]] = tosa.reshape %arg1 {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.greater [[VAR_0_]], %arg0 : (tensor<1x1x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xi1> +// CHECK: return [[VAR_1_]] : tensor<13x21x1xi1> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir index 5ccbd32a28..09b1685f44 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir @@ -41,4 +41,4 @@ func.func @gemm_to_fc_opt(%arg0: tensor<1x5xf32>, %arg1: tensor<4x5xf32>) -> ten // CHECK: %[[VAL_4:.*]] = tosa.fully_connected %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : (tensor<1x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[VAL_4]] : tensor<1x4xf32> // CHECK: } -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir index 3654d493ea..d1c99e06d0 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir @@ -161,4 +161,4 @@ func.func @test_mixed(%arg0: tensor<11x5xf32>, %arg1: tensor<3x11xf32>, %arg2: t // CHECK: [[VAR_13_:%.+]] = tosa.reshape [[VAR_12_]] {new_shape = array} : (tensor<1x5x3xf32>) -> tensor<5x3xf32> // CHECK: return [[VAR_13_]] : tensor<5x3xf32> // CHECK: } -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Pow-unsupported.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Pow-unsupported.mlir new file mode 100644 index 0000000000..6ba6f49000 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/Pow-unsupported.mlir @@ -0,0 +1,16 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +// onnx.Pow with integer exponent is not supported in TOSA. +// This test checks that the conversion does not fail but, instead, it keeps the +// original version of the op. + +func.func @test_less_broadcast(%arg0: tensor<5xi32>, %arg1: tensor<5xi32>) -> tensor<*xi32> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<5xi32>, tensor<5xi32>) -> tensor<*xi32> + onnx.Return %0 : tensor<*xi32> + +// CHECK: test_less_broadcast +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<5xi32>, %[[ARG_1:.*]]: tensor<5xi32>) -> tensor<5xi32> +// CHECK-NEXT: %[[VAL_0:.*]] = "onnx.Pow"(%[[ARG_0]], %[[ARG_1]]) : (tensor<5xi32>, tensor<5xi32>) -> tensor<5xi32> +// CHECK-NEXT: onnx.Return %[[VAL_0]] : tensor<5xi32> + +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMax.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMax.mlir new file mode 100644 index 0000000000..c8ea641c89 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMax.mlir @@ -0,0 +1,89 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @reduce_max(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMax"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_max( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_max_no_axes_attr(%arg0: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceMax"(%arg0, %none) : (tensor<2x5x9x11xf32>, none) -> tensor<1x1x1x1xf32> +return %0 : tensor<1x1x1x1xf32> +// CHECK-LABEL: func.func @reduce_max_no_axes_attr( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 0 : i32} : (tensor<2x5x9x11xf32>) -> tensor<1x5x9x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 1 : i32} : (tensor<1x5x9x11xf32>) -> tensor<1x1x9x11xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<1x1x1x1xf32> +} + +// ----- + +func.func @reduce_max_keepdims_false(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMax"(%arg0, %0) {keepdims = 0 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5xf32> +return %1 : tensor<2x5xf32> +// CHECK-LABEL: func.func @reduce_max_keepdims_false( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> +// CHECK: return %[[VAL_3]] : tensor<2x5xf32> +} + +// ----- + +func.func @reduce_max_noop_with_emtpy_axes_one(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMax"(%arg0, %0) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_max_noop_with_emtpy_axes_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_max_noop_with_emtpy_axes_one_none_input(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceMax"(%arg0, %none) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, none) -> tensor<2x5x9x11xf32> +return %0 : tensor<2x5x9x11xf32> +// CHECK-LABEL: func.func @reduce_max_noop_with_emtpy_axes_one_none_input( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> +// CHECK: return %[[VAL_1]] : tensor<2x5x9x11xf32> +} + +// ----- + +func.func @test_reducemaxV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> { + %0 = "onnx.ReduceMaxV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> + return %0 : tensor<1x32x1x1xf32> +// CHECK-LABEL: func.func @test_reducemaxV13 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_max %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_max [[VAR_0_]] {axis = 3 : i32} +// CHECK: return [[VAR_1_]] : tensor<1x32x1x1xf32> +} + +// ----- + +func.func @test_reducemaxV13_keep_dims_0(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32xf32> { + %0 = "onnx.ReduceMaxV13"(%arg0) {axes = [2, 3], keepdims = 0 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32xf32> + return %0 : tensor<1x32xf32> +// CHECK-LABEL: func.func @test_reducemaxV13_keep_dims_0 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_max %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_max [[VAR_0_]] {axis = 3 : i32} +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} +// CHECK: return [[VAR_2_]] : tensor<1x32xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir index 937638926e..71629af893 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir @@ -8,10 +8,9 @@ return %1 : tensor<2x5x1x1xf32> // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> -// CHECK: return %[[VAL_5]] : tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32> } // ----- @@ -26,10 +25,9 @@ return %0 : tensor<1x1x1x1xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<1x5x9x11xf32>) -> tensor<1x1x9x11xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.00101010106> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> -// CHECK: return %[[VAL_7]] : tensor<1x1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.00101010106> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_6]] : tensor<1x1x1x1xf32> } // ----- @@ -43,10 +41,9 @@ return %1 : tensor<2x5xf32> // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x5xf32>, tensor<1x1xf32>) -> tensor<2x5xf32> -// CHECK: return %[[VAL_6]] : tensor<2x5xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5xf32>, tensor<1x1xf32>) -> tensor<2x5xf32> +// CHECK: return %[[VAL_5]] : tensor<2x5xf32> } // ----- @@ -59,10 +56,9 @@ return %1 : tensor<2x5x1x1xf32> // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> -// CHECK: return %[[VAL_5]] : tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<2x5x1x1xf32> } // ----- @@ -73,9 +69,27 @@ func.func @reduce_mean_noop_with_emtpy_axes_one_none_input(%arg0: tensor<2x5x9x1 return %0 : tensor<2x5x9x11xf32> // CHECK-LABEL: func.func @reduce_mean_noop_with_emtpy_axes_one_none_input( // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x9x11xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x9x11xf32> -// CHECK: return %[[VAL_4]] : tensor<2x5x9x11xf32> +// CHECK: return %[[VAL_0]] : tensor<2x5x9x11xf32> } + +// ----- + +func.func @test_reducemeanV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> { + %0 = "onnx.ReduceMeanV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> + return %0 : tensor<1x32x1x1xf32> +// CHECK-LABEL: func.func @test_reducemeanV13 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_sum %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 3 : i32} +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<7.97193861E-5> : tensor<1x1x1x1xf32>}> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_1_]], [[VAR_2_]] {shift = 0 : i8} : (tensor<1x32x1x1xf32>, tensor<1x1x1x1xf32>) +// CHECK: return [[VAR_3_]] : tensor<1x32x1x1xf32> +} + +// ----- + +func.func @non_constant_axis(%arg0: tensor<3x2x2xf32>, %arg1: tensor<1xi64>) -> tensor<3x2xf32> { + %0 = "onnx.ReduceMean"(%arg0, %arg1) {keepdims = 0 : si64} : (tensor<3x2x2xf32>, tensor<1xi64>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// CHECK-LABEL: non_constant_axis +// CHECK: onnx.ReduceMean diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMin.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMin.mlir new file mode 100644 index 0000000000..c9326c5da3 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMin.mlir @@ -0,0 +1,89 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @reduce_min(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMin"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_min( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_min %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_min_no_axes_attr(%arg0: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceMin"(%arg0, %none) : (tensor<2x5x9x11xf32>, none) -> tensor<1x1x1x1xf32> +return %0 : tensor<1x1x1x1xf32> +// CHECK-LABEL: func.func @reduce_min_no_axes_attr( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_min %[[VAL_0]] {axis = 0 : i32} : (tensor<2x5x9x11xf32>) -> tensor<1x5x9x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 1 : i32} : (tensor<1x5x9x11xf32>) -> tensor<1x1x9x11xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<1x1x1x1xf32> +} + +// ----- + +func.func @reduce_min_keepdims_false(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMin"(%arg0, %0) {keepdims = 0 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5xf32> +return %1 : tensor<2x5xf32> +// CHECK-LABEL: func.func @reduce_min_keepdims_false( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_min %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> +// CHECK: return %[[VAL_3]] : tensor<2x5xf32> +} + +// ----- + +func.func @reduce_min_noop_with_emtpy_axes_one(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceMin"(%arg0, %0) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_min_noop_with_emtpy_axes_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_min %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_min_noop_with_emtpy_axes_one_none_input(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceMin"(%arg0, %none) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, none) -> tensor<2x5x9x11xf32> +return %0 : tensor<2x5x9x11xf32> +// CHECK-LABEL: func.func @reduce_min_noop_with_emtpy_axes_one_none_input( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> +// CHECK: return %[[VAL_1]] : tensor<2x5x9x11xf32> +} + +// ----- + +func.func @test_reduceminV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> { + %0 = "onnx.ReduceMinV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> + return %0 : tensor<1x32x1x1xf32> +// CHECK-LABEL: func.func @test_reduceminV13 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_min %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_min [[VAR_0_]] {axis = 3 : i32} +// CHECK: return [[VAR_1_]] : tensor<1x32x1x1xf32> +} + +// ----- + +func.func @test_reduceminV13_keep_dims_0(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32xf32> { + %0 = "onnx.ReduceMinV13"(%arg0) {axes = [2, 3], keepdims = 0 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32xf32> + return %0 : tensor<1x32xf32> +// CHECK-LABEL: func.func @test_reduceminV13_keep_dims_0 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_min %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_min [[VAR_0_]] {axis = 3 : i32} +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} +// CHECK: return [[VAR_2_]] : tensor<1x32xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceProd.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceProd.mlir new file mode 100644 index 0000000000..dbbaae004e --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceProd.mlir @@ -0,0 +1,90 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @reduce_prod(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceProd"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_prod( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_prod %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_prod_no_axes_attr(%arg0: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceProd"(%arg0, %none) : (tensor<2x5x9x11xf32>, none) -> tensor<1x1x1x1xf32> +return %0 : tensor<1x1x1x1xf32> +// CHECK-LABEL: func.func @reduce_prod_no_axes_attr( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_prod %[[VAL_0]] {axis = 0 : i32} : (tensor<2x5x9x11xf32>) -> tensor<1x5x9x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 1 : i32} : (tensor<1x5x9x11xf32>) -> tensor<1x1x9x11xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_prod %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_prod %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<1x1x1x1xf32> +} + +// ----- + +func.func @reduce_prod_keepdims_false(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceProd"(%arg0, %0) {keepdims = 0 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5xf32> +return %1 : tensor<2x5xf32> +// CHECK-LABEL: func.func @reduce_prod_keepdims_false( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_prod %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> +// CHECK: return %[[VAL_3]] : tensor<2x5xf32> +} + +// ----- + +func.func @reduce_prod_noop_with_emtpy_axes_one(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceProd"(%arg0, %0) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_prod_noop_with_emtpy_axes_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_prod %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_prod_noop_with_emtpy_axes_one_none_input(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceProd"(%arg0, %none) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, none) -> tensor<2x5x9x11xf32> +return %0 : tensor<2x5x9x11xf32> +// CHECK-LABEL: func.func @reduce_prod_noop_with_emtpy_axes_one_none_input( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> +// CHECK: return %[[VAL_1]] : tensor<2x5x9x11xf32> +} + +// ----- + +func.func @test_reduceprodV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> { + %0 = "onnx.ReduceProdV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> + return %0 : tensor<1x32x1x1xf32> +// CHECK-LABEL: func.func @test_reduceprodV13 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_prod %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_prod [[VAR_0_]] {axis = 3 : i32} +// CHECK: return [[VAR_1_]] : tensor<1x32x1x1xf32> +} + +// ----- + +func.func @test_reduceprodV13_keep_dims_false(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32xf32> { + %0 = "onnx.ReduceProdV13"(%arg0) {axes = [2, 3], keepdims = 0 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32xf32> + return %0 : tensor<1x32xf32> +// CHECK-LABEL: func.func @test_reduceprodV13_keep_dims_false +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_prod %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_prod [[VAR_0_]] {axis = 3 : i32} +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} +// CHECK: return [[VAR_2_]] : tensor<1x32xf32> +} + diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceSum.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceSum.mlir new file mode 100644 index 0000000000..4226b2e2d9 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceSum.mlir @@ -0,0 +1,89 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @reduce_sum(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceSum"(%arg0, %0) : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_sum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_sum_no_axes_attr(%arg0: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceSum"(%arg0, %none) : (tensor<2x5x9x11xf32>, none) -> tensor<1x1x1x1xf32> +return %0 : tensor<1x1x1x1xf32> +// CHECK-LABEL: func.func @reduce_sum_no_axes_attr( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<1x1x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x5x9x11xf32>) -> tensor<1x5x9x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<1x5x9x11xf32>) -> tensor<1x1x9x11xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> +// CHECK: return %[[VAL_4]] : tensor<1x1x1x1xf32> +} + +// ----- + +func.func @reduce_sum_keepdims_false(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceSum"(%arg0, %0) {keepdims = 0 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5xf32> +return %1 : tensor<2x5xf32> +// CHECK-LABEL: func.func @reduce_sum_keepdims_false( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> +// CHECK: return %[[VAL_3]] : tensor<2x5xf32> +} + +// ----- + +func.func @reduce_sum_noop_with_emtpy_axes_one(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +%0 = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +%1 = "onnx.ReduceSum"(%arg0, %0) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, tensor<2xi64>) -> tensor<2x5x1x1xf32> +return %1 : tensor<2x5x1x1xf32> +// CHECK-LABEL: func.func @reduce_sum_noop_with_emtpy_axes_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x1x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> +// CHECK: return %[[VAL_2]] : tensor<2x5x1x1xf32> +} + +// ----- + +func.func @reduce_sum_noop_with_emtpy_axes_one_none_input(%arg0: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +%none = "onnx.NoValue"() {value} : () -> none +%0 = "onnx.ReduceSum"(%arg0, %none) {noop_with_empty_axes = 1 : si64} : (tensor<2x5x9x11xf32>, none) -> tensor<2x5x9x11xf32> +return %0 : tensor<2x5x9x11xf32> +// CHECK-LABEL: func.func @reduce_sum_noop_with_emtpy_axes_one_none_input( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> +// CHECK: return %[[VAL_1]] : tensor<2x5x9x11xf32> +} + +// ----- + +func.func @test_reducesumV13(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> { + %0 = "onnx.ReduceSumV11"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32x1x1xf32> + return %0 : tensor<1x32x1x1xf32> +// CHECK-LABEL: func.func @test_reducesumV13 +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_sum %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 3 : i32} +// CHECK: return [[VAR_1_]] : tensor<1x32x1x1xf32> +} + +// ----- + +func.func @test_reducesumV11_keep_dims_false(%arg0: tensor<1x32x112x112xf32>) -> tensor<1x32xf32> { + %0 = "onnx.ReduceSumV11"(%arg0) {axes = [2, 3], keepdims = 0 : si64} : (tensor<1x32x112x112xf32>) -> tensor<1x32xf32> + return %0 : tensor<1x32xf32> +// CHECK-LABEL: func.func @test_reducesumV11_keep_dims_false +// CHECK: [[VAR_0_:%.+]] = tosa.reduce_sum %arg0 {axis = 2 : i32} +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reduce_sum [[VAR_0_]] {axis = 3 : i32} +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} +// CHECK: return [[VAR_2_]] : tensor<1x32xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir index 2ecb8f5795..d2013c44a7 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir @@ -4,7 +4,9 @@ func.func @test_softmax_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "onnx.Softmax"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> // CHECK: test_softmax_v13(%[[VAL_0:.*]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[MAX:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 2 : i32} +// CHECK: %[[SUB:.*]] = tosa.sub %[[VAL_0]], %[[MAX]] +// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 2 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> @@ -16,7 +18,9 @@ func.func @test_softmax_v13_axis_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21 %2 = "onnx.Softmax"(%arg0) {axis = 1 : si64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> // CHECK: test_softmax_v13_axis_one(%[[VAL_0:.*]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[MAX:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 1 : i32} +// CHECK: %[[SUB:.*]] = tosa.sub %[[VAL_0]], %[[MAX]] +// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x1x3xf32>) -> tensor<13x1x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> @@ -28,7 +32,10 @@ func.func @test_softmax_before_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3 %2 = "onnx.SoftmaxV11"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> // CHECK: test_softmax_before_v13(%[[VAL_0:.*]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[MAX:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 1 : i32} +// CHECK: %[[MAX2:.*]] = tosa.reduce_max %[[MAX]] {axis = 2 : i32} +// CHECK: %[[SUB:.*]] = tosa.sub %[[VAL_0]], %[[MAX2]] +// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<13x1x3xf32>) -> tensor<13x1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<13x1x1xf32>) -> tensor<13x1x1xf32> @@ -41,10 +48,14 @@ func.func @test_softmax_before_v13_axis_zero(%arg0: tensor<13x21x3xf32>) -> tens %2 = "onnx.SoftmaxV11"(%arg0) {axis = 0 : si64}: (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> // CHECK: test_softmax_before_v13_axis_zero(%[[VAL_0:.*]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[MAX:.*]] = tosa.reduce_max %[[VAL_0]] {axis = 0 : i32} +// CHECK: %[[MAX2:.*]] = tosa.reduce_max %[[MAX]] {axis = 1 : i32} +// CHECK: %[[MAX3:.*]] = tosa.reduce_max %[[MAX2]] {axis = 2 : i32} +// CHECK: %[[SUB:.*]] = tosa.sub %[[VAL_0]], %[[MAX3]] +// CHECK: %[[VAL_1:.*]] = tosa.exp %[[SUB]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 1 : i32} : (tensor<1x21x3xf32>) -> tensor<1x1x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1x1xf32>) -> tensor<1x1x1xf32> // CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir b/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir index 8ebeaa30c6..4afb6243f2 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir @@ -59,7 +59,7 @@ func.func @test_default_averagepool_strides(%arg0 : tensor<5x5x32x32xf32>) -> te // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32> { // CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[VAR_0_]] : (tensor<5x5x32x32xf32>, tensor<4xi32>) -> tensor<5x32x32x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[VAR_1_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x32x32x5xf32>) -> tensor<5x16x16x5xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[VAR_1_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x32x32x5xf32>) -> tensor<5x16x16x5xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x16x16x5xf32>, tensor<4xi32>) -> tensor<5x5x16x16xf32> // CHECK: return [[VAR_4_]] : tensor<5x5x16x16xf32> @@ -67,7 +67,7 @@ func.func @test_default_averagepool_strides(%arg0 : tensor<5x5x32x32xf32>) -> te // ----- -/// Test the behavior of AveragePool with strides and non uniform padding +/// Test the behavior of AveragePool with strides and non uniform padding func.func @test_default_averagepool_strides_nonunifpad(%arg0 : tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> { %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2,2], pads = [1, 0, 0, 0], strides = [2, 2] } : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> "func.return"(%0) : (tensor<5x5x15x16xf32>) -> () @@ -76,7 +76,8 @@ func.func @test_default_averagepool_strides_nonunifpad(%arg0 : tensor<5x5x30x32x // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> { // CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[VAR_0_]] : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[VAR_1_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x15x16x5xf32> +// CHECK: [[SLICE:%.+]] = tosa.slice [[VAR_1_]] {size = array, start = array} : (tensor<5x30x32x5xf32>) -> tensor<5x29x32x5xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[SLICE]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x29x32x5xf32>) -> tensor<5x15x16x5xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x15x16x5xf32>, tensor<4xi32>) -> tensor<5x5x15x16xf32> // CHECK: return [[VAR_4_]] : tensor<5x5x15x16xf32> @@ -93,7 +94,7 @@ func.func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x3 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32> { // CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[VAR_0_]] : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[VAR_1_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x16x16x5xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.avg_pool2d [[VAR_1_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x16x16x5xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x16x16x5xf32>, tensor<4xi32>) -> tensor<5x5x16x16xf32> // CHECK: return [[VAR_4_]] : tensor<5x5x16x16xf32> @@ -189,8 +190,66 @@ func.func @test_averagepool_strides_nonunifpad_ceil_with_count_include_pad(%arg0 // CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x30x32xf32>, tensor<4x2xi64>, tensor) -> tensor<5x5x31x34xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x5x31x34xf32>, tensor<4xi32>) -> tensor<5x31x34x5xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tosa.avg_pool2d [[VAR_4_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x31x34x5xf32>) -> tensor<5x16x17x5xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.avg_pool2d [[VAR_4_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x31x34x5xf32>) -> tensor<5x16x17x5xf32> // CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_7_:%.+]] = tosa.transpose [[VAR_5_]], [[VAR_6_]] : (tensor<5x16x17x5xf32>, tensor<4xi32>) -> tensor<5x5x16x17xf32> // CHECK: return [[VAR_7_]] : tensor<5x5x16x17xf32> // CHECK: } + +// ----- + +func.func @test_averagepool_dilations_one(%arg0 : tensor<1x1x4x4xf32>) -> tensor<1x1x3x3xf32> { + %0 = "onnx.AveragePool"(%arg0) { auto_pad = "NOTSET", + ceil_mode = 1 : si64, + count_include_pad = 0 : si64, + dilations = [1, 1], + kernel_shape = [2, 2], + strides = [1, 1]} : (tensor<1x1x4x4xf32>) -> tensor<1x1x3x3xf32> + "func.return"(%0) : (tensor<1x1x3x3xf32>) -> () +} +// CHECK-LABEL: test_averagepool_dilations_one +// CHECK: tosa.avg_pool2d + +// ----- + +func.func @test_averagepool_dilations(%arg0 : tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> { + %0 = "onnx.AveragePool"(%arg0) { auto_pad = "NOTSET", + ceil_mode = 1 : si64, + count_include_pad = 0 : si64, + dilations = [2, 2], + kernel_shape = [2, 2], + strides = [1, 1]} : (tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> + "func.return"(%0) : (tensor<1x1x2x2xf32>) -> () +} +// CHECK-LABEL: test_averagepool_dilations +// CHECK: onnx.AveragePool + +// ----- + +func.func @test_averagepool_5d(%arg0: tensor<1x1x32x32x32xf32>) -> tensor<1x1x8x8x8xf32> { + %0 = "onnx.AveragePool"(%arg0) { + auto_pad = "NOTSET", + ceil_mode = 0 : si64, + count_include_pad = 1 : si64, + dilations = [2, 2, 2], + kernel_shape = [5, 5, 5], + strides = [3, 3, 3]} : (tensor<1x1x32x32x32xf32>) -> tensor<1x1x8x8x8xf32> + return %0 : tensor<1x1x8x8x8xf32> +} +// CHECK-LABEL: test_averagepool_5d +// CHECK: onnx.AveragePool + +// ----- + +func.func @test_averagepool_dilations_one_dyn_shape(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.AveragePool"(%arg0) { auto_pad = "NOTSET", + ceil_mode = 1 : si64, + count_include_pad = 0 : si64, + dilations = [1, 1], + kernel_shape = [2, 2], + strides = [1, 1]} : (tensor<*xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +} +// CHECK-LABEL: test_averagepool_dilations_one +// CHECK: onnx.AveragePool + diff --git a/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir new file mode 100644 index 0000000000..0c47f22ad4 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir @@ -0,0 +1,114 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_batchnorm_f32(%arg0: tensor<100x3x10x10xf32>) -> tensor<100x3x10x10xf32> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<100x3x10x10xf32> + return %4 : tensor<100x3x10x10xf32> +// CHECK-LABEL: func @test_batchnorm_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x10x10xf32>) -> tensor<100x3x10x10xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[VAR_2_]] {new_shape = array} : (tensor<3xf32>) -> tensor<1x3x1x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<3xf32>) -> tensor<1x3x1x1xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} : (tensor<3xf32>) -> tensor<1x3x1x1xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3xf32>) -> tensor<1x3x1x1xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.00000007E-5> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_4_]] : (tensor<100x3x10x10xf32>, tensor<1x3x1x1xf32>) -> tensor<100x3x10x10xf32> +// CHECK: [[VAR_10_:%.+]] = tosa.add [[VAR_7_]], [[VAR_8_]] : (tensor<1x3x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x1x1xf32> +// CHECK: [[VAR_11_:%.+]] = tosa.rsqrt [[VAR_10_]] : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1xf32> +// CHECK: [[VAR_12_:%.+]] = tosa.mul [[VAR_9_]], [[VAR_11_]] {shift = 0 : i8} : (tensor<100x3x10x10xf32>, tensor<1x3x1x1xf32>) -> tensor<100x3x10x10xf32> +// CHECK: [[VAR_13_:%.+]] = tosa.mul [[VAR_12_]], [[VAR_5_]] {shift = 0 : i8} : (tensor<100x3x10x10xf32>, tensor<1x3x1x1xf32>) -> tensor<100x3x10x10xf32> +// CHECK: [[VAR_14_:%.+]] = tosa.add [[VAR_13_]], [[VAR_6_]] : (tensor<100x3x10x10xf32>, tensor<1x3x1x1xf32>) -> tensor<100x3x10x10xf32> +// CHECK: return [[VAR_14_]] : tensor<100x3x10x10xf32> +} + +// ----- +func.func @test_batchnorm_f16_dynamic(%arg0: tensor<100x3x?x?xf16>) -> tensor<*xf16> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> tensor<*xf16> + return %4 : tensor<*xf16> +// CHECK-LABEL: func @test_batchnorm_f16_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x?x?xf16>) -> tensor<100x3x?x?xf16> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16>}> : () -> tensor<3xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf16>}> : () -> tensor<3xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf16>}> : () -> tensor<3xf16> +// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf16>}> : () -> tensor<3xf16> +// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[VAR_2_]] {new_shape = array} : (tensor<3xf16>) -> tensor<1x3x1x1xf16> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<3xf16>) -> tensor<1x3x1x1xf16> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} : (tensor<3xf16>) -> tensor<1x3x1x1xf16> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3xf16>) -> tensor<1x3x1x1xf16> +// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.001360e-05> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16> +// CHECK: [[VAR_9_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_4_]] : (tensor<100x3x?x?xf16>, tensor<1x3x1x1xf16>) -> tensor<100x3x?x?xf16> +// CHECK: [[VAR_10_:%.+]] = tosa.add [[VAR_7_]], [[VAR_8_]] : (tensor<1x3x1x1xf16>, tensor<1x1x1x1xf16>) -> tensor<1x3x1x1xf16> +// CHECK: [[VAR_11_:%.+]] = tosa.rsqrt [[VAR_10_]] : (tensor<1x3x1x1xf16>) -> tensor<1x3x1x1xf16> +// CHECK: [[VAR_12_:%.+]] = tosa.mul [[VAR_9_]], [[VAR_11_]] {shift = 0 : i8} : (tensor<100x3x?x?xf16>, tensor<1x3x1x1xf16>) -> tensor<100x3x?x?xf16> +// CHECK: [[VAR_13_:%.+]] = tosa.mul [[VAR_12_]], [[VAR_5_]] {shift = 0 : i8} : (tensor<100x3x?x?xf16>, tensor<1x3x1x1xf16>) -> tensor<100x3x?x?xf16> +// CHECK: [[VAR_14_:%.+]] = tosa.add [[VAR_13_]], [[VAR_6_]] : (tensor<100x3x?x?xf16>, tensor<1x3x1x1xf16>) -> tensor<100x3x?x?xf16> +// CHECK: return [[VAR_14_]] : tensor<100x3x?x?xf16> +} + +// ----- + +func.func @test_batchnorm_bf16_dynamic(%arg0: tensor<100x3x?x?xbf16>) -> tensor<*xbf16> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> tensor<*xbf16> + return %4 : tensor<*xbf16> +// CHECK-LABEL: func @test_batchnorm_bf16_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x?x?xbf16>) -> tensor<100x3x?x?xbf16> +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xbf16>}> : () -> tensor<3xbf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xbf16>}> : () -> tensor<3xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xbf16>}> : () -> tensor<3xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xbf16>}> : () -> tensor<3xbf16> +// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[VAR_2_]] {new_shape = array} : (tensor<3xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<3xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK-DAG: [[VAR_6_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} : (tensor<3xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.001360e-05> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> +// CHECK: [[VAR_9_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_4_]] : (tensor<100x3x?x?xbf16>, tensor<1x3x1x1xbf16>) -> tensor<100x3x?x?xbf16> +// CHECK: [[VAR_10_:%.+]] = tosa.add [[VAR_7_]], [[VAR_8_]] : (tensor<1x3x1x1xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK: [[VAR_11_:%.+]] = tosa.rsqrt [[VAR_10_]] : (tensor<1x3x1x1xbf16>) -> tensor<1x3x1x1xbf16> +// CHECK: [[VAR_12_:%.+]] = tosa.mul [[VAR_9_]], [[VAR_11_]] {shift = 0 : i8} : (tensor<100x3x?x?xbf16>, tensor<1x3x1x1xbf16>) -> tensor<100x3x?x?xbf16> +// CHECK: [[VAR_13_:%.+]] = tosa.mul [[VAR_12_]], [[VAR_5_]] {shift = 0 : i8} : (tensor<100x3x?x?xbf16>, tensor<1x3x1x1xbf16>) -> tensor<100x3x?x?xbf16> +// CHECK: [[VAR_14_:%.+]] = tosa.add [[VAR_13_]], [[VAR_6_]] : (tensor<100x3x?x?xbf16>, tensor<1x3x1x1xbf16>) -> tensor<100x3x?x?xbf16> +// CHECK: return [[VAR_14_]] : tensor<100x3x?x?xbf16> +} + +// ----- + +func.func @test_batchnorm_f64(%arg0: tensor<100x3x10x10xf64>) -> tensor<100x3x10x10xf64> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf64>} : () -> tensor<3xf64> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf64>} : () -> tensor<3xf64> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf64>} : () -> tensor<3xf64> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf64>} : () -> tensor<3xf64> + %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<100x3x10x10xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3xf64>) -> tensor<100x3x10x10xf64> + return %4 : tensor<100x3x10x10xf64> +// CHECK-LABEL: @test_batchnorm_f64 +// CHECK-SAME: ([[PARAM_0:%.*]]: tensor<100x3x10x10xf64>) -> tensor<100x3x10x10xf64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_4_:%.+]] = tosa.reshape [[VAR_2_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_5_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_6_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_7_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.0000000656873453E-5> : tensor<1x1x1x1xf64>}> : () -> tensor<1x1x1x1xf64> +// CHECK-NEXT: [[VAR_9_:%.+]] = tosa.sub %arg0, [[VAR_4_]] : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_10_:%.+]] = tosa.add %7, [[VAR_8_]] : (tensor<1x3x1x1xf64>, tensor<1x1x1x1xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_11_:%.+]] = tosa.rsqrt [[VAR_10_]] : (tensor<1x3x1x1xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_12_:%.+]] = tosa.mul [[VAR_9_]], %11 {shift = 0 : i8} : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_13_:%.+]] = tosa.mul [[VAR_12_]], %5 {shift = 0 : i8} : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_14_:%.+]] = tosa.add [[VAR_13_]], [[VAR_6_]] : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: return [[VAR_14_]] : tensor<100x3x10x10xf64> +} diff --git a/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir new file mode 100644 index 0000000000..4947576aa5 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir @@ -0,0 +1,96 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_dequantizeLinear(%arg0 : tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xi8>, tensor, tensor) -> tensor<32x3x224x224xf32> + "func.return"(%2) : (tensor<32x3x224x224xf32>) -> () +} +// CHECK-LABEL: @test_dequantizeLinear +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[CAST_0:.*]] = tosa.cast %[[ARG_0]] : (tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTZP:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[SUB:.*]] = tosa.sub %[[CAST_0]], %[[CASTZP]] : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[SCALE]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: return %[[MUL]] : tensor<32x3x224x224xf32> + +// ----- + +func.func @test_dequantizeLinear_f16(%arg0 : tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf16> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xi8>, tensor, tensor) -> tensor<32x3x224x224xf16> + "func.return"(%2) : (tensor<32x3x224x224xf16>) -> () +} + +// CHECK-LABEL: @test_dequantizeLinear_f16 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf16> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf16>}> : () -> tensor<1x1x1x1xf16> +// CHECK-DAG: %[[CAST_0:.*]] = tosa.cast %[[ARG_0]] : (tensor<32x3x224x224xi8>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTZP:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[SUB:.*]] = tosa.sub %[[CAST_0]], %[[CASTZP]] : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CASTSCALE:.*]] = tosa.cast %[[SCALE]] : (tensor<1x1x1x1xf16>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[CASTSCALE]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xf16> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xf16> + +// ----- + +func.func @per_axis(%arg0: tensor<8x2xi8>) -> tensor<8x2xf32> { + %0 = onnx.Constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %1 = onnx.Constant dense<[0, 1]> : tensor<2xi8> + %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) + {axis = 1 : si64, + saturate = 1 : si64} : (tensor<8x2xi8>, tensor<2xf32>, tensor<2xi8>) -> tensor<8x2xf32> + return %2 : tensor<8x2xf32> +} + +// ----- + +// The default `axis` is `1` when it's absent in ONNX, which conflicts +// with the allowed range of `axis` when the input has rank 1. +// See https://github.com/onnx/onnx/issues/6067 +func.func @default_axis(%arg0 : tensor<32xi8>) -> tensor<32xf32> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.DequantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32xi8>, tensor, tensor) -> tensor<32xf32> + return %2 : tensor<32xf32> +} + +// CHECK-LABEL: default_axis +// CHECK-NOT: onnx.DequantizeLinear + +// ----- + +func.func @no_zeropoint(%arg0: tensor<5xi8>, %arg1: tensor) -> tensor<5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %0) {axis = 0 : si64} : (tensor<5xi8>, tensor, none) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @no_zeropoint( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xi8>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<5xf32> { +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xi8>) -> tensor<5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<5xf32>, tensor<1xf32>) -> tensor<5xf32> +// CHECK: return %[[VAL_4]] : tensor<5xf32> + +// ----- + +func.func @f8E4M3FN(%arg0: tensor<5xf8E4M3FN>, %arg1: tensor) -> tensor<5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %0) {axis = 0 : si64} : (tensor<5xf8E4M3FN>, tensor, none) -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: @f8E4M3FN +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf8E4M3FN>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<5xf32> { +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xf8E4M3FN>) -> tensor<5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor<5xf32>, tensor<1xf32>) -> tensor<5xf32> +// CHECK: return %[[VAL_4]] : tensor<5xf32> diff --git a/test/mlir/conversion/onnx_to_tosa/NN/MatMul.mlir b/test/mlir/conversion/onnx_to_tosa/NN/MatMul.mlir new file mode 100644 index 0000000000..451c7f4277 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/NN/MatMul.mlir @@ -0,0 +1,166 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_onnx_to_matmul2d(%arg0 : tensor<4x8xf32>, %arg1 : tensor<8x16xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x8xf32>, tensor<8x16xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_to_matmul2d(%arg0: tensor<4x8xf32>, %arg1: tensor<8x16xf32>) -> tensor<4x16xf32> + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x8xf32>) -> tensor<1x4x8xf32> + // CHECK: %1 = tosa.reshape %arg1 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> + // CHECK: %2 = tosa.matmul %0, %1 : (tensor<1x4x8xf32>, tensor<1x8x16xf32>) -> tensor<1x4x16xf32> + // CHECK: %3 = tosa.reshape %2 {new_shape = array} : (tensor<1x4x16xf32>) -> tensor<4x16xf32> + // CHECK: return %3 : tensor<4x16xf32> +} + +// ----- + +func.func @test_onnx_to_matmul3dbcast(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<8x16xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x4x8xf32>, tensor<8x16xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_to_matmul3dbcast(%arg0: tensor<100x4x8xf32>, %arg1: tensor<8x16xf32>) -> tensor<100x4x16xf32> { + // CHECK: %0 = tosa.reshape %arg1 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> + // CHECK: %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<100x4x8xf32>) -> tensor<1x400x8xf32> + // CHECK: %2 = "tosa.const"() <{value = dense<[1, 0, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> + // CHECK: %3 = tosa.transpose %0, %2 : (tensor<1x8x16xf32>, tensor<3xi32>) -> tensor<8x1x16xf32> + // CHECK: %4 = tosa.reshape %3 {new_shape = array} : (tensor<8x1x16xf32>) -> tensor<1x8x16xf32> + // CHECK: %5 = tosa.matmul %1, %4 : (tensor<1x400x8xf32>, tensor<1x8x16xf32>) -> tensor<1x400x16xf32> + // CHECK: %6 = tosa.reshape %5 {new_shape = array} : (tensor<1x400x16xf32>) -> tensor<100x4x16xf32> + // CHECK: return %6 : tensor<100x4x16xf32> +} + +// ----- + +func.func @test_onnx_1d(%arg0 : tensor<6xf32>, %arg1 : tensor<6xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<6xf32>, tensor<6xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_1d(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>) -> tensor { + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> + // CHECK: %1 = tosa.reshape %arg1 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> + // CHECK: %2 = tosa.reshape %0 {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> + // CHECK: %3 = tosa.reshape %1 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> + // CHECK: %4 = tosa.matmul %2, %3 : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> + // CHECK: %5 = tosa.reshape %4 {new_shape = array} : (tensor<1x1x1xf32>) -> tensor + // CHECK: return %5 : tensor +} + +// ----- + +func.func @test_onnx_12d(%arg0 : tensor<6xf32>, %arg1 : tensor<6x1xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<6xf32>, tensor<6x1xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_12d(%arg0: tensor<6xf32>, %arg1: tensor<6x1xf32>) -> tensor<1xf32> { + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> + // CHECK: %1 = tosa.reshape %0 {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> + // CHECK: %2 = tosa.reshape %arg1 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> + // CHECK: %3 = tosa.matmul %1, %2 : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> + // CHECK: %4 = tosa.reshape %3 {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> + // CHECK: return %4 : tensor<1xf32> +} + +// ----- + +func.func @test_onnx_21d(%arg0 : tensor<2x6xf32>, %arg1 : tensor<6xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<2x6xf32>, tensor<6xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_21d(%arg0: tensor<2x6xf32>, %arg1: tensor<6xf32>) -> tensor<2xf32> { + // CHECK: %0 = tosa.reshape %arg1 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> + // CHECK: %1 = tosa.reshape %arg0 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x2x6xf32> + // CHECK: %2 = tosa.reshape %0 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> + // CHECK: %3 = tosa.matmul %1, %2 : (tensor<1x2x6xf32>, tensor<1x6x1xf32>) -> tensor<1x2x1xf32> + // CHECK: %4 = tosa.reshape %3 {new_shape = array} : (tensor<1x2x1xf32>) -> tensor<2xf32> + // CHECK: return %4 : tensor<2xf32> +} + +// ----- + +func.func @test_onnx_4d(%arg0 : tensor<10x10x6x2xf32>, %arg1 : tensor<10x10x2x6xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<10x10x6x2xf32>, tensor<10x10x2x6xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_4d(%arg0: tensor<10x10x6x2xf32>, %arg1: tensor<10x10x2x6xf32>) -> tensor<10x10x6x6xf32> { + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<10x10x6x2xf32>) -> tensor<100x6x2xf32> + // CHECK: %1 = tosa.reshape %arg1 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> + // CHECK: %2 = tosa.matmul %0, %1 : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> + // CHECK: %3 = tosa.reshape %2 {new_shape = array} : (tensor<100x6x6xf32>) -> tensor<10x10x6x6xf32> + // CHECK: return %3 : tensor<10x10x6x6xf32> +} + +// ----- + +func.func @test_onnx_4d_mixed(%arg0 : tensor<10x6x2xf32>, %arg1 : tensor<10x10x2x6xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<10x6x2xf32>, tensor<10x10x2x6xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_4d_mixed(%arg0: tensor<10x6x2xf32>, %arg1: tensor<10x10x2x6xf32>) -> tensor<10x10x6x6xf32> { + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<10x6x2xf32>) -> tensor<1x10x6x2xf32> + // CHECK: %1 = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK: %2 = tosa.transpose %0, %1 : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> + // CHECK: %3 = tosa.reshape %2 {new_shape = array} : (tensor<10x1x6x2xf32>) -> tensor<10x6x2xf32> + // CHECK: %4 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK: %5 = tosa.transpose %arg1, %4 : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> + // CHECK: %6 = tosa.reshape %5 {new_shape = array} : (tensor<10x2x10x6xf32>) -> tensor<10x2x60xf32> + // CHECK: %7 = tosa.matmul %3, %6 : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> + // CHECK: %8 = tosa.reshape %7 {new_shape = array} : (tensor<10x6x60xf32>) -> tensor<10x6x10x6xf32> + // CHECK: %9 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK: %10 = tosa.transpose %8, %9 : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> + // CHECK: return %10 : tensor<10x10x6x6xf32> +} + +// ----- + +func.func @test_onnx_to_matmul4d_non_broadcastable(%arg0 : tensor<4x1x5x6xf32>, %arg1 : tensor<1x3x6x7xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<4x1x5x6xf32>, tensor<1x3x6x7xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: func.func @test_onnx_to_matmul4d_non_broadcastable(%arg0: tensor<4x1x5x6xf32>, %arg1: tensor<1x3x6x7xf32>) -> tensor<4x3x5x7xf32> { + // CHECK: %0 = tosa.reshape %arg0 {new_shape = array} : (tensor<4x1x5x6xf32>) -> tensor<1x20x6xf32> + // CHECK: %1 = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK: %2 = tosa.transpose %arg1, %1 : (tensor<1x3x6x7xf32>, tensor<4xi32>) -> tensor<6x1x3x7xf32> + // CHECK: %3 = tosa.reshape %2 {new_shape = array} : (tensor<6x1x3x7xf32>) -> tensor<1x6x21xf32> + // CHECK: %4 = tosa.matmul %0, %3 : (tensor<1x20x6xf32>, tensor<1x6x21xf32>) -> tensor<1x20x21xf32> + // CHECK: %5 = tosa.reshape %4 {new_shape = array} : (tensor<1x20x21xf32>) -> tensor<4x5x3x7xf32> + // CHECK: %6 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK: %7 = tosa.transpose %5, %6 : (tensor<4x5x3x7xf32>, tensor<4xi32>) -> tensor<4x3x5x7xf32> + // CHECK: return %7 : tensor<4x3x5x7xf32> +} + +// ----- + +func.func @test_onnx_to_matmul3d_fp16(%arg0 : tensor<100x4x8xf16>, %arg1 : tensor<100x8x16xf16>) -> tensor<*xf16> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<*xf16> + "func.return"(%0) : (tensor<*xf16>) -> () + // CHECK: %0 = tosa.matmul %arg0, %arg1 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf32> + // CHECK: %1 = tosa.cast %0 : (tensor<100x4x16xf32>) -> tensor<100x4x16xf16> + // CHECK: return %1 : tensor<100x4x16xf16> +} + +// ----- + +func.func @test_onnx_to_matmul3d_bf16(%arg0 : tensor<100x4x8xbf16>, %arg1 : tensor<100x8x16xbf16>) -> tensor<*xbf16> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<*xbf16> + "func.return"(%0) : (tensor<*xbf16>) -> () + // CHECK: %0 = tosa.matmul %arg0, %arg1 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xf32> + // CHECK: %1 = tosa.cast %0 : (tensor<100x4x16xf32>) -> tensor<100x4x16xbf16> + // CHECK: return %1 : tensor<100x4x16xbf16> +} + +// ----- + +func.func @test_onnx_to_matmul3d_fp32(%arg0 : tensor<100x4x8xf32>, %arg1 : tensor<100x8x16xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK: %0 = tosa.matmul %arg0, %arg1 : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<100x4x16xf32> + // CHECK: return %0 : tensor<100x4x16xf32> +} + +// ----- + +func.func @test_onnx_to_matmul2d_dyn(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK-NOT: tosa.matmul +} + +// ----- + +func.func @test_onnx_to_matmul3d_dyn(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () + // CHECK-NOT: tosa.matmul +} diff --git a/test/mlir/conversion/onnx_to_tosa/NN/MaxPoolSingleOut.mlir b/test/mlir/conversion/onnx_to_tosa/NN/MaxPoolSingleOut.mlir index 6a5ee4c0da..dc483e4cec 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/MaxPoolSingleOut.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/MaxPoolSingleOut.mlir @@ -52,24 +52,25 @@ func.func @test_default_maxpoolsingleout_strides(%arg0 : tensor<5x5x32x32xf32>) // CHECK-LABEL: func.func @test_default_maxpoolsingleout_strides(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x5x16x16xf32> { // CHECK-DAG: "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %0 : (tensor<5x5x32x32xf32>, tensor<4xi32>) -> tensor<5x32x32x5xf32> -// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x32x32x5xf32>) -> tensor<5x16x16x5xf32> +// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x32x32x5xf32>) -> tensor<5x16x16x5xf32> // CHECK-DAG: "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %3 : (tensor<5x16x16x5xf32>, tensor<4xi32>) -> tensor<5x5x16x16xf32> // CHECK-DAG: return %[[TRANS_MPOOL_RES]] : tensor<5x5x16x16xf32> // ----- -/// Test the behavior of Max Pool with strides and non uniform padding +/// Test the behavior of Max Pool with strides and non uniform padding func.func @test_default_maxpoolsingleout_strides_nonunifpad(%arg0 : tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> { %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2,2], pads = [1, 0, 0, 0], strides = [2, 2] } : (tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> "func.return"(%0) : (tensor<5x5x15x16xf32>) -> () } // CHECK-LABEL: func.func @test_default_maxpoolsingleout_strides_nonunifpad(%arg0: tensor<5x5x30x32xf32>) -> tensor<5x5x15x16xf32> { -// CHECK-DAG: "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %0 : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> -// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x15x16x5xf32> -// CHECK-DAG: "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %3 : (tensor<5x15x16x5xf32>, tensor<4xi32>) -> tensor<5x5x15x16xf32> +// CHECK-DAG: %[[TRANS_CONST_1:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %[[TRANS_CONST_1]] : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> +// CHECK-DAG: %[[SLICE:.+]] = tosa.slice %[[TRANS_ARG]] {size = array, start = array} : (tensor<5x30x32x5xf32>) -> tensor<5x29x32x5xf32> +// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[SLICE]] {kernel = array, pad = array, stride = array} : (tensor<5x29x32x5xf32>) -> tensor<5x15x16x5xf32> +// CHECK-DAG: %[[TRANS_CONST_2:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %[[TRANS_CONST_2]] : (tensor<5x15x16x5xf32>, tensor<4xi32>) -> tensor<5x5x15x16xf32> // CHECK-DAG: return %[[TRANS_MPOOL_RES]] : tensor<5x5x15x16xf32> // ----- @@ -80,11 +81,11 @@ func.func @test_default_maxpoolsingleout_strides_nonunifpad_ceil(%arg0 : tensor< "func.return"(%0) : (tensor<5x5x16x16xf32>) -> () } // CHECK-LABEL: func.func @test_default_maxpoolsingleout_strides_nonunifpad_ceil(%arg0: tensor<5x5x30x32xf32>) -> tensor<5x5x16x16xf32> { -// CHECK-DAG: "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %0 : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> -// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x16x16x5xf32> -// CHECK-DAG: "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %3 : (tensor<5x16x16x5xf32>, tensor<4xi32>) -> tensor<5x5x16x16xf32> +// CHECK-DAG: %[[TRANS_CONST_1:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %[[TRANS_CONST_1]] : (tensor<5x5x30x32xf32>, tensor<4xi32>) -> tensor<5x30x32x5xf32> +// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x30x32x5xf32>) -> tensor<5x16x16x5xf32> +// CHECK-DAG: %[[TRANS_CONST_2:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %[[TRANS_CONST_2]] : (tensor<5x16x16x5xf32>, tensor<4xi32>) -> tensor<5x5x16x16xf32> // CHECK-DAG: return %[[TRANS_MPOOL_RES]] : tensor<5x5x16x16xf32> @@ -115,3 +116,35 @@ func.func @test_default_maxpoolsingleout_same_upper_ceil_mode(%arg0 : tensor<5x5 // CHECK-DAG: "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %3 : (tensor<5x4x4x5xf32>, tensor<4xi32>) -> tensor<5x5x4x4xf32> // CHECK-DAG: return %[[TRANS_MPOOL_RES]] : tensor<5x5x4x4xf32> + +// ----- + +func.func @test_maxpoolsingleout_dilation1(%arg0 : tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {kernel_shape = [3,3], dilations = [1,1]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> + return %0 : tensor<5x5x30x30xf32> +} +// CHECK-LABEL: func.func @test_maxpoolsingleout_dilation1(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> { +// CHECK-DAG: "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_ARG:.*]] = tosa.transpose %arg0, %0 : (tensor<5x5x32x32xf32>, tensor<4xi32>) -> tensor<5x32x32x5xf32> +// CHECK-DAG: %[[MPOOL_RES:.*]] = tosa.max_pool2d %[[TRANS_ARG]] {kernel = array, pad = array, stride = array} : (tensor<5x32x32x5xf32>) -> tensor<5x30x30x5xf32> +// CHECK-DAG: "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %[[TRANS_MPOOL_RES:.*]] = tosa.transpose %[[MPOOL_RES]], %3 : (tensor<5x30x30x5xf32>, tensor<4xi32>) -> tensor<5x5x30x30xf32> +// CHECK-DAG: return %[[TRANS_MPOOL_RES]] : tensor<5x5x30x30xf32> + +// ----- + +func.func @test_maxpoolsingleout_dilation2(%arg0 : tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {kernel_shape = [3,3], dilations = [2,2]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> + return %0 : tensor<5x5x30x30xf32> +} +// CHECK-LABEL: func.func @test_maxpoolsingleout_dilation2(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> { +// CHECK: onnx.MaxPoolSingleOut + +// ----- + +func.func @test_maxpoolsingleout_dilation1_dyn(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {kernel_shape = [3,3], dilations = [1,1]} : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func.func @test_maxpoolsingleout_dilation1_dyn +// CHECK: onnx.MaxPoolSingleOut diff --git a/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir b/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir new file mode 100644 index 0000000000..fd51d6e377 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/NN/QuantizeLinear.mlir @@ -0,0 +1,111 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_quantizeLinear(%arg0 : tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xf32>, tensor, tensor) -> tensor<32x3x224x224xi8> + "func.return"(%2) : (tensor<32x3x224x224xi8>) -> () +} +// CHECK-LABEL: @test_quantizeLinear +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi8>}> : () -> tensor<1x1x1x1xi8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[REC:.*]] = tosa.reciprocal %[[SCALE]] : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[ARG_0]], %[[REC]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[MUL_CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[ZPCAST:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xi8>) -> tensor<1x1x1x1xi32> +// CHECK-DAG: %[[ADD:.*]] = tosa.add %[[MUL_CAST]], %[[ZPCAST]] : (tensor<32x3x224x224xi32>, tensor<1x1x1x1xi32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[ADD]] {max_fp = 1.270000e+02 : f32, max_int = 127 : i64, min_fp = -1.280000e+02 : f32, min_int = -128 : i64} : (tensor<32x3x224x224xi32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[CLAMP]] : (tensor<32x3x224x224xi32>) -> tensor<32x3x224x224xi8> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xi8> + +// ----- + +func.func @test_quantizeLinear_none(%arg0 : tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi8> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = "onnx.NoValue"() {onnx_node_name = "onnx.NoValue_0", value} : () -> none + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xf32>, tensor, none) -> tensor<32x3x224x224xi8> + "func.return"(%2) : (tensor<32x3x224x224xi8>) -> () +} + +// CHECK-LABEL: @test_quantizeLinear_none +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xui8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[REC:.*]] = tosa.reciprocal %[[SCALE]] : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[ARG_0]], %[[REC]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[MUL_CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[MUL_CAST]] : (tensor<32x3x224x224xi32>) -> tensor<32x3x224x224xui8> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xui8> + +// ----- + +func.func @test_quantizeLinear_per_axis(%arg0: tensor<8x2xf32>) -> tensor<8x2xi8> { + %0 = onnx.Constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %1 = onnx.Constant dense<[0, 1]> : tensor<2xi8> + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) + {axis = 1 : si64, + saturate = 1 : si64} : (tensor<8x2xf32>, tensor<2xf32>, tensor<2xi8>) -> tensor<8x2xi8> + return %2 : tensor<8x2xi8> +} +// CHECK-LABEL: func.func @test_quantizeLinear_per_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x2xf32>) -> tensor<8x2xi8> { +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>}> : () -> tensor<1x2xf32> +// CHECK: %[[REC:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x2xf32>) -> tensor<1x2xf32> +// CHECK: %[[MUL:.*]] = tosa.mul %[[VAL_0]], %[[REC]] {shift = 0 : i8} : (tensor<8x2xf32>, tensor<1x2xf32>) -> tensor<8x2xf32> +// CHECK: %[[MUL_CAST:.*]] = tosa.cast %[[MUL]] : (tensor<8x2xf32>) -> tensor<8x2xi32> +// CHECK: %[[ZP:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1]]> : tensor<1x2xi8>}> : () -> tensor<1x2xi8> +// CHECK: %[[ZPCAST:.*]] = tosa.cast %[[ZP]] : (tensor<1x2xi8>) -> tensor<1x2xi32> +// CHECK: %[[ADD:.*]] = tosa.add %[[MUL_CAST]], %[[ZPCAST]] : (tensor<8x2xi32>, tensor<1x2xi32>) -> tensor<8x2xi32> +// CHECK: %[[CLAMP:.*]] = tosa.clamp %[[ADD]] {max_fp = 1.270000e+02 : f32, max_int = 127 : i64, min_fp = -1.280000e+02 : f32, min_int = -128 : i64} : (tensor<8x2xi32>) -> tensor<8x2xi32> +// CHECK: %[[CAST:.*]] = tosa.cast %[[CLAMP]] : (tensor<8x2xi32>) -> tensor<8x2xi8> +// CHECK: return %[[CAST]] : tensor<8x2xi8> +// CHECK: } + +// ----- + +func.func @test_quantizeLinear_negative_axis(%arg0: tensor<8x2xf32>) -> tensor<8x2xi8> { + %0 = onnx.Constant dense<2.000000e+00> : tensor<8xf32> + %1 = onnx.Constant dense<1> : tensor<8xi8> + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) + {axis = -2 : si64, + saturate = 1 : si64} : (tensor<8x2xf32>, tensor<8xf32>, tensor<8xi8>) -> tensor<8x2xi8> + return %2 : tensor<8x2xi8> +} +// CHECK-LABEL: test_quantizeLinear_negative_axis +// CHECK: "tosa.const"() {{.*}} : tensor<8x1xi8> + +// ----- + +func.func @test_quantizeLinear_ui8(%arg0 : tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xui8> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32x3x224x224xf32>, tensor, tensor) -> tensor<32x3x224x224xui8> + "func.return"(%2) : (tensor<32x3x224x224xui8>) -> () +} +// CHECK-LABEL: @test_quantizeLinear_ui8 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xui8> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xui8>}> : () -> tensor<1x1x1x1xui8> +// CHECK-DAG: %[[SCALE:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[REC:.*]] = tosa.reciprocal %[[SCALE]] : (tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK-DAG: %[[MUL:.*]] = tosa.mul %[[ARG_0]], %[[REC]] {shift = 0 : i8} : (tensor<32x3x224x224xf32>, tensor<1x1x1x1xf32>) -> tensor<32x3x224x224xf32> +// CHECK-DAG: %[[MUL_CAST:.*]] = tosa.cast %[[MUL]] : (tensor<32x3x224x224xf32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[ZPCAST:.*]] = tosa.cast %[[ZP]] : (tensor<1x1x1x1xui8>) -> tensor<1x1x1x1xi32> +// CHECK-DAG: %[[ADD:.*]] = tosa.add %[[MUL_CAST]], %[[ZPCAST]] : (tensor<32x3x224x224xi32>, tensor<1x1x1x1xi32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[CLAMP:.*]] = tosa.clamp %[[ADD]] {max_fp = 2.550000e+02 : f32, max_int = 255 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<32x3x224x224xi32>) -> tensor<32x3x224x224xi32> +// CHECK-DAG: %[[CAST:.*]] = tosa.cast %[[CLAMP]] : (tensor<32x3x224x224xi32>) -> tensor<32x3x224x224xui8> +// CHECK-DAG: return %[[CAST]] : tensor<32x3x224x224xui8> + +// ----- + +// The default `axis` is `1` when it's absent in ONNX, which conflicts +// with the allowed range of `axis` when the input has rank 1. +// See https://github.com/onnx/onnx/issues/6067 +func.func @default_axis(%arg0 : tensor<32xf32>) -> tensor<32xi8> { + %0 = onnx.Constant dense<3.125000e-02> : tensor + %1 = onnx.Constant dense<0> : tensor + %2 = "onnx.QuantizeLinear"(%arg0, %0, %1) {axis = 1 : si64} : (tensor<32xf32>, tensor, tensor) -> tensor<32xi8> + return %2 : tensor<32xi8> +} + +// CHECK-LABEL: default_axis +// CHECK-NOT: onnx.QuantizeLinear diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Concat.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Concat.mlir new file mode 100644 index 0000000000..2796dae7b3 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Concat.mlir @@ -0,0 +1,34 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + + +func.func @test_concat(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { + %0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> + "func.return"(%0) : (tensor<5x5x4x32xf32>) -> () +// CHECK-LABEL: func.func @test_concat( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x1x32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { +// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> +// CHECK: return %[[VAL_2]] : tensor<5x5x4x32xf32> +} + +// ----- +func.func @test_concat_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>, %arg1 : tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { + %0 = "onnx.Concat"(%arg0, %arg1) { axis = 2 : si64} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> + "func.return"(%0) : (tensor<5x5x?x32xf32>) -> () +// CHECK-LABEL: func.func @test_concat_dynamic_shape( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x?x32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> { +// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<5x5x?x32xf32>, tensor<5x5x?x32xf32>) -> tensor<5x5x?x32xf32> +// CHECK: return %[[VAL_2]] : tensor<5x5x?x32xf32> +} + +// ----- +func.func @test_concat_negative_axis(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { + %0 = "onnx.Concat"(%arg0, %arg1) { axis = -2 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> + "func.return"(%0) : (tensor<5x5x4x32xf32>) -> () +// CHECK-LABEL: func.func @test_concat_negative_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x1x32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> { +// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_0]], %[[VAL_1]] {axis = 2 : i32} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>) -> tensor<5x5x4x32xf32> +// CHECK: return %[[VAL_2]] : tensor<5x5x4x32xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Expand.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Expand.mlir new file mode 100644 index 0000000000..6c3eeb6401 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Expand.mlir @@ -0,0 +1,120 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_expand(%arg0: tensor<1x64x1x1xf32>) -> tensor<1x64x64x64xf32> { + %0 = "onnx.Constant"() {value = dense<[1, 64, 64, 64]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<1x64x1x1xf32>, tensor<4xi64>) -> tensor<1x64x64x64xf32> + return %1 : tensor<1x64x64x64xf32> +} + +// CHECK-LABEL: func.func @test_expand +// CHECK: %[[VAL:.*]] = tosa.tile %{{.*}} {multiples = array} : (tensor<1x64x1x1xf32>) -> tensor<1x64x64x64xf32> +// CHECK-NEXT: return %[[VAL]] : tensor<1x64x64x64xf32> + +// ----- + +func.func @test_expand_splat(%arg0: tensor<1x64x1x1xf32>) -> tensor<64x64x64x64xf32> { + %0 = "onnx.Constant"() {value = dense<64> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<1x64x1x1xf32>, tensor<4xi64>) -> tensor<64x64x64x64xf32> + return %1 : tensor<64x64x64x64xf32> +} + +// CHECK-LABEL: func.func @test_expand_splat +// CHECK: %[[VAL:.*]] = tosa.tile %{{.*}} {multiples = array} : (tensor<1x64x1x1xf32>) -> tensor<64x64x64x64xf32> +// CHECK-NEXT: return %[[VAL]] : tensor<64x64x64x64xf32> + +// ----- + +func.func @test_expand_new_dims_out(%arg0: tensor<1x64x1xf32>) -> tensor<64x64x64x64xf32> { + %0 = "onnx.Constant"() {value = dense<[64, 64, 64, 64]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<1x64x1xf32>, tensor<4xi64>) -> tensor<64x64x64x64xf32> + return %1 : tensor<64x64x64x64xf32> +} + +// CHECK-LABEL: func.func @test_expand_new_dims_out +// CHECK: %[[RES:.*]] = tosa.reshape %{{.*}} {new_shape = array} : (tensor<1x64x1xf32>) -> tensor<1x64x1x1xf32> +// CHECK-NEXT: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<1x64x1x1xf32>) -> tensor<64x64x64x64xf32> +// CHECK-NEXT: return %[[TILE]] : tensor<64x64x64x64xf32> + +// ----- + +func.func @test_expand_new_dims_start(%arg0: tensor<256x256x16xf32>) -> tensor<1x512x256x16xf32> { + %0 = "onnx.Constant"() {value = dense<[1, 512, 256, 16]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<256x256x16xf32>, tensor<4xi64>) -> tensor<1x512x256x16xf32> + return %1 : tensor<1x512x256x16xf32> +} + +// CHECK-LABEL: func.func @test_expand_new_dims_start +// CHECK: %[[RES:.*]] = tosa.reshape %{{.*}} {new_shape = array} : (tensor<256x256x16xf32>) -> tensor<1x256x256x16xf32> +// CHECK-NEXT: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<1x256x256x16xf32>) -> tensor<1x512x256x16xf32> +// CHECK-NEXT: return %[[TILE]] : tensor<1x512x256x16xf32> + +// ----- + +func.func @test_expand_new_dims_mix(%arg0: tensor<128x64xf32>) -> tensor<1x128x16x128x16xf32> { + %0 = "onnx.Constant"() {value = dense<[1, 128, 16, 128, 16]> : tensor<5xi64>} : () -> tensor<5xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<128x64xf32>, tensor<5xi64>) -> tensor<1x128x16x128x16xf32> + return %1 : tensor<1x128x16x128x16xf32> +} + +// CHECK-LABEL: func.func @test_expand_new_dims_mix +// CHECK: %[[RES:.*]] = tosa.reshape %{{.*}} {new_shape = array} : (tensor<128x64xf32>) -> tensor<1x128x1x64x1xf32> +// CHECK-NEXT: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<1x128x1x64x1xf32>) -> tensor<1x128x16x128x16xf32> +// CHECK-NEXT: return %[[TILE]] : tensor<1x128x16x128x16xf32> + +// ----- + +func.func @test_expand_no_tile(%arg0: tensor<128x16xf32>) -> tensor<1x1x128x16xf32> { + %0 = "onnx.Constant"() {value = dense<[1, 1, 128, 16]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<128x16xf32>, tensor<4xi64>) -> tensor<1x1x128x16xf32> + return %1 : tensor<1x1x128x16xf32> +} + +// CHECK-LABEL: func.func @test_expand_no_tile +// CHECK: %[[RES:.*]] = tosa.reshape %{{.*}} {new_shape = array} : (tensor<128x16xf32>) -> tensor<1x1x128x16xf32> +// CHECK-NEXT: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<1x1x128x16xf32>) -> tensor<1x1x128x16xf32> +// CHECK-NEXT: return %[[TILE]] : tensor<1x1x128x16xf32> + +// ----- +func.func @test_expand_tile_one_dim_big(%arg0: tensor<1x6x1x1xf32>) -> tensor<1x6x576x672xf32> { + %0 = onnx.Constant dense<[1, 1, 576, 672]> : tensor<4xi64> + %1 = "onnx.Expand"(%arg0, %0) {onnx_node_name = "Expand_1417"} : (tensor<1x6x1x1xf32>, tensor<4xi64>) -> tensor<1x6x576x672xf32> + return %1 : tensor<1x6x576x672xf32> +} +// CHECK-LABEL: func.func @test_expand_tile_one_dim_big +// CHECK: %[[TILE:.*]] = tosa.tile %{{.*}} {multiples = array} : (tensor<1x6x1x1xf32>) -> tensor<1x6x576x672xf32> +// CHECK: return %[[TILE]] : tensor<1x6x576x672xf32> + +// ----- + +func.func @test_expand_smaller_dims(%arg0: tensor<128x64x1x1xf32>) -> tensor<1x128x64x64x128xf32> { + %0 = "onnx.Constant"() {value = dense<[1, 64, 128]> : tensor<3xi64>} : () -> tensor<3xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<128x64x1x1xf32>, tensor<3xi64>) -> tensor<1x128x64x64x128xf32> + return %1 : tensor<1x128x64x64x128xf32> +} + +// CHECK-LABEL: func.func @test_expand_smaller_dims +// CHECK: %[[RES:.*]] = tosa.reshape {{.*}} {new_shape = array} : (tensor<128x64x1x1xf32>) -> tensor<1x128x64x1x1xf32> +// CHECK: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<1x128x64x1x1xf32>) -> tensor<1x128x64x64x128xf32> +// CHECK: return %[[TILE]] : tensor<1x128x64x64x128xf32> + +// ----- + +func.func @test_expand_mixed_smaller(%arg0 : tensor<2x1x6x1xbf16>) -> tensor<2x7x6x5xbf16> { + %0 = "onnx.Constant"() {value = dense<[7, 1, 5]> : tensor<3xi64> } : () -> tensor<3xi64> + %1 = "onnx.Expand"(%arg0, %0) : (tensor<2x1x6x1xbf16>, tensor<3xi64>) -> tensor<2x7x6x5xbf16> + func.return %1 : tensor<2x7x6x5xbf16> +} + +// CHECK-LABEL: func.func @test_expand_mixed_smaller +// CHECK: %[[RES:.*]] = tosa.reshape {{.*}} {new_shape = array} : (tensor<2x1x6x1xbf16>) -> tensor<2x1x6x1xbf16> +// CHECK: %[[TILE:.*]] = tosa.tile %[[RES]] {multiples = array} : (tensor<2x1x6x1xbf16>) -> tensor<2x7x6x5xbf16> +// CHECK: return %[[TILE]] : tensor<2x7x6x5xbf16> + +// ----- + +func.func @test_expand_no_legalization(%arg0: tensor<1x64x1x1xf32>, %arg1: tensor<4xi64>) -> tensor<1x64x64x64xf32> { + %0 = "onnx.Expand"(%arg0, %arg1) : (tensor<1x64x1x1xf32>, tensor<4xi64>) -> tensor<1x64x64x64xf32> + return %0 : tensor<1x64x64x64xf32> +} +// CHECK-LABEL: func @test_expand_no_legalization +// CHECK: onnx.Expand diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/EyeLike.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/EyeLike.mlir new file mode 100644 index 0000000000..0776dcbbbb --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/EyeLike.mlir @@ -0,0 +1,98 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_eyelike_dtype_f32(%arg0 : tensor<4x4xi32>) -> tensor<4x4xf32> { + %1 = "onnx.EyeLike"(%arg0) {dtype = 1 : si64} : (tensor<4x4xi32>) -> tensor<4x4xf32> + "onnx.Return"(%1) : (tensor<4x4xf32>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_dtype_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xi32>) -> tensor<4x4xf32> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<4x4xf32>}> : () -> tensor<4x4xf32> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xf32> + +// ----- + +func.func @test_eyelike_int8(%arg0 : tensor<4x4xi8>) -> tensor<4x4xi8> { + %1 = "onnx.EyeLike"(%arg0) : (tensor<4x4xi8>) -> tensor<4x4xi8> + "onnx.Return"(%1) : (tensor<4x4xi8>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_int8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xi8>) -> tensor<4x4xi8> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]> : tensor<4x4xi8>}> : () -> tensor<4x4xi8> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xi8> + +// ----- + +func.func @test_eyelike_bool(%arg0 : tensor<4x4xi1>) -> tensor<4x4xi1> { + %1 = "onnx.EyeLike"(%arg0) : (tensor<4x4xi1>) -> tensor<4x4xi1> + "onnx.Return"(%1) : (tensor<4x4xi1>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_bool +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xi1>) -> tensor<4x4xi1> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}true, false, false, false], [false, true, false, false], [false, false, true, false], [false, false, false, true]]> : tensor<4x4xi1>}> : () -> tensor<4x4xi1> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xi1> + +// ----- + +func.func @test_eyelike_k_pos(%arg0 : tensor<4x4xf64>) -> tensor<4x4xf64> { + %1 = "onnx.EyeLike"(%arg0) {k = 2 : si64} : (tensor<4x4xf64>) -> tensor<4x4xf64> + "onnx.Return"(%1) : (tensor<4x4xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_k_pos +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xf64>) -> tensor<4x4xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]]> : tensor<4x4xf64>}> : () -> tensor<4x4xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xf64> + +// ----- + +func.func @test_eyelike_k_neg(%arg0 : tensor<4x4xf64>) -> tensor<4x4xf64> { + %1 = "onnx.EyeLike"(%arg0) {k = -2 : si64} : (tensor<4x4xf64>) -> tensor<4x4xf64> + "onnx.Return"(%1) : (tensor<4x4xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_k_neg +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xf64>) -> tensor<4x4xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00]]> : tensor<4x4xf64>}> : () -> tensor<4x4xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xf64> + +// ----- + +func.func @test_eyelike_k_out_of_rang(%arg0 : tensor<4x4xf64>) -> tensor<4x4xf64> { + %1 = "onnx.EyeLike"(%arg0) {k = 42 : si64} : (tensor<4x4xf64>) -> tensor<4x4xf64> + "onnx.Return"(%1) : (tensor<4x4xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_k_out_of_rang +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x4xf64>) -> tensor<4x4xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4x4xf64>}> : () -> tensor<4x4xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x4xf64> + +// ----- + +func.func @test_eyelike_dif_dim(%arg0 : tensor<2x5xf64>) -> tensor<2x5xf64> { + %1 = "onnx.EyeLike"(%arg0) : (tensor<2x5xf64>) -> tensor<2x5xf64> + "onnx.Return"(%1) : (tensor<2x5xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_dif_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5xf64>) -> tensor<2x5xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]]> : tensor<2x5xf64>}> : () -> tensor<2x5xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<2x5xf64> + +// ----- + +func.func @test_eyelike_dif_dim_k(%arg0 : tensor<2x5xf64>) -> tensor<2x5xf64> { + %1 = "onnx.EyeLike"(%arg0) {k = 1 : si64} : (tensor<2x5xf64>) -> tensor<2x5xf64> + "onnx.Return"(%1) : (tensor<2x5xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_dif_dim_k +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5xf64>) -> tensor<2x5xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00]]> : tensor<2x5xf64>}> : () -> tensor<2x5xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<2x5xf64> + +// ----- + +func.func @test_eyelike_dif_dim_2(%arg0 : tensor<4x2xf64>) -> tensor<4x2xf64> { + %1 = "onnx.EyeLike"(%arg0) : (tensor<4x2xf64>) -> tensor<4x2xf64> + "onnx.Return"(%1) : (tensor<4x2xf64>) -> () +} +// CHECK-LABEL: func.func @test_eyelike_dif_dim_2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x2xf64>) -> tensor<4x2xf64> { +// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{\[\[}}1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00], [0.000000e+00, 0.000000e+00], [0.000000e+00, 0.000000e+00]{{.}}> : tensor<4x2xf64>}> : () -> tensor<4x2xf64> +// CHECK: onnx.Return [[VAR_0_]] : tensor<4x2xf64> diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Flatten.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Flatten.mlir new file mode 100644 index 0000000000..71231e5db4 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Flatten.mlir @@ -0,0 +1,42 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_static_flatten(%arg0: tensor<32x512x1x1xf32>) -> tensor<32x512xf32> { + %0 = "onnx.Flatten"(%arg0) {axis = 1 : si64} : (tensor<32x512x1x1xf32>) -> tensor<32x512xf32> + return %0 : tensor<32x512xf32> + //CHECK: {{%.+}} = tosa.reshape %arg0 {new_shape = array} : (tensor<32x512x1x1xf32>) -> tensor<32x512xf32> +} + +// ----- +func.func @test_static_flatten_mult(%arg0: tensor<32x51x5x3xf32>) -> tensor<1632x15xf32> { + %0 = "onnx.Flatten"(%arg0) {axis = 2 : si64} : (tensor<32x51x5x3xf32>) -> tensor<1632x15xf32> + return %0 : tensor<1632x15xf32> + //CHECK: {{%.+}} = tosa.reshape %arg0 {new_shape = array} : (tensor<32x51x5x3xf32>) -> tensor<1632x15xf32> +} + +// ----- +func.func @test_flatten_axes_0(%arg0: tensor<32x51x1x1xf32>) -> tensor<1x1632xf32> { + %0 = "onnx.Flatten"(%arg0) {axis = 0 : si64} : (tensor<32x51x1x1xf32>) -> tensor<1x1632xf32> + return %0 : tensor<1x1632xf32> + //CHECK: {{%.+}} = tosa.reshape %arg0 {new_shape = array} : (tensor<32x51x1x1xf32>) -> tensor<1x1632xf32> +} + +// ----- +func.func @test_flatten_axes_last(%arg0: tensor<32x51x1x3xf32>) -> tensor<1632x3xf32> { + %0 = "onnx.Flatten"(%arg0) {axis = 3 : si64} : (tensor<32x51x1x3xf32>) -> tensor<1632x3xf32> + return %0 : tensor<1632x3xf32> + //CHECK: {{%.+}} = tosa.reshape %arg0 {new_shape = array} : (tensor<32x51x1x3xf32>) -> tensor<1632x3xf32> +} + +// ----- +func.func @test_flatten_axes_equals_rank(%arg0: tensor<32x51x1x3xf32>) -> tensor<4896x1xf32> { + %0 = "onnx.Flatten"(%arg0) {axis = 4 : si64} : (tensor<32x51x1x3xf32>) -> tensor<4896x1xf32> + return %0 : tensor<4896x1xf32> + //CHECK: {{%.+}} = tosa.reshape %arg0 {new_shape = array} : (tensor<32x51x1x3xf32>) -> tensor<4896x1xf32> +} + +// ----- +func.func @test_flatten_dyn_shape(%arg0: tensor) -> tensor { + %0 = "onnx.Flatten"(%arg0) {axis = 4 : si64} : (tensor) -> tensor + return %0 : tensor + //CHECK: onnx.Flatten +} \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir new file mode 100644 index 0000000000..a358e35cd3 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Gather.mlir @@ -0,0 +1,164 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> + "func.return"(%0) : (tensor<2x2x2xf32>) -> () +// CHECK-LABEL: func.func @test_gather_axis0( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2xf32>) -> tensor<2x2x2xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 1], [1, 2]]> : tensor<2x2xi64>}> : () -> tensor<2x2xi64> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<2x2xi64>, tensor<1x1xi64>) -> tensor<2x2xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor<2x2xi64>, tensor<1x1xi64>) -> tensor<2x2xi1> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_3]] : (tensor<2x2xi1>, tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<2x2xi64>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[0, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_8]] : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<3x2xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<3x2xf32>) -> tensor<1x3x2xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x2xi32>) -> tensor<1x4xi32> +// CHECK: %[[VAL_12:.*]] = tosa.gather %[[VAL_10]], %[[VAL_11]] : (tensor<1x3x2xf32>, tensor<1x4xi32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x4x2xf32>) -> tensor<2x2x2xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<2x2x2xf32>, tensor<3xi32>) -> tensor<2x2x2xf32> +// CHECK: return %[[VAL_15]] : tensor<2x2x2xf32> +} + +// ----- + +// Test negative indices. +func.func @test_gather_axis0_neg_idx(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, -1], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2x2xf32> + "func.return"(%0) : (tensor<2x2x2xf32>) -> () +// CHECK-LABEL: func.func @test_gather_axis0_neg_idx( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2xf32>) -> tensor<2x2x2xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, -1], [1, 2]]> : tensor<2x2xi64>}> : () -> tensor<2x2xi64> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<2x2xi64>, tensor<1x1xi64>) -> tensor<2x2xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor<2x2xi64>, tensor<1x1xi64>) -> tensor<2x2xi1> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_3]] : (tensor<2x2xi1>, tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<2x2xi64>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[0, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_8]] : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<3x2xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<3x2xf32>) -> tensor<1x3x2xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x2xi32>) -> tensor<1x4xi32> +// CHECK: %[[VAL_12:.*]] = tosa.gather %[[VAL_10]], %[[VAL_11]] : (tensor<1x3x2xf32>, tensor<1x4xi32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x4x2xf32>) -> tensor<2x2x2xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<2x2x2xf32>, tensor<3xi32>) -> tensor<2x2x2xf32> +// CHECK: return %[[VAL_15]] : tensor<2x2x2xf32> +} + +// ----- + +// Test along axis 1. Transpose should be different. +func.func @test_gather_axis1(%arg0 : tensor<3x3xf32>) -> tensor<3x1x2xf32> { + %indices = "onnx.Constant"() {value = dense<[[0, 2]]> : tensor<1x2xi64>} : () -> tensor<1x2xi64> + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> + "func.return"(%0) : (tensor<3x1x2xf32>) -> () +// CHECK-LABEL: func.func @test_gather_axis1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>) -> tensor<3x1x2xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 2]]> : tensor<1x2xi64>}> : () -> tensor<1x2xi64> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x2xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x2xi1> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_3]] : (tensor<1x2xi1>, tensor<1x2xi64>, tensor<1x2xi64>) -> tensor<1x2xi64> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<1x2xi64>) -> tensor<1x2xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_8]] : (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<3x3xf32>) -> tensor<1x3x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.gather %[[VAL_10]], %[[VAL_11]] : (tensor<1x3x3xf32>, tensor<1x2xi32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x1x2xf32> +// CHECK: return %[[VAL_15]] : tensor<3x1x2xf32> +// CHECK: } +} + +// ----- + +func.func @test_gather_dynamic_indices(%arg0 : tensor<3x3xf32>, %indices: tensor<1x2xi64>) -> tensor<3x1x2xf32> { + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi64>) -> tensor<3x1x2xf32> + "func.return"(%0) : (tensor<3x1x2xf32>) -> () +// CHECK-LABEL: func.func @test_gather_dynamic_indices( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi64>) -> tensor<3x1x2xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x2xi64> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor<1x2xi64>, tensor<1x1xi64>) -> tensor<1x2xi1> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_3]] : (tensor<1x2xi1>, tensor<1x2xi64>, tensor<1x2xi64>) -> tensor<1x2xi64> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<1x2xi64>) -> tensor<1x2xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_8]] : (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<3x3xf32>) -> tensor<1x3x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.gather %[[VAL_10]], %[[VAL_11]] : (tensor<1x3x3xf32>, tensor<1x2xi32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x1x2xf32> +// CHECK: return %[[VAL_15]] : tensor<3x1x2xf32> +} + +// ----- + +func.func @test_gather_dynamic_indices_i32(%arg0 : tensor<3x3xf32>, %indices: tensor<1x2xi32>) -> tensor<3x1x2xf32> { + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor<1x2xi32>) -> tensor<3x1x2xf32> + "func.return"(%0) : (tensor<3x1x2xf32>) -> () +// CHECK-LABEL: func.func @test_gather_dynamic_indices_i32( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2xi32>) -> tensor<3x1x2xf32> { +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3> : tensor<1x1xi32>}> : () -> tensor<1x1xi32> +// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<1x2xi32>, tensor<1x1xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi32>}> : () -> tensor<1x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_5]] : (tensor<1x2xi32>, tensor<1x1xi32>) -> tensor<1x2xi1> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_1]], %[[VAL_4]] : (tensor<1x2xi1>, tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_9]] : (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<3x3xf32>) -> tensor<1x3x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<1x2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_13:.*]] = tosa.gather %[[VAL_11]], %[[VAL_12]] : (tensor<1x3x3xf32>, tensor<1x2xi32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_14]], %[[VAL_15]] : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x1x2xf32> +// CHECK: return %[[VAL_16]] : tensor<3x1x2xf32> +} + +// ----- + +func.func @test_gather_like_slice(%arg0 : tensor<3x3xf32>) -> tensor<3xf32> { + %indices = onnx.Constant dense<0> : tensor + %0 = "onnx.Gather"(%arg0, %indices) {axis = 1 : si64} : (tensor<3x3xf32>, tensor) -> tensor<3xf32> + "func.return"(%0) : (tensor<3xf32>) -> () +// CHECK-LABEL: test_gather_like_slice +// CHECK-SAME: (%[[ARG:.*]]: tensor<3x3xf32>) +// CHECK: %[[VAL_1:.*]] = tosa.slice %[[ARG]] {size = array, start = array} : (tensor<3x3xf32>) -> tensor<3x1xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reshape %[[VAL_1]] {{.*}} -> tensor<3xf32> +// CHECK: return %[[VAL_2]] +} + +// ----- + +func.func @test_gather_like_slice_non_zero(%arg0 : tensor<3x3xf32>) -> tensor<3xf32> { + %indices = onnx.Constant dense<2> : tensor + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor<3x3xf32>, tensor) -> tensor<3xf32> + "func.return"(%0) : (tensor<3xf32>) -> () +// CHECK-LABEL: test_gather_like_slice +// CHECK-SAME: (%[[ARG:.*]]: tensor<3x3xf32>) +// CHECK: %[[VAL_1:.*]] = tosa.slice %[[ARG]] {size = array, start = array} : (tensor<3x3xf32>) -> tensor<1x3xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reshape %[[VAL_1]] {{.*}} -> tensor<3xf32> +// CHECK: return %[[VAL_2]] +} + +// ----- + +func.func @test_gather_dynamic_shape_indices_i32(%arg0 : tensor, %indices: tensor) -> tensor { + %0 = "onnx.Gather"(%arg0, %indices) {axis = 0 : si64} : (tensor, tensor) -> tensor + "func.return"(%0) : (tensor) -> () +// CHECK-LABEL: test_gather_dynamic_shape_indices_i32 +// CHECK: onnx.Gather +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir new file mode 100644 index 0000000000..ea0b6d662c --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Padding.mlir @@ -0,0 +1,137 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<24x22x52x42xf32> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<24x22x52x42xf32> + return %2 : tensor<24x22x52x42xf32> +// CHECK-LABEL: test_pad_f32 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4.500000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4.5000]> : tensor<1xf32>} : () -> tensor<1xf32> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32> + return %2 : tensor<20x16x44x32xf32> +// CHECK-LABEL: test_no_pad_f32 +// CHECK: return %arg0 +} + +// ----- +func.func @test_novalue_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x45x33xf32> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xf32> + return %2 : tensor<20x16x45x33xf32> +// CHECK-LABEL: test_novalue_pad_f32 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_novalue_no_pad_f32(%arg0: tensor<20x16x44x32xf32>) -> tensor<20x16x44x32xf32> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xf32> + return %2 : tensor<20x16x44x32xf32> +// CHECK-LABEL: test_novalue_no_pad_f32 +// CHECK: return %arg0 +} + +// ----- +func.func @test_no_const_pad_f32(%arg0: tensor<20x16x44x32xf32>, %arg1: tensor<8xi64>, %arg2: tensor<1xf32>) -> tensor<20x16x44x32xf32> { + %noval = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xf32>, tensor<8xi64>, tensor<1xf32>, none) -> tensor<20x16x44x32xf32> + return %2 : tensor<20x16x44x32xf32> +// CHECK-LABEL: test_no_const_pad_f32 +// CHECK: "onnx.Pad" +} + +// ----- +func.func @test_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<24x22x52x42xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<24x22x52x42xi64> + return %2 : tensor<24x22x52x42xi64> +// CHECK-LABEL: test_pad_i64 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xi64>} : () -> tensor<1xi64> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_no_pad_i64 +// CHECK: return %arg0 +} + +// ----- +func.func @test_novalue_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x45x33xi64> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x45x33xi64> + return %2 : tensor<20x16x45x33xi64> +// CHECK-LABEL: test_novalue_pad_i64 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 0], [0, 0], [1, 0], [1, 0]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_novalue_no_pad_i64(%arg0: tensor<20x16x44x32xi64>) -> tensor<20x16x44x32xi64> { + %0 = "onnx.Constant"() {value = dense<[0, 0, 0, 0, 0, 0, 0, 0]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %0, %1, %1) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, none, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_novalue_no_pad_i64 +// CHECK: return %arg0 +} + +// ----- +func.func @test_no_const_pad_i64(%arg0: tensor<20x16x44x32xi64>, %arg1: tensor<8xi64>, %arg2: tensor<1xi64>) -> tensor<20x16x44x32xi64> { + %noval = "onnx.NoValue"() {value} : () -> none + %2 = "onnx.Pad"(%arg0, %arg1, %arg2, %noval) {mode = "constant"} : (tensor<20x16x44x32xi64>, tensor<8xi64>, tensor<1xi64>, none) -> tensor<20x16x44x32xi64> + return %2 : tensor<20x16x44x32xi64> +// CHECK-LABEL: test_no_const_pad_i64 +// CHECK: "onnx.Pad" +} + +// ----- +func.func @test_pad_ui32(%arg0: tensor<20x16x44x32xui32>) -> tensor<24x22x52x42xui32> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4]> : tensor<1xui32>} : () -> tensor<1xui32> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xui32>, tensor<8xi64>, tensor<1xui32>, none) -> tensor<24x22x52x42xui32> + return %2 : tensor<24x22x52x42xui32> +// CHECK-LABEL: test_pad_ui32 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + +// ----- +func.func @test_pad_bf16(%arg0: tensor<20x16x44x32xbf16>) -> tensor<24x22x52x42xbf16> { + %noval = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Constant"() {value = dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>} : () -> tensor<8xi64> + %1 = "onnx.Constant"() {value = dense<[4.500000e+00]> : tensor<1xbf16>} : () -> tensor<1xbf16> + %2 = "onnx.Pad"(%arg0, %0, %1, %noval) {mode = "constant"} : (tensor<20x16x44x32xbf16>, tensor<8xi64>, tensor<1xbf16>, none) -> tensor<24x22x52x42xbf16> + return %2 : tensor<24x22x52x42xbf16> +// CHECK-LABEL: test_pad_bf16 +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[{{\[}}0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>}> : () -> tensor<4x2xi64> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<4.500000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[VAR1]] +} + diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir index 73a653be96..4501f1ed7e 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir @@ -19,3 +19,13 @@ func.func @test_reshape_allowzero(%arg0 : tensor<12x128x1024xf32>) -> tensor<12x // CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<12x128x1024xf32>) -> tensor<12x128x16x64xf32> // CHECK-NEXT: return [[VAR_1_]] : tensor<12x128x16x64xf32> } + +func.func @test_reshape_fp8(%arg0 : tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> { + %0 = "onnx.Constant"() {value = dense<[-1, 128, 16, 64]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<128x1024xf8E5M2FNUZ>, tensor<4xi64>) -> tensor<1x128x16x64xf8E5M2FNUZ> + "func.return"(%1) : (tensor<1x128x16x64xf8E5M2FNUZ>) -> () +// CHECK-LABEL: @test_reshape_fp8 +// CHECK-SAME: ([[PARAM_0_:%.+]] tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> { +// CHECK: [[VAR_1_:%.+]] = tosa.reshape %arg0 {new_shape = array} : (tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> +// CHECK-NEXT: return [[VAR_1_]] : tensor<1x128x16x64xf8E5M2FNUZ> + } diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir index c3e372533a..af185ced75 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir @@ -102,6 +102,18 @@ func.func @test_resize_half_pixel_nearest_floor_downsample(%arg0: tensor<1x1x1x1 // ----- +func.func @test_resize_f64(%arg0: tensor<1x1x1x4xf64>) -> tensor<1x1x1x12xf64> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Constant"() {value = dense<[1, 1, 1, 12]> : tensor<4xi64>} : () -> tensor<4xi64> + %2 = "onnx.Resize"(%arg0, %0, %0, %1) {coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, mode = "nearest", nearest_mode = "floor"} : (tensor<1x1x1x4xf64>, none, none, tensor<4xi64>) -> tensor<1x1x1x12xf64> + return %2 : tensor<1x1x1x12xf64> +// CHECK-LABEL: func.func @test_resize_f64 +// CHECK-NOT: onnx.Resize +// CHECK: return {{.*}}: tensor<1x1x1x12xf64> +} + +// ----- + func.func @test_resize_input_one(%arg0: tensor<1x1x1x1xf32>) -> tensor<1x1x4x4xf32> { %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.Constant"() {value = dense<[1, 1, 4, 4]> : tensor<4xi64>} : () -> tensor<4xi64> @@ -239,4 +251,4 @@ func.func @test_resize_cubic_disallowed(%arg0: tensor<1x1x2x4xf32>) -> tensor<1x return %2 : tensor<1x1x2x8xf32> // CHECK-LABEL: func.func @test_resize_cubic_disallowed // CHECK-LABEL: onnx.Resize -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Shrink.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Shrink.mlir new file mode 100644 index 0000000000..d7ec449791 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Shrink.mlir @@ -0,0 +1,52 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_shrink_float(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { + %2 = "onnx.Shrink"(%arg0) {lambd = -7.500000e-01 : f32, bias = 5.000000e-01 : f32} : (tensor<4x4xf32>) -> tensor<4x4xf32> + return %2 : tensor<4x4xf32> +// CHECK-LABEL: func.func @test_shrink_float( +// CHECK: %0 = "tosa.const"() <{value = dense<-7.500000e-01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %1 = "tosa.const"() <{value = dense<7.500000e-01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %2 = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %3 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %4 = tosa.greater %1, %arg0 : (tensor<1x1xf32>, tensor<4x4xf32>) -> tensor<4x4xi1> +// CHECK: %5 = tosa.add %arg0, %2 : (tensor<4x4xf32>, tensor<1x1xf32>) -> tensor<4x4xf32> +// CHECK: %6 = tosa.select %4, %5, %3 : (tensor<4x4xi1>, tensor<4x4xf32>, tensor<1x1xf32>) -> tensor<4x4xf32> +// CHECK: %7 = tosa.greater %arg0, %0 : (tensor<4x4xf32>, tensor<1x1xf32>) -> tensor<4x4xi1> +// CHECK: %8 = tosa.sub %arg0, %2 : (tensor<4x4xf32>, tensor<1x1xf32>) -> tensor<4x4xf32> +// CHECK: %9 = tosa.select %7, %8, %6 : (tensor<4x4xi1>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK: return %9 : tensor<4x4xf32> +} + +func.func @test_shrink_int(%arg0: tensor<4x4xi8>) -> tensor<4x4xi8> { + %2 = "onnx.Shrink"(%arg0) {lambd = -7.500000e-01 : f32, bias = 5.000000e-01 : f32} : (tensor<4x4xi8>) -> tensor<4x4xi8> + return %2 : tensor<4x4xi8> +// CHECK-LABEL: func.func @test_shrink_int( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xi8>) -> tensor<4x4xi8> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_2]], %[[VAL_0]] : (tensor<1x1xi8>, tensor<4x4xi8>) -> tensor<4x4xi1> +// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_0]], %[[VAL_3]] : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : (tensor<4x4xi1>, tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_0]], %[[VAL_1]] : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi1> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_0]], %[[VAL_3]] : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] : (tensor<4x4xi1>, tensor<4x4xi8>, tensor<4x4xi8>) -> tensor<4x4xi8> +// CHECK: return %[[VAL_9]] : tensor<4x4xi8> +// CHECK: } +} + +func.func @test_shrink_int_constants_are_one(%arg0: tensor<4x4xi8>) -> tensor<4x4xi8> { + %2 = "onnx.Shrink"(%arg0) {lambd = 1.000000e00 : f32, bias = 1.000000e00 : f32} : (tensor<4x4xi8>) -> tensor<4x4xi8> + return %2 : tensor<4x4xi8> +// CHECK-LABEL: func.func @test_shrink_int_constants_are_one( +// CHECK: %0 = "tosa.const"() <{value = dense<1> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %1 = "tosa.const"() <{value = dense<-1> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %2 = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> +// CHECK: %3 = tosa.greater %1, %arg0 : (tensor<1x1xi8>, tensor<4x4xi8>) -> tensor<4x4xi1> +// CHECK: %4 = tosa.add %arg0, %0 : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %5 = tosa.select %3, %4, %2 : (tensor<4x4xi1>, tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %6 = tosa.greater %arg0, %0 : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi1> +// CHECK: %7 = tosa.sub %arg0, %0 : (tensor<4x4xi8>, tensor<1x1xi8>) -> tensor<4x4xi8> +// CHECK: %8 = tosa.select %6, %7, %5 : (tensor<4x4xi1>, tensor<4x4xi8>, tensor<4x4xi8>) -> tensor<4x4xi8> +// CHECK: return %8 : tensor<4x4xi8> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Slice.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Slice.mlir new file mode 100644 index 0000000000..c81ffa5e76 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Slice.mlir @@ -0,0 +1,59 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + + +func.func @test_slice_constant_default_steps(%arg0 : tensor<2x4xf32>) -> tensor<1x3xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, none) -> tensor<1x3xf32> + "func.return"(%1) : (tensor<1x3xf32>) -> () +// CHECK-LABEL: func @test_slice_constant_default_steps +// CHECK: %0 = tosa.slice %arg0 {size = array, start = array} : (tensor<2x4xf32>) -> tensor<1x3xf32> +} + +func.func @test_slice_all_constant_negative(%arg0 : tensor<2x4xf32>) -> tensor<1x3xf32> { + %axes = "onnx.Constant"() {value = dense<[0, -1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[2, -1]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32> + "func.return"(%1) : (tensor<1x3xf32>) -> () +// CHECK-LABEL: func @test_slice_all_constant_negative +// CHECK: %0 = tosa.slice %arg0 {size = array, start = array} : (tensor<2x4xf32>) -> tensor<1x3xf32> +} + +func.func @test_slice_all_constant_end_outofbound(%arg0 : tensor<2x4xf32>) -> tensor<1x3xf32> { + %axes = "onnx.Constant"() {value = dense<[0, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %starts = "onnx.Constant"() {value = dense<[1, 0]> : tensor<2xi64> } : () -> tensor<2xi64> + %ends = "onnx.Constant"() {value = dense<[5, 3]> : tensor<2xi64> } : () -> tensor<2xi64> + %steps = "onnx.Constant"() {value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> + %1 = "onnx.Slice"(%arg0, %starts, %ends, %axes, %steps) : (tensor<2x4xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32> + "func.return"(%1) : (tensor<1x3xf32>) -> () +// CHECK-LABEL: func @test_slice_all_constant_end_outofbound +// CHECK: %0 = tosa.slice %arg0 {size = array, start = array} : (tensor<2x4xf32>) -> tensor<1x3xf32> +} + +// ----- + +func.func @slice_all_dynamic(%arg0: tensor<20x10x5xf32>, + %arg1: tensor<1xi64>, + %arg2: tensor<1xi64>, + %arg3: tensor<1xi64>, + %arg4: tensor<1xi64>) + -> tensor<20x9x5xf32> { + %0 = "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<20x10x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<20x9x5xf32> + return %0 : tensor<20x9x5xf32> +} + +// ----- + +func.func @slice_default_axes(%arg0: tensor<20x10x5xf32>, + %arg1: tensor<3xi64>, + %arg2: tensor<3xi64>) + -> tensor<20x10x1xf32> { + %0 = onnx.Constant dense<[0, 1, 2]> : tensor<3xi64> + %1 = onnx.Constant dense<1> : tensor<3xi64> + %2 = "onnx.Slice"(%arg0, %arg1, %arg2, %0, %1) {onnx_node_name = "onnx.Slice_0"} : (tensor<20x10x5xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<20x10x1xf32> + return %2 : tensor<20x10x1xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir new file mode 100644 index 0000000000..367d3ce988 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Split.mlir @@ -0,0 +1,119 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_split_equal(%arg0 : tensor<16x32x64xf32>) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = 0 : si64} : (tensor<16x32x64xf32>, none) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) + return %0, %1 : tensor<8x32x64xf32>, tensor<8x32x64xf32> +} + +// CHECK-LABEL: func.func @test_split_equal +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf32>) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<8x32x64xf32>, tensor<8x32x64xf32> + +// ----- + +func.func @test_split_variable(%arg0 : tensor<16x32x64xf16>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) { + %split = "onnx.Constant"() {value = dense<[2, 30]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xf16>, tensor<2xi64>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) + "func.return"(%0, %1) : (tensor<16x2x64xf16>, tensor<16x30x64xf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_variable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf16>) -> (tensor<16x2x64xf16>, tensor<16x30x64xf16>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x2x64xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x30x64xf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<16x2x64xf16>, tensor<16x30x64xf16> + +// ----- + +func.func @test_split_multiple(%arg0 : tensor<16x32x64xf16>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) { + %split = "onnx.Constant"() {value = dense<[4, 8, 20]> : tensor<3xi64>} : () -> tensor<3xi64> + %0, %1, %2 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xf16>, tensor<3xi64>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) + "func.return"(%0, %1, %2) : (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_multiple +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf16>) -> (tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x4x64xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x8x64xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf16>) -> tensor<16x20x64xf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]], [[VAR_2_]] : tensor<16x4x64xf16>, tensor<16x8x64xf16>, tensor<16x20x64xf16> + + +// ----- + +func.func @test_no_split(%arg0 : tensor<16x32x64xi32>) -> tensor<16x16x64xi32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = 1 : si64} : (tensor<16x32x64xi32>, none) -> (tensor<16x16x64xi32>, tensor<16x16x64xi32>) + "func.return"(%0) : (tensor<16x16x64xi32>) -> () +} + +// CHECK-LABEL: func.func @test_no_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xi32>) -> tensor<16x16x64xi32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xi32>) -> tensor<16x16x64xi32> +// CHECK: return [[VAR_0_]] : tensor<16x16x64xi32> + + +// ----- + +func.func @test_split_negative_axis(%arg0 : tensor<16x32x64xbf16>) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) { axis = -2 : si64} : (tensor<16x32x64xbf16>, none) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) + "func.return"(%0, %1) : (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) -> () +} + +// CHECK-LABEL: func.func @test_split_negative_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xbf16>) -> (tensor<16x16x64xbf16>, tensor<16x16x64xbf16>) { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xbf16>) -> tensor<16x16x64xbf16> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xbf16>) -> tensor<16x16x64xbf16> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<16x16x64xbf16>, tensor<16x16x64xbf16> + +// ----- + +func.func @test_non_constant_split(%arg0 : tensor<16x32x64xi16>, %arg1 : tensor<2xi64>) -> tensor<16x?x64xi16> { + %0, %1 = "onnx.Split"(%arg0, %arg1) {axis = 1 : si64} : (tensor<16x32x64xi16>, tensor<2xi64>) -> (tensor<16x?x64xi16>, tensor<16x?x64xi16>) + "func.return"(%0) : (tensor<16x?x64xi16>) -> () +} + +// CHECK-LABEL: func.func @test_non_constant_split +// CHECK-NOT: tosa.slice + +// ----- + +func.func @test_zero_split(%arg0 : tensor<16x32x64xi16>) -> tensor<16x0x64xi16> { + %split = "onnx.Constant"() {value = dense<[32, 0]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x64xi16>, tensor<2xi64>) -> (tensor<16x32x64xi16>, tensor<16x0x64xi16>) + "func.return"(%1) : (tensor<16x0x64xi16>) -> () +} + +// CHECK-LABEL: func.func @test_zero_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xi16>) -> tensor<16x0x64xi16> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xi16>) -> tensor<16x0x64xi16> +// CHECK: return [[VAR_0_]] : tensor<16x0x64xi16> + +// ----- +// Legalization won't happen since tosa.slice doesn't +// allow dynamic entry in 'size' attribute +func.func @test_dynamic_shapes(%arg0 : tensor<16x32x?xf32>) -> tensor<16x2x?xf32> { + %split = "onnx.Constant"() {value = dense<[2, 30]> : tensor<2xi64>} : () -> tensor<2xi64> + %0, %1 = "onnx.Split"(%arg0, %split) {axis = 1 : si64} : (tensor<16x32x?xf32>, tensor<2xi64>) -> (tensor<16x2x?xf32>, tensor<16x30x?xf32>) + return %0 : tensor<16x2x?xf32> +} + +// CHECK-LABEL: func.func @test_dynamic_shapes +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x?xf32>) -> tensor<16x2x?xf32> { +// CHECK-NOT: tosa.slice + +// ----- +func.func @test_num_outputs(%arg0 : tensor<16x32x64xf32>) -> tensor<8x32x64xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0, %1 = "onnx.Split"(%arg0, %cst) {axis = 0 : si64, num_outputs = 2 : si64} : (tensor<16x32x64xf32>, none) -> (tensor<8x32x64xf32>, tensor<8x32x64xf32>) + return %0 : tensor<8x32x64xf32> +} + +// CHECK-LABEL: func.func @test_num_outputs +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x32x64xf32>) -> tensor<8x32x64xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<16x32x64xf32>) -> tensor<8x32x64xf32> +// CHECK: return [[VAR_0_]] : tensor<8x32x64xf32> diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Squeeze.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Squeeze.mlir new file mode 100644 index 0000000000..dd2465cdcf --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Squeeze.mlir @@ -0,0 +1,41 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa --canonicalize %s -split-input-file | FileCheck %s + +func.func @test_squeeze(%arg0 : tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32> { + %0 = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.Squeeze"(%arg0, %0) : (tensor<16x1x32x1x64xf32>, tensor<2xi64>) -> (tensor<16x32x64xf32>) + "func.return"(%1) : (tensor<16x32x64xf32>) -> () +// CHECK-LABEL: func.func @test_squeeze +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<16x1x32x1x64xf32>) -> tensor<16x32x64xf32> +// CHECK: return [[VAR_0_]] : tensor<16x32x64xf32> +// CHECK: } +} + +func.func @test_squeeze_unknown_dimensions(%arg0 : tensor<1x1x32x1x64xf32>) -> tensor<32x64xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Squeeze"(%arg0, %0) : (tensor<1x1x32x1x64xf32>, none) -> (tensor<32x64xf32>) + "func.return"(%1) : (tensor<32x64xf32>) -> () +// CHECK-LABEL: func.func @test_squeeze_unknown_dimensions +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x32x1x64xf32>) -> tensor<32x64xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x1x32x1x64xf32>) -> tensor<32x64xf32> +// CHECK: return [[VAR_0_]] : tensor<32x64xf32> +// CHECK: } +} + +// ----- + +func.func @squeeze_runtime(%arg0: tensor<1x3x4x5xf32> , %arg1: tensor<1xi64> ) -> tensor<3x4x5xf32> { + %0 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<1x3x4x5xf32>, tensor<1xi64>) -> tensor<3x4x5xf32> + return %0 : tensor<3x4x5xf32> +// CHECK-LABEL: squeeze_runtime +// CHECK: tosa.reshape {{.*}} {new_shape = array} : (tensor<1x3x4x5xf32>) -> tensor<3x4x5xf32> +} + +// ----- + +func.func @squeeze_dynamic(%arg0: tensor<1x3x4x5xf32> , %arg1: tensor<1xi64> ) -> tensor { + %0 = "onnx.Squeeze"(%arg0, %arg1) : (tensor<1x3x4x5xf32>, tensor<1xi64>) -> tensor + return %0 : tensor +// CHECK-LABEL: squeeze_dynamic +// CHECK: onnx.Squeeze +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir new file mode 100644 index 0000000000..fe8aa6f1e1 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Tile.mlir @@ -0,0 +1,58 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_tile(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x10x30x32xf32> + "func.return"(%tile) : (tensor<5x10x30x32xf32>) -> () +// CHECK-LABEL: test_tile +// CHECK: tosa.tile{{.*}} {multiples = array} : (tensor<5x5x1x32xf32>) -> tensor<5x10x30x32xf32> +} + +// ----- + +func.func @test_tile_dynamic_shape(%arg0 : tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x?x32xf32>, tensor<4xi64>) -> tensor<5x10x?x32xf32> + "func.return"(%tile) : (tensor<5x10x?x32xf32>) -> () +// CHECK-LABEL: test_tile_dynamic_shape +// CHECK: tosa.tile{{.*}} {multiples = array} : (tensor<5x5x?x32xf32>) -> tensor<5x10x?x32xf32> +} + +// ----- + +func.func @test_tile_input_not_ranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<*xf32>, tensor<4xi64>) -> tensor<*xf32> + "func.return"(%tile) : (tensor<*xf32>) -> () +// CHECK-LABEL: test_tile_input_not_ranked +// CHECK-NOT: tosa.tile +} + +// ----- + +func.func @test_tile_non_constant_reps(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> { + %tile = "onnx.Tile"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32> + "func.return"(%tile) : (tensor<*xf32>) -> () +// CHECK-LABEL: test_tile_non_constant_reps +// CHECK-NOT: tosa.tile +} + +// ----- + +func.func @test_tile_no_tosa_type(%arg0 : tensor<5x5x1x32xcomplex>) -> tensor<5x10x30x32xcomplex> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xcomplex>, tensor<4xi64>) -> tensor<5x10x30x32xcomplex> + "func.return"(%tile) : (tensor<5x10x30x32xcomplex>) -> () +// CHECK-LABEL: test_tile_no_tosa_type +// CHECK-NOT: tosa.tile +} + +// ----- + +func.func @test_tile_no_valid_tosa_tile_type(%arg0 : tensor<5x5x1x32xf64>) -> tensor<5x10x30x32xf64> { + %const = onnx.Constant dense<[1, 2, 30, 1]> : tensor<4xi64> + %tile = "onnx.Tile"(%arg0, %const) : (tensor<5x5x1x32xf64>, tensor<4xi64>) -> tensor<5x10x30x32xf64> + "func.return"(%tile) : (tensor<5x10x30x32xf64>) -> () +// CHECK-LABEL: test_tile_no_valid_tosa_tile_type +// CHECK-NOT: tosa.tile +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir new file mode 100644 index 0000000000..1c0e27ffc1 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir @@ -0,0 +1,42 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> { + %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> + "func.return"(%0) : (tensor<32x1x5x5xf32>) -> () +// CHECK-LABEL: func.func @test_default_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<32x1x5x5xf32> +// CHECK: return %[[VAL_2]] : tensor<32x1x5x5xf32> +} + +// ----- + +func.func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x1x32x5xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<5x1x32x5xf32> + "func.return"(%0) : (tensor<5x1x32x5xf32>) -> () +// CHECK-LABEL: func.func @test_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x1x32xf32>) -> tensor<5x1x32x5xf32> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x1x32x5xf32> +// CHECK: return %[[VAL_2]] : tensor<5x1x32x5xf32> +} + +// ----- + +func.func @test_transpose_f64(%arg0 : tensor<5x5x1x32xf64>) -> tensor<5x1x32x5xf64> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<5x5x1x32xf64>) -> tensor<5x1x32x5xf64> + return %0 : tensor<5x1x32x5xf64> +// CHECK-LABEL: func.func @test_transpose +// CHECK-NOT: onnx.Transpose +// CHECK: return {{.*}}: tensor<5x1x32x5xf64> +} + +// ----- + +func.func @test_transpose_dyn(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<*xf32>) -> tensor<*xf32> + "func.return"(%0) : (tensor<*xf32>) -> () +// CHECK-LABEL: func.func @test_transpose_dyn +// CHECK: onnx.Transpose +} \ No newline at end of file diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Unsqueeze.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Unsqueeze.mlir new file mode 100644 index 0000000000..6d33fa9b25 --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Unsqueeze.mlir @@ -0,0 +1,51 @@ +// RUN: onnx-mlir-opt --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + +func.func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<1x10x10x1xf32> { + %0 = "onnx.Constant"() {value = dense<[0, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.Unsqueeze"(%arg0, %0) : (tensor<10x10xf32>, tensor<2xi64>) -> tensor<1x10x10x1xf32> + func.return %1 : tensor<1x10x10x1xf32> +// CHECK-LABEL: func.func @test_unsqueeze( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x10xf32>) -> tensor<1x10x10x1xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<10x10xf32>) -> tensor<1x10x10x1xf32> +// CHECK: return %[[VAL_1]] : tensor<1x10x10x1xf32> +// CHECK: } +} + +func.func @test_unsqueeze_negative_axis(%arg0 : tensor<16x32x64xf32>) -> tensor<16x32x1x64xf32> { + %0 = "onnx.Constant"() {value = dense<[-2]> : tensor<1xi64>} : () -> tensor<1xi64> + %1 = "onnx.Unsqueeze"(%arg0, %0) : (tensor<16x32x64xf32>, tensor<1xi64>) -> tensor<16x32x1x64xf32> + func.return %1 : tensor<16x32x1x64xf32> +// CHECK-LABEL: func.func @test_unsqueeze_negative_axis( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x64xf32>) -> tensor<16x32x1x64xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<16x32x64xf32>) -> tensor<16x32x1x64xf32> +// CHECK: return %[[VAL_1]] : tensor<16x32x1x64xf32> +// CHECK: } +} + +func.func @test_unsqueeze_mix(%arg0 : tensor<16x32x64xf32>) -> tensor<16x1x32x1x64xf32> { + %0 = "onnx.Constant"() {value = dense<[1, -2]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "onnx.Unsqueeze"(%arg0, %0) : (tensor<16x32x64xf32>, tensor<2xi64>) -> tensor<16x1x32x1x64xf32> + func.return %1 : tensor<16x1x32x1x64xf32> +// CHECK-LABEL: func.func @test_unsqueeze_mix( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32x64xf32>) -> tensor<16x1x32x1x64xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array} : (tensor<16x32x64xf32>) -> tensor<16x1x32x1x64xf32> +// CHECK: return %[[VAL_1]] : tensor<16x1x32x1x64xf32> +// CHECK: } +} + +// ----- + +func.func @unsqueeze_runtime(%arg0: tensor<3x4x5xf32> , %arg1: tensor<1xi64> ) -> tensor<3x4x1x5xf32> { + %0 = "onnx.Unsqueeze"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<1xi64>) -> tensor<3x4x1x5xf32> + return %0 : tensor<3x4x1x5xf32> +// CHECK-LABEL: unsqueeze_runtime +// CHECK: tosa.reshape {{.*}} {new_shape = array} : (tensor<3x4x5xf32>) -> tensor<3x4x1x5xf32> +} +// ----- + +func.func @unsqueeze_dynamic(%arg0: tensor<1x3x4x5xf32> , %arg1: tensor<1xi64> ) -> tensor { + %0 = "onnx.Unsqueeze"(%arg0, %arg1) : (tensor<1x3x4x5xf32>, tensor<1xi64>) -> tensor + return %0 : tensor +// CHECK-LABEL: unsqueeze_dynamic +// CHECK: onnx.Unsqueeze +} diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Where.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Where.mlir new file mode 100644 index 0000000000..e5049cb64f --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Where.mlir @@ -0,0 +1,44 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa -cse %s -split-input-file | FileCheck %s + + +func.func @test_where(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xf32>, %arg2: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func @test_where +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>, [[PARAM_2_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.select [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] : (tensor<13x21x1xi1>, tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +} + +// ----- + +func.func @test_where_broadcast(%arg0: tensor<21x1xi1>, %arg1: tensor<13x21x1xf32>, %arg2: tensor<1xf32>) -> tensor<13x21x1xf32> { + %0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<21x1xi1>, tensor<13x21x1xf32>, tensor<1xf32>) -> tensor<13x21x1xf32> + "func.return"(%0) : (tensor<13x21x1xf32>) -> () +// CHECK-LABEL: func.func @test_where_broadcast +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<21x1xi1>) -> tensor<1x21x1xi1> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.select [[VAR_0_]], [[PARAM_1_]], [[VAR_1_]] : (tensor<1x21x1xi1>, tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +// CHECK: return [[VAR_2_]] : tensor<13x21x1xf32> +} + +// ----- + +func.func @test_where_ui32(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xui32>, %arg2: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> { + %0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32> + "func.return"(%0) : (tensor<13x21x1xui32>) -> () +// CHECK-LABEL: func.func @test_where_ui32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi1>, [[PARAM_1_:%.+]]: tensor<13x21x1xui32>, [[PARAM_2_:%.+]]: tensor<13x21x1xui32>) -> tensor<13x21x1xui32> { +// CHECK: [[VAR_0_:%.+]] = tosa.select [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] : (tensor<13x21x1xi1>, tensor<13x21x1xui32>, tensor<13x21x1xui32>) -> tensor<13x21x1xui32> +// CHECK: return [[VAR_0_]] : tensor<13x21x1xui32> +} + +// ----- + +func.func @test_where_f64(%arg0: tensor<13x21x1xi1>, %arg1: tensor<13x21x1xf64>, %arg2: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> { + %0 = "onnx.Where"(%arg0, %arg1, %arg2) : (tensor<13x21x1xi1>, tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64> + "func.return"(%0) : (tensor<13x21x1xf64>) -> () +// CHECK-LABEL: func.func @test_where_f64 +// CHECK-NOT: onnx.Where +// CHECK: return {{.*}}: tensor<13x21x1xf64> +} \ No newline at end of file diff --git a/test/mlir/driver/bytecode.mlir b/test/mlir/driver/bytecode.mlir new file mode 100644 index 0000000000..0acf722424 --- /dev/null +++ b/test/mlir/driver/bytecode.mlir @@ -0,0 +1,5 @@ +// RUN: onnx-mlir --tag="encoder_model" --printBytecode %s | onnx-mlir-opt | FileCheck %s + +// CHECK: module attributes {{{.*}}"onnx-mlir.symbol-postfix" = "encoder_model"} +module { +} diff --git a/test/mlir/driver/compile_phases.mlir b/test/mlir/driver/compile_phases.mlir index 3e94ccbfb0..5391164d4a 100644 --- a/test/mlir/driver/compile_phases.mlir +++ b/test/mlir/driver/compile_phases.mlir @@ -1,14 +1,38 @@ -// RUN: onnx-mlir %s -o %t| FileCheck %s && rm %t.so - -// CHECK: [1/6] {{.*}} Importing ONNX Model to MLIR Module from -// CHECK: [2/6] {{.*}} Compiling and Optimizing MLIR Module -// CHECK: [3/6] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode -// CHECK: [4/6] {{.*}} Generating Object from LLVM Bitcode -// CHECK: [5/6] {{.*}} Linking and Generating the Output Shared Library -// CHECK: [6/6] {{.*}} Compilation completed +// RUN: onnx-mlir %s -o %t 2>&1 | FileCheck --check-prefix=EMIT-LIB %s && rm %t.so +// RUN: onnx-mlir %s --EmitObj -o %t 2>&1 | FileCheck --check-prefix=EMIT-OBJ %s && rm %t.o +// Disabled as jni libs do not exist in test env (AMD): onnx-mlir %s --EmitJNI -o %t 2>&1 | FileCheck --check-prefix=EMIT-JNI %s && rm %t.jar +// RUN: onnx-mlir %s --EmitLLVMIR -o %t 2>&1 | FileCheck --check-prefix=EMIT-LLVMIR %s && rm %t.onnx.mlir + +// EMIT-LIB: [1/6] {{.*}} Importing ONNX Model to MLIR Module from +// EMIT-LIB: [2/6] {{.*}} Compiling and Optimizing MLIR Module +// EMIT-LIB: [3/6] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode +// EMIT-LIB: [4/6] {{.*}} Generating Object from LLVM Bitcode +// EMIT-LIB: [5/6] {{.*}} Linking and Generating the Output Shared Library +// EMIT-LIB: [6/6] {{.*}} Compilation completed + +// EMIT-OBJ: [1/5] {{.*}} Importing ONNX Model to MLIR Module from +// EMIT-OBJ: [2/5] {{.*}} Compiling and Optimizing MLIR Module +// EMIT-OBJ: [3/5] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode +// EMIT-OBJ: [4/5] {{.*}} Generating Object from LLVM Bitcode +// EMIT-OBJ: [5/5] {{.*}} Compilation completed + +// EMIT-JNI: [1/8] {{.*}} Importing ONNX Model to MLIR Module from +// EMIT-JNI: [2/8] {{.*}} Compiling and Optimizing MLIR Module +// EMIT-JNI: [3/8] {{.*}} Translating MLIR Module to LLVM and Generating LLVM Optimized Bitcode +// EMIT-JNI: [4/8] {{.*}} Generating Object from LLVM Bitcode +// EMIT-JNI: [5/8] {{.*}} Generating JNI Object +// EMIT-JNI: [6/8] {{.*}} Linking and Generating the Output Shared Library +// EMIT-JNI: [7/8] {{.*}} Creating JNI Jar +// EMIT-JNI: [8/8] {{.*}} Compilation completed + +// EMIT-LLVMIR: [1/3] {{.*}} Importing ONNX Model to MLIR Module from +// EMIT-LLVMIR: [2/3] {{.*}} Compiling and Optimizing MLIR Module +// EMIT-LLVMIR: [3/3] {{.*}} Compilation completed module { func.func @main_graph(%arg0: tensor) -> tensor { onnx.Return %arg0 : tensor } "onnx.EntryPoint"() {func = @main_graph} : () -> () } + + diff --git a/test/mlir/driver/do_not_emit_full_mlir.mlir b/test/mlir/driver/do_not_emit_full_mlir.mlir new file mode 100644 index 0000000000..794a7a0934 --- /dev/null +++ b/test/mlir/driver/do_not_emit_full_mlir.mlir @@ -0,0 +1,9 @@ +// RUN: onnx-mlir --EmitMLIR --do-not-emit-full-mlir-code %s -o %t && test ! -f %t.onnx.mlir && rm %t.tmp + +module { + func.func @main_graph(%arg0: tensor) -> tensor { + onnx.Return %arg0 : tensor + } + "onnx.EntryPoint"() {func = @main_graph} : () -> () +} + diff --git a/test/mlir/lit.cfg.py b/test/mlir/lit.cfg.py index 8f6f19bea4..bf1706a95f 100644 --- a/test/mlir/lit.cfg.py +++ b/test/mlir/lit.cfg.py @@ -26,6 +26,11 @@ llvm_config.use_default_substitutions() +config.excludes = ["onnx_to_mhlo"] + +# Xilinx fork: Don't care about krnl dialect. Simplifies LLVM bumps +config.excludes += ["onnx_to_krnl", "krnl_to_affine", "krnl_to_llvm"] + # Tweak the PATH to include the tools dir. llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index f91d261eaa..fcb743876c 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -182,6 +182,15 @@ func.func @test_constantofshape_verifier_4() -> tensor<2xi64> { // ----- +func.func @test_constantofshape_elided() -> tensor<2xi64> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor<2xi64> + %1 = "onnx.ConstantOfShape"(%0) : (tensor<2xi64>) -> tensor<2xi64> + "onnx.Return"(%1) : (tensor<2xi64>) -> () +} + +// ----- + func.func @test_flatten_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Flatten: 'axis' value is 5, accepted range is [-4, 4]}} %1 = "onnx.Flatten"(%arg0) {axis = 5 : si64} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> @@ -214,6 +223,15 @@ func.func @test_gatherElements_verifier_2(%data: tensor<2x2xf32>, %indices: tens // ----- +func.func @test_gatherElements_verifier_elided(%data: tensor<12x14x1024xf32>) -> tensor<12x14x14xf32> { + // Tests that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<12x14x14xi64> + %1 = "onnx.GatherElements"(%data, %indices) {axis = -1 : si64} : (tensor<12x14x1024xf32>, tensor<12x14x14xi64>) -> tensor<12x14x14xf32> + "onnx.Return"(%1) : (tensor<12x14x14xf32>) -> () +} + +// ----- + func.func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Hardmax: 'axis' value is 3, accepted range is [-2, 1]}} %1 = "onnx.Hardmax"(%arg0) {axis = 3: si64} : (tensor<2x2xf32>) -> tensor<*xf32> @@ -307,6 +325,16 @@ func.func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32 // expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}} %indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64> %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () +} + +// ----- + +func.func @test_gatherND_verifier_elided(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> { + // Test that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<3x3x2xi64> + %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () } // ----- @@ -580,6 +608,15 @@ func.func @test_splitToSequence_verifier_6(%arg0: tensor<2x2xf32>) -> !onnx.Seq< // ----- +func.func @test_splitToSequence_verifier_elided(%arg0: tensor<2x2xf32>) -> !onnx.Seq> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor + %1 = "onnx.SplitToSequence"(%arg0, %0) : (tensor<2x2xf32>, tensor) -> !onnx.Seq> + "onnx.Return"(%1) : (!onnx.Seq>) -> () +} + +// ----- + func.func @test_topK_verifier_1(%arg0: tensor<3x4xi64>, %arg1: tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>) { // expected-error @+1 {{onnx.TopK: 'axis' value is 2, accepted range is [-2, 1]}} %1, %2 = "onnx.TopK"(%arg0, %arg1) {axis = 2 : si64, largest = 1 : si64, sorted = 1 : si64} : (tensor<3x4xi64>, tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>) @@ -682,3 +719,147 @@ func.func @test_matmulinteger_wrong_B_broadcast(%arg0: tensor<16x32xui8>, %arg1: %0 = "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (tensor<16x32xui8>, tensor<5x32x64xui8>, tensor<16xui8>, tensor<5x1x2xui8>) -> tensor<5x16x64xi32> onnx.Return %0 : tensor<5x16x64xi32> } + +// ----- + +func.func @test_add_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_sub_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Sub"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_mul_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Mul"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_div_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Div"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_equal_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xi1> { + // expected-error @+1 {{op requires the same element type for all operands}} + %0 = "onnx.Equal"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xi1> + onnx.Return %0 : tensor<5x16xi1> +} + +// ----- + +func.func @test_greater_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xi1> { + // expected-error @+1 {{op requires the same element type for all operands}} + %0 = "onnx.Greater"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xi1> + onnx.Return %0 : tensor<5x16xi1> +} + +// ----- + +func.func @test_greater_or_equal_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xi1> { + // expected-error @+1 {{op requires the same element type for all operands}} + %0 = "onnx.GreaterOrEqual"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xi1> + onnx.Return %0 : tensor<5x16xi1> +} + +// ----- + +func.func @test_less_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xi1> { + // expected-error @+1 {{op requires the same element type for all operands}} + %0 = "onnx.Less"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xi1> + onnx.Return %0 : tensor<5x16xi1> +} + +// ----- + +func.func @test_less_or_equal_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xi1> { + // expected-error @+1 {{op requires the same element type for all operands}} + %0 = "onnx.LessOrEqual"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xi1> + onnx.Return %0 : tensor<5x16xi1> +} + +// ----- + +func.func @test_min_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Min"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_max_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Max"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_mod_diff_element_type(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xbf16>) -> tensor<5x16xf32> { + // expected-error @+1 {{op requires the same element type for all operands and results}} + %0 = "onnx.Mod"(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xbf16>) -> tensor<5x16xf32> + onnx.Return %0 : tensor<5x16xf32> +} + +// ----- + +func.func @test_grid_sample_diff_ranks(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Input(=4) and grid(=3) have different dim sizes.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_diff_batch(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Input and grid must have the same batch value.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_align_corners(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op align_corners needs to be 0 or 1}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 2 : si64, mode = "linear", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_mode(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op mode needs to be linear, nearest or cubic}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "sampling", padding_mode = "border"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_padding(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op padding_mode needs to be zeros, border or reflection}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "cubic", padding_mode = "bottom"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func.func @test_grid_sample_wrong_dim_grid(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x3xf32>) -> tensor<*xf32> { + // expected-error @+1 {{'onnx.GridSample' op Grid last dim must have been '2' instead of '3'.}} + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x3xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 10f43761b8..809a7e3e79 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --shape-inference --canonicalize="test-convergence=true" --shape-inference --cse %s -split-input-file -verify-diagnostics | FileCheck %s // ----- @@ -59,45 +59,6 @@ func.func @test_dropout(%arg: tensor<10x10xf32>) -> (tensor<10x10xf32>, none) { // ----- -//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> { -func.func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> { - %cst = "onnx.NoValue"() {value} : () -> none - %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32> - %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32> - onnx.Return %1 : tensor<*xf32> - - // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> - // onnx.Return [[GEMM]] : tensor<*xf32> -} - -// ----- - -//CHECK-LABEL: @test_gemm_add_fusion_beta_zero(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> { -func.func @test_gemm_add_fusion_beta_zero(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> { - %cst = "onnx.NoValue"() {value} : () -> none - %0 = "onnx.Gemm"(%arg0, %arg1, %cst) {beta = 0.0 : f32}: (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32> - %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32> - onnx.Return %1 : tensor<*xf32> - - // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> - // onnx.Return [[GEMM]] : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { -func.func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> { - %cst = "onnx.NoValue"() {value} : () -> none - %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32> - %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<256xf32>) -> tensor<*xf32> - onnx.Return %1 : tensor<*xf32> - - // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> - // onnx.Return [[GEMM]] : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: @cast_elimination(%{{.*}}: tensor<2xf32>) -> tensor<2xf32> { func.func @cast_elimination(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "onnx.Cast"(%arg0) {to = f32} : (tensor<2xf32>) -> tensor<2xf32> @@ -131,8 +92,8 @@ func.func @cast_slice_swap(%arg0: tensor<3xi32>, %arg1: tensor<1xi64>, %arg2: te // CHECK-LABEL: func.func @cast_slice_swap // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>, [[PARAM_1_:%.+]]: tensor<1xi64>, [[PARAM_2_:%.+]]: tensor<1xi64>, [[PARAM_3_:%.+]]: tensor<1xi64>, [[PARAM_4_:%.+]]: tensor<1xi64>) -> tensor<1xi64> { -// CHECK: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<3xi32>) -> tensor<*xi64> -// CHECK: [[VAR_1_:%.+]] = "onnx.Slice"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) : (tensor<*xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i64} : (tensor<3xi32>) -> tensor<3xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Slice"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> // CHECK: onnx.Return [[VAR_1_]] : tensor<1xi64> // CHECK: } } @@ -150,14 +111,15 @@ func.func @test_conv_batchnormtestmode_fusion_nobias(%arg0: tensor<1x3x224x224xf // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.00000007E-5> : tensor<1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[PARAM_5_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> - // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<*xf32> - // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_2_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<3xi64>) -> tensor<*xf32> - // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_1_]], [[VAR_5_]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Neg"([[PARAM_4_]]) : (tensor<64xf32>) -> tensor<*xf32> - // CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_7_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_9_:%.+]] = "onnx.Add"([[PARAM_3_]], [[VAR_8_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_10_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_6_]], [[VAR_9_]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> + // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_2_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<3xi64>) -> tensor<64x1x1x1xf32> + // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_1_]], [[VAR_5_]]) : (tensor<64x3x7x7xf32>, tensor<64x1x1x1xf32>) -> tensor<64x3x7x7xf32> + // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Neg"([[PARAM_4_]]) : (tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_7_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_9_:%.+]] = "onnx.Add"([[PARAM_3_]], [[VAR_8_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_10_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_6_]], [[VAR_9_]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> + // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationInferenceMode"{{.*}} // CHECK: onnx.Return [[VAR_10_]] : tensor<1x64x112x112xf32> } @@ -174,14 +136,15 @@ func.func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, % // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.00000007E-5> : tensor<1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[PARAM_6_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> - // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<*xf32> - // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_3_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<3xi64>) -> tensor<*xf32> - // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_5_]]) : (tensor<64x3x7x7xf32>, tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_3_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<3xi64>) -> tensor<64x1x1x1xf32> + // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_5_]]) : (tensor<64x3x7x7xf32>, tensor<64x1x1x1xf32>) -> tensor<64x3x7x7xf32> // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Sub"([[PARAM_1_]], [[PARAM_5_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> - // CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_7_]]) : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32> - // CHECK: [[VAR_9_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_8_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_10_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_6_]], [[VAR_9_]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> + // CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_7_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_9_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_8_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_10_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_6_]], [[VAR_9_]]) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]} : (tensor<1x3x224x224xf32>, tensor<64x3x7x7xf32>, tensor<64xf32>) -> tensor<1x64x112x112xf32> + // CHECK-NOT: {{.*}} = "onnx.BatchNormalizationInferenceMode"{{.*}} // CHECK: onnx.Return [[VAR_10_]] : tensor<1x64x112x112xf32> } @@ -413,6 +376,122 @@ func.func @test_reshape_fusion3(%arg0: tensor) -> tensor // ----- +// No fusion should happen if multiple unknown dimensions (-1) are given. +func.func @test_reshape_no_fusion(%arg0: tensor<1x3x1152x1x1344xf32>) -> (tensor<1x3x576x2x326x326x2xf32>) { + %0 = onnx.Constant dense<[0, 0, -1, 2, 0]> : tensor<5xi64> loc(unknown) + %1 = onnx.Constant dense<[0, 0, 0, 0, -1, 2]> : tensor<6xi64> loc(unknown) + %5 = onnx.Constant dense<[0, 0, -1, 0, -1, 326, 2]> : tensor<7xi64> loc(unknown) + %2 = "onnx.Reshape"(%arg0, %0) {allowzero = 0 : si64} : (tensor<1x3x1152x1x1344xf32>, tensor<5xi64>) -> tensor<1x3x576x2x1344xf32> + %3 = "onnx.Reshape"(%2, %1) {allowzero = 0 : si64} : (tensor<1x3x576x2x1344xf32>, tensor<6xi64>) -> tensor<1x3x576x2x672x2xf32> + %7 = "onnx.Reshape"(%3, %5) {allowzero = 0 : si64} : (tensor<1x3x576x2x672x2xf32>, tensor<7xi64>) -> tensor<1x3x576x2x326x326x2xf32> + onnx.Return %7 : tensor<1x3x576x2x326x326x2xf32> + +// CHECK-LABEL: func.func @test_reshape_no_fusion +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1x1344xf32>) -> tensor<1x3x576x2x326x326x2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 0, -1, 2, 0]> : tensor<5xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, -1, 2]> : tensor<6xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[0, 0, -1, 0, -1, 326, 2]> : tensor<7xi64> +// CHECK: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<1x3x1152x1x1344xf32>, tensor<5xi64>) -> tensor<1x3x576x2x1344xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Reshape"([[VAR_3_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x3x576x2x1344xf32>, tensor<6xi64>) -> tensor<1x3x576x2x672x2xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Reshape"([[VAR_4_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<1x3x576x2x672x2xf32>, tensor<7xi64>) -> tensor<1x3x576x2x326x326x2xf32> +// CHECK: onnx.Return [[VAR_5_]] : tensor<1x3x576x2x326x326x2xf32> +// CHECK: } +} + +// ----- + +func.func @reshape_allowzero_to_reshape(%arg0: tensor<1x2048x1x1xbf16>) -> (tensor<1x1x1x2048xbf16>) { + %0 = onnx.Constant dense<[1, 1, 1, 2048]> : tensor<4xi64> loc(unknown) + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<1x2048x1x1xbf16>, tensor<4xi64>) -> tensor<1x1x1x2048xbf16> + return %1: tensor<1x1x1x2048xbf16> + +// CHECK-LABEL: func.func @reshape_allowzero_to_reshape +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 0 +} + +// ----- + +func.func @reshape_allowzero_to_reshape_unranked(%arg0: tensor<*xbf16>, %arg1: tensor<4xi64>) -> (tensor<*xbf16>) { + %0 = onnx.Constant dense<[1, 1, 1, 2048]> : tensor<4xi64> loc(unknown) + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<*xbf16>, tensor<4xi64>) -> tensor<*xbf16> + return %1: tensor<*xbf16> + +// CHECK-LABEL: func.func @reshape_allowzero_to_reshape_unranked +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 0 +} + +// ----- + +func.func @reshape_allowzero_to_reshape_known_input_zero_value(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[80, 0, 2]> : tensor<3xi64> } : () -> tensor<3xi64> + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32> + return %1: tensor<*xf32> + +// CHECK-LABEL: func.func @reshape_allowzero_to_reshape_known_input_zero_value +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 1 +} + +// ----- + +func.func @reshape_allowzero_to_reshape_unranked_no_conversion(%arg0: tensor<*xbf16>, %arg1: tensor<4xi64>) -> (tensor<*xbf16>) { + %0 = onnx.Constant dense<[1, 1, 0, 2048]> : tensor<4xi64> loc(unknown) + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<*xbf16>, tensor<4xi64>) -> tensor<*xbf16> + return %1: tensor<*xbf16> + +// CHECK-LABEL: func.func @reshape_allowzero_to_reshape_unranked +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 1 +} + +// ----- + +func.func @reshape_allowzero_no_conversion_failure(%arg0: tensor<1x2048x1x0xbf16>) -> (tensor<1x1x1x0x0x2048xbf16>) { + %0 = onnx.Constant dense<[-1, 1, 1, 0, 0, 2048]> : tensor<6xi64> loc(unknown) + // expected-error@below {{'onnx.Reshape' op Allowzero is set and shape contains both -1 and 0. Dimension corresponding to -1 cannot be determined uniquely.}} + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<1x2048x1x0xbf16>, tensor<6xi64>) -> tensor<1x1x1x0x0x2048xbf16> + return %1: tensor<1x1x1x0x0x2048xbf16> +} + +// ----- + +func.func @reshape_allowzero_no_conversion(%arg0: tensor<1x2048x1x0xbf16>) -> (tensor<1x1x1x0x0x2048xbf16>) { + %0 = onnx.Constant dense<[1, 1, 1, 0, 0, 2048]> : tensor<6xi64> loc(unknown) + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor<1x2048x1x0xbf16>, tensor<6xi64>) -> tensor<1x1x1x0x0x2048xbf16> + return %1: tensor<1x1x1x0x0x2048xbf16> + +// CHECK-LABEL: func.func @reshape_allowzero_no_conversion +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 1 +} + +// ----- + +func.func @reshape_allowzero_no_conversion(%arg0: tensor) -> (tensor<*xbf16>) { + %0 = onnx.Constant dense<[1, 1, 1, 0, 0, 2048]> : tensor<6xi64> loc(unknown) + %1 = "onnx.Reshape"(%arg0, %0) { allowzero = 1 : si64 } : (tensor, tensor<6xi64>) -> tensor<*xbf16> + return %1: tensor<*xbf16> + +// CHECK-LABEL: func.func @reshape_allowzero_no_conversion +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 1 +} + +// ----- + +func.func @reshape_allowzero_to_reshape_dynamic_no_conversion(%arg0: tensor, %arg1: tensor<5xi64>) -> (tensor) { + %0 = "onnx.Reshape"(%arg0, %arg1) {allowzero = 1 : si64 } : (tensor, tensor<5xi64>) -> tensor + return %0: tensor + +// CHECK-LABEL: func.func @reshape_allowzero_to_reshape_dynamic_no_conversion +// CHECK: "onnx.Reshape" +// CHECK-SAME: allowzero = 1 +} + +// ----- + // Check the combining of transposes into an identity transpose, which in turns is removed. // CHECK-LABEL: func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { func.func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { @@ -453,8 +532,8 @@ func.func @test_shape2(%arg0 : tensor) -> tensor<*xi64> { onnx.Return %0 : tensor<*xi64> // CHECK-LABEL: @test_shape2 - // CHECK-NEXT: %0 = "onnx.Shape"(%arg0) {start = 0 : si64} : (tensor) -> tensor<*xi64> - // CHECK-NEXT: onnx.Return %0 : tensor<*xi64> + // CHECK-NEXT: %0 = "onnx.Shape"(%arg0) {start = 0 : si64} : (tensor) -> tensor<4xi64> + // CHECK-NEXT: onnx.Return %0 : tensor<4xi64> } @@ -476,8 +555,8 @@ func.func @test_size2(%arg0 : tensor<*xf32>) -> tensor<*xi64> { onnx.Return %0 : tensor<*xi64> // CHECK-LABEL: @test_size2 - // CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> - // CHECK-NEXT: onnx.Return %0 : tensor<*xi64> + // CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor + // CHECK-NEXT: onnx.Return %0 : tensor } // ----- @@ -498,12 +577,22 @@ func.func @test_global_average_pool_dyn_dims(%arg0: tensor<1x?x?x5xf32>) -> tens %0 = "onnx.GlobalAveragePool"(%arg0) : (tensor<1x?x?x5xf32>) -> tensor<1x?x?x1xf32> onnx.Return %0 : tensor<1x?x?x1xf32> // CHECK-LABEL: test_global_average_pool_dyn_dims - // CHECK: [[RES:%.+]] = "onnx.ReduceMeanV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x?x?x5xf32>) -> tensor<1x?x?x1xf32> - // CHECK: onnx.Return [[RES]] : tensor<1x?x?x1xf32> + // CHECK: [[RES:%.+]] = "onnx.ReduceMeanV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x?x?x5xf32>) -> tensor<1x?x1x1xf32> + // CHECK: onnx.Return [[RES]] : tensor<1x?x1x1xf32> } // ----- +// COM: Test that GlobalAveragePool with dynamic rank does not crash +func.func @test_global_average_pool_dynamic_rank(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.GlobalAveragePool"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> + // CHECK-LABEL: test_global_average_pool_dynamic_rank + // CHECK: "onnx.GlobalAveragePool" + } + +// ----- + // COM: Test rewriting GlobalMaxPool into ReduceMaxV13 func.func @test_global_average_pool(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x1x1xf32> { %0 = "onnx.GlobalMaxPool"(%arg0) : (tensor<1x3x5x5xf32>) -> tensor<1x3x1x1xf32> @@ -520,8 +609,8 @@ func.func @test_global_average_pool_dyn_dims(%arg0: tensor<1x?x?x5xf32>) -> tens %0 = "onnx.GlobalMaxPool"(%arg0) : (tensor<1x?x?x5xf32>) -> tensor<1x?x?x1xf32> onnx.Return %0 : tensor<1x?x?x1xf32> // CHECK-LABEL: test_global_average_pool_dyn_dims - // CHECK: [[RES:%.+]] = "onnx.ReduceMaxV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x?x?x5xf32>) -> tensor<1x?x?x1xf32> - // CHECK: onnx.Return [[RES]] : tensor<1x?x?x1xf32> + // CHECK: [[RES:%.+]] = "onnx.ReduceMaxV13"(%arg0) {axes = [2, 3], keepdims = 1 : si64} : (tensor<1x?x?x5xf32>) -> tensor<1x?x1x1xf32> + // CHECK: onnx.Return [[RES]] : tensor<1x?x1x1xf32> } // ----- @@ -732,14 +821,14 @@ func.func @test_rewrite_batchnormtestmode_Nd(%arg0 : tensor<1x64x112x112xf32>, % // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.00000007E-5> : tensor<1xf32> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 2]> : tensor<2xi64> // CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> - // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<*xf32> - // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> - // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_5_]]) : (tensor<1x64x112x112xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_4_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_8_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_7_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> - // CHECK: [[VAR_9_:%.+]] = "onnx.Unsqueeze"([[VAR_8_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> - // CHECK: [[VAR_10_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_9_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x64x112x112xf32> + // CHECK: [[VAR_3_:%.+]] = "onnx.Sqrt"([[VAR_2_]]) : (tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_4_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> + // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_5_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<1x64x112x112xf32> + // CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_4_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_8_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_7_]]) : (tensor<64xf32>, tensor<64xf32>) -> tensor<64xf32> + // CHECK: [[VAR_9_:%.+]] = "onnx.Unsqueeze"([[VAR_8_]], [[VAR_1_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> + // CHECK: [[VAR_10_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_9_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<1x64x112x112xf32> // CHECK: onnx.Return [[VAR_10_]] : tensor<1x64x112x112xf32> } @@ -753,12 +842,12 @@ func.func @test_rewrite_batchnormtestmode_1d(%arg0 : tensor<64xf32>, %scale : te // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<64xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>, [[PARAM_3_:%.+]]: tensor<1xf32>, [[PARAM_4_:%.+]]: tensor<1xf32>) -> tensor<64xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1.00000007E-5> : tensor<1xf32> // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_0_]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<1xf32>) -> tensor<*xf32> -// CHECK: [[VAR_3_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_5_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<64xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_5_]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<64xf32>, tensor<1xf32>) -> tensor<64xf32> // CHECK: onnx.Return [[VAR_7_]] : tensor<64xf32> } @@ -771,13 +860,13 @@ func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale // CHECK-LABEL: func.func @test_rewrite_batchnormtestmode_1d_f16 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<64xf16>, [[PARAM_1_:%.+]]: tensor<1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>, [[PARAM_3_:%.+]]: tensor<1xf32>, [[PARAM_4_:%.+]]: tensor<1xf32>) -> tensor<64xf16> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1.001360e-05> : tensor<1xf16> -// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_0_]]) : (tensor<1xf32>, tensor<1xf16>) -> tensor<*xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_3_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf16>, tensor<*xf32>) -> tensor<*xf16> -// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_5_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<*xf16>, tensor<*xf32>) -> tensor<64xf16> +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"({{.*}}, [[VAR_0_]]) : (tensor<1xf16>, tensor<1xf16>) -> tensor<1xf16> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<1xf16>) -> tensor<1xf16> +// CHECK: [[VAR_3_:%.+]] = "onnx.Div"({{.*}}, [[VAR_2_]]) : (tensor<1xf16>, tensor<1xf16>) -> tensor<1xf16> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf16>, tensor<1xf16>) -> tensor<64xf16> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"({{.*}}, [[VAR_3_]]) : (tensor<1xf16>, tensor<1xf16>) -> tensor<1xf16> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"({{.*}}, [[VAR_5_]]) : (tensor<1xf16>, tensor<1xf16>) -> tensor<1xf16> +// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<64xf16>, tensor<1xf16>) -> tensor<64xf16> // CHECK: onnx.Return [[VAR_7_]] : tensor<64xf16> // CHECK: } } @@ -820,8 +909,8 @@ func.func @test_fuse_add_conv_bias(%arg0 : tensor<1x1x28x28xf32>, %arg1 : tensor // CHECK-LABEL: func.func @test_fuse_add_conv_bias // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>, [[PARAM_1_:%.+]]: tensor<8x1x5x5xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<1x8x28x28xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[-0.161539719, -0.433835655, 0.091641359, -0.0168522168, -0.0650264397, -0.131737873, 0.0204175506, -0.121110231]> : tensor<8xf32> -// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_2_]], [[VAR_0_]]) : (tensor<8xf32>, tensor<8xf32>) -> tensor<*xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<*xf32>) -> tensor<1x8x28x28xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_2_]], [[VAR_0_]]) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {auto_pad = "SAME_UPPER", dilations = [1, 1], group = 1 : si64, kernel_shape = [5, 5], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x5x5xf32>, tensor<8xf32>) -> tensor<1x8x28x28xf32> // CHECK: onnx.Return [[VAR_2_]] : tensor<1x8x28x28xf32> } @@ -868,15 +957,15 @@ func.func @test_fuse_mul_conv(%arg0: tensor<1x1x28x28xf32>) -> tensor<*xf32> { onnx.Return %4 : tensor<*xf32> // CHECK-LABEL: func.func @test_fuse_mul_conv - // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>) -> tensor<*xf32> { + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x28x28xf32>) -> tensor<1x8x27x27xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}{{.}}[-0.161539719]{{.}}, {{.}}[-0.433835655]{{.}}, {{.}}[0.091641359]{{.}}, {{.}}[-0.0168522168]{{.}}, {{.}}[-0.0650264397]{{.}}, {{.}}[-0.131737873]{{.}}, {{.}}[0.0204175506]{{.}}, {{.}}[-0.121110231]{{.}}{{.}}> : tensor<8x1x1xf32> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<{{.}}[{{.}}[0.0234164055, 0.0228030644], [2.442580e-02, 0.0237577036]{{.}}{{.}}, {{.}}{{.}}[-0.0410864502, 0.0488203131], [0.164448678, -0.0200194642]{{.}}{{.}}, {{.}}{{.}}[-4.34581793E-9, 0.025325032], [0.0373019315, 0.165243402]{{.}}{{.}}, {{.}}{{.}}[-0.0198689923, 0.131284416], [0.0572107285, 2.33985098E-8]{{.}}{{.}}, {{.}}{{.}}[0.0187684372, -0.148515195], [0.0154875498, 0.019133633]{{.}}{{.}}, {{.}}{{.}}[0.0176953916, -0.0154658081], [0.0233727545, -0.274110436]{{.}}{{.}}, {{.}}{{.}}[-0.021181887, 0.0936150252], [0.135688141, -0.0202601217]{{.}}{{.}}, {{.}}{{.}}[-0.0201558527, 0.0192655921], [0.227748245, -0.196346223]{{.}}{{.}}]> : tensor<8x1x2x2xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[VAR_4_:%.+]] = "onnx.Unsqueeze"([[VAR_1_]], [[VAR_0_]]) : (tensor<8x1x1xf32>, tensor<1xi64>) -> tensor<*xf32> - // CHECK: [[VAR_5_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_2_]]) : (tensor<*xf32>, tensor<8x1x2x2xf32>) -> tensor<*xf32> - // CHECK: [[VAR_6_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_5_]], [[VAR_3_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [2, 2], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<*xf32>, none) -> tensor<*xf32> - // CHECK: onnx.Return [[VAR_6_]] : tensor<*xf32> + // CHECK: [[VAR_4_:%.+]] = "onnx.Unsqueeze"([[VAR_1_]], [[VAR_0_]]) : (tensor<8x1x1xf32>, tensor<1xi64>) -> tensor<8x1x1x1xf32> + // CHECK: [[VAR_5_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_2_]]) : (tensor<8x1x1x1xf32>, tensor<8x1x2x2xf32>) -> tensor<8x1x2x2xf32> + // CHECK: [[VAR_6_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_5_]], [[VAR_3_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [2, 2], strides = [1, 1]} : (tensor<1x1x28x28xf32>, tensor<8x1x2x2xf32>, none) -> tensor<1x8x27x27xf32> + // CHECK: onnx.Return [[VAR_6_]] : tensor<1x8x27x27xf32> } // ----- @@ -907,83 +996,6 @@ func.func @test_less_should_not_remove_cast(%arg0 : tensor, %arg1 : tensor< // ----- -// Check deriving a new maximum trip count from the break condition of the loop. -// In this test, the new maximum trip count is a constant. -func.func @test_loop_derive_max_trip_count(%arg0: tensor) -> tensor { - %0 = onnx.Constant dense<9223372036854775807> : tensor - %1 = onnx.Constant dense : tensor - %2 = onnx.Constant dense<0> : tensor - %3 = onnx.Constant dense<30> : tensor - %4:4 = "onnx.Loop"(%0, %1, %2, %3, %arg0) ({ - ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): - %5 = onnx.Constant dense<4> : tensor - %6 = "onnx.Add"(%arg3, %5) : (tensor, tensor) -> tensor - %7 = "onnx.Relu"(%arg5) : (tensor) -> tensor - %8 = "onnx.Less"(%6, %arg4) : (tensor, tensor) -> tensor - onnx.Yield %8, %6, %arg4, %7 : tensor, tensor, tensor, tensor - }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) - onnx.Return %4#3 : tensor -// CHECK-LABEL: func.func @test_loop_derive_max_trip_count -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8> : tensor -// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<4> : tensor -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0> : tensor -// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<30> : tensor -// CHECK: [[VAR_5_:%.+]]:4 = "onnx.Loop"([[VAR_0_]], [[VAR_2_]], [[VAR_3_]], [[VAR_4_]], [[PARAM_0_]]) ({ -// CHECK: ^bb0([[arg1_:%.+]]: tensor, [[arg2_:%.+]]: tensor, [[arg3_:%.+]]: tensor, [[arg4_:%.+]]: tensor, [[arg5_:%.+]]: tensor): -// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Add"([[arg3_]], [[VAR_1_]]) : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[arg5_]]) : (tensor) -> tensor -// CHECK: onnx.Yield [[arg2_]], [[VAR_6_]], [[arg4_]], [[VAR_7_]] : tensor, tensor, tensor, tensor -// CHECK: }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) -// CHECK: onnx.Return [[VAR_5_]]#3 : tensor - -} - -// ----- - -// Check deriving a new maximum trip count from the break condition of the loop. -// In this test, the new maximum trip count is not a constant. -func.func @test_loop_derive_max_trip_count_non_constant_ub(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = onnx.Constant dense<9223372036854775807> : tensor - %1 = onnx.Constant dense : tensor - %2 = onnx.Constant dense<0> : tensor - %3:4 = "onnx.Loop"(%0, %1, %2, %arg1, %arg0) ({ - ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor): - %4 = onnx.Constant dense<1> : tensor - %5 = "onnx.Add"(%arg4, %4) : (tensor, tensor) -> tensor - %6 = "onnx.Relu"(%arg6) : (tensor) -> tensor - %7 = "onnx.Less"(%5, %arg5) : (tensor, tensor) -> tensor - onnx.Yield %7, %5, %arg5, %6 : tensor, tensor, tensor, tensor - }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) - onnx.Return %3#3 : tensor -// CHECK-LABEL: func @test_loop_derive_max_trip_count_non_constant_ub -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor -// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9223372036854775807> : tensor -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0> : tensor -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor -// CHECK: [[VAR_5_:%.+]] = "onnx.Cast"([[VAR_3_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor -// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[VAR_4_]], [[VAR_5_]]) : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Cast"([[VAR_6_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor -// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor -// CHECK: [[VAR_9_:%.+]] = "onnx.Div"([[VAR_7_]], [[VAR_8_]]) : (tensor, tensor) -> tensor -// CHECK: [[VAR_10_:%.+]] = "onnx.Ceil"([[VAR_9_]]) : (tensor) -> tensor -// CHECK: [[VAR_11_:%.+]] = "onnx.Cast"([[VAR_10_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor -// CHECK: [[VAR_12_:%.+]] = "onnx.Min"([[VAR_1_]], [[VAR_1_]]1) : (tensor, tensor) -> tensor -// CHECK: [[VAR_13_:%.+]]:4 = "onnx.Loop"([[VAR_12_]], [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]], [[PARAM_0_]]) ({ -// CHECK: ^bb0([[arg2_:%.+]]: tensor, [[arg3_:%.+]]: tensor, [[arg4_:%.+]]: tensor, [[arg5_:%.+]]: tensor, [[arg6_:%.+]]: tensor): -// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Add"([[arg4_]], [[VAR_0_]]) : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.Relu"([[arg6_]]) : (tensor) -> tensor -// CHECK: onnx.Yield [[arg3_]], [[VAR_14_]], [[arg5_]], [[VAR_15_]] : tensor, tensor, tensor, tensor -// CHECK: }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) -// CHECK: onnx.Return [[VAR_13_]]#3 : tensor - -} - -// ----- - func.func @test_rnn_layout1(%arg0: tensor<5x4x2xf32>, %arg1: tensor<1x3x2xf32>, %arg2: tensor<1x3x3xf32>, %arg3: tensor<5x1x3xf32>) -> tensor<5x1x3xf32> { %cst = "onnx.NoValue"() {value} : () -> none %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %arg3) {layout = 1 : si64} : (tensor<5x4x2xf32>, tensor<1x3x2xf32>, tensor<1x3x3xf32>, none, none, tensor<5x1x3xf32>) -> (tensor<5x4x1x3xf32>, tensor<5x1x3xf32>) @@ -1044,10 +1056,10 @@ func.func @test_lstm_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x16x3x return %Y_h : tensor<*xf32> // CHECK-LABEL: func.func @test_lstm_seq_lens_bs1_in_X -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x16x3xf32>, [[PARAM_2_:%.+]]: tensor<1x16x4xf32>, [[PARAM_3_:%.+]]: tensor<1x32xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<1x1x4xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<*xf32>, none) -// CHECK: return [[Y_h_]] : tensor<*xf32> +// CHECK: [[Y_:%.+]], [[Y_h_:%.+]], [[VAR_Y_c_:%.+]] = "onnx.LSTM"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]], [[VAR_0_]], [[VAR_0_]]) {direction = "forward", hidden_size = 4 : si64, input_forget = 0 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x16x3xf32>, tensor<1x16x4xf32>, tensor<1x32xf32>, none, tensor<1x?x4xf32>, none, none) -> (none, tensor<1x1x4xf32>, none) +// CHECK: return [[Y_h_]] : tensor<1x1x4xf32> // CHECK: } } @@ -1088,10 +1100,10 @@ func.func @test_gru_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x12x3xf return %Y_h : tensor<*xf32> // CHECK-LABEL: func.func @test_gru_seq_lens_bs1_in_X -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x12x3xf32>, [[PARAM_2_:%.+]]: tensor<1x12x4xf32>, [[PARAM_3_:%.+]]: tensor<1x24xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<1x1x4xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) -// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.GRU"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {direction = "forward", hidden_size = 4 : si64, layout = 0 : si64, linear_before_reset = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x12x3xf32>, tensor<1x12x4xf32>, tensor<1x24xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<1x1x4xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<1x1x4xf32> // CHECK: } } @@ -1130,10 +1142,10 @@ func.func @test_rnn_seq_lens_bs1_in_X(%X: tensor<7x1x3xf32>, %W: tensor<1x4x3xf3 return %Y_h : tensor<*xf32> // CHECK-LABEL: func.func @test_rnn_seq_lens_bs1_in_X -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<7x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x4x3xf32>, [[PARAM_2_:%.+]]: tensor<1x4x4xf32>, [[PARAM_3_:%.+]]: tensor<1x8xf32>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor<1x?x4xf32>) -> tensor<1x1x4xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<*xf32>) -// CHECK: return [[VAR_Y_h_]] : tensor<*xf32> +// CHECK: [[Y_]], [[VAR_Y_h_:%.+]] = "onnx.RNN"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[VAR_0_]], [[PARAM_5_]]) {activations = ["Tanh", "Tanh"], direction = "forward", hidden_size = 4 : si64, layout = 0 : si64} : (tensor<7x1x3xf32>, tensor<1x4x3xf32>, tensor<1x4x4xf32>, tensor<1x8xf32>, none, tensor<1x?x4xf32>) -> (none, tensor<1x1x4xf32>) +// CHECK: return [[VAR_Y_h_]] : tensor<1x1x4xf32> // CHECK: } } @@ -1282,61 +1294,36 @@ func.func @shape_transform_identity_map(%arg0: tensor<128x128xf32>) -> tensor<12 // ----- -// COM: Expand Pow into multiple Mul if exponent is an integer and <= 64. +// COM: Expand Pow into multiple Mul if exponent is an integer and <= 2. func.func @expand_pow_into_mul_f32(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %cst = onnx.Constant dense<5.0> : tensor + %cst = onnx.Constant dense<2.0> : tensor %0 = "onnx.Pow"(%arg0, %cst) : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> onnx.Return %0 : tensor<3x4x5xf32> // CHECK-LABEL: func.func @expand_pow_into_mul_f32 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { // CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_2_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: onnx.Return [[VAR_3_]] : tensor<3x4x5xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<3x4x5xf32> // CHECK: } } // ----- -// COM: Expand Pow into multiple Mul if exponent is an integer and <= 64. -func.func @expand_pow_into_mul_f16(%arg0: tensor<3x4x5xf16>) -> tensor<3x4x5xf16> { - %cst = onnx.Constant dense<5.0> : tensor - %0 = "onnx.Pow"(%arg0, %cst) : (tensor<3x4x5xf16>, tensor) -> tensor<3x4x5xf16> - onnx.Return %0 : tensor<3x4x5xf16> +// COM: Expand a bfloat16 Pow into multiple Mul if exponent is an integer and <= 2. +func.func @expand_pow_bf16_into_mul(%arg0: tensor<3x4x5xbf16>) -> tensor<3x4x5xbf16> { + %cst = onnx.Constant dense<2.0> : tensor + %0 = "onnx.Pow"(%arg0, %cst) : (tensor<3x4x5xbf16>, tensor) -> tensor<3x4x5xbf16> + onnx.Return %0 : tensor<3x4x5xbf16> -// CHECK-LABEL: func.func @expand_pow_into_mul_f16 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf16>) -> tensor<3x4x5xf16> { -// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<3x4x5xf16>, tensor<3x4x5xf16>) -> tensor<3x4x5xf16> -// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) : (tensor<3x4x5xf16>, tensor<3x4x5xf16>) -> tensor<3x4x5xf16> -// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_2_]]) : (tensor<3x4x5xf16>, tensor<3x4x5xf16>) -> tensor<3x4x5xf16> -// CHECK: onnx.Return [[VAR_3_]] : tensor<3x4x5xf16> +// CHECK-LABEL: func.func @expand_pow_bf16_into_mul +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xbf16>) -> tensor<3x4x5xbf16> { +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<3x4x5xbf16>, tensor<3x4x5xbf16>) -> tensor<3x4x5xbf16> +// CHECK: onnx.Return [[VAR_1_]] : tensor<3x4x5xbf16> // CHECK: } } // ----- -// COM: Expand Pow into multiple Mul if exponent is an integer and <= 64. - -func.func @expand_pow_into_mul13(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %cst = onnx.Constant dense<13.0> : tensor - %0 = "onnx.Pow"(%arg0, %cst) : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> - onnx.Return %0 : tensor<3x4x5xf32> - -// mlir2FileCheck.py -// CHECK-LABEL: func.func @expand_pow_into_mul13 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_0_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_3_]]) : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: onnx.Return [[VAR_4_]] : tensor<3x4x5xf32> -// CHECK: } -} - -// ----- - func.func @expand_pow_into_constant(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { %cst = onnx.Constant dense<0.0> : tensor %0 = "onnx.Pow"(%arg0, %cst) : (tensor<3x4x5xf32>, tensor) -> tensor<3x4x5xf32> @@ -1357,11 +1344,11 @@ func.func @mul_broadcast_axis_unsqueeze(%279: tensor<1x64x112x112xf32>, %138: te onnx.Return %280 : tensor<*xf32> // CHECK-LABEL: func.func @mul_broadcast_axis_unsqueeze -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x64x112x112xf32>, [[PARAM_1_:%.+]]: tensor<64xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x64x112x112xf32>, [[PARAM_1_:%.+]]: tensor<64xf32>) -> tensor<1x64x112x112xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2]> : tensor<2xi64> // CHECK: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<64xf32>, tensor<2xi64>) -> tensor<64x1x1xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<*xf32> -// CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x64x112x112xf32>, tensor<64x1x1xf32>) -> tensor<1x64x112x112xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x64x112x112xf32> // CHECK: } } @@ -1826,17 +1813,3 @@ func.func @test_where_with_always_false_3(%arg0: tensor) -> tensor<2xi6 // CHECK: } } -// ----- - -func.func @test_dequantize_linear(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { - %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor - %1 = "onnx.DequantizeLinear"(%0, %arg1, %arg2) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor - return %1: tensor - -// CHECK-LABEL: func.func @test_dequantize_linear -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { -// CHECK-NOT: "onnx.QuantizeLinear" -// CHECK-NOT: "onnx.DequantizeLinear" -// CHECK: return [[PARAM_0_]] : tensor -// CHECK: } -} diff --git a/test/mlir/onnx/onnx_canonicalization_locations.mlir b/test/mlir/onnx/onnx_canonicalization_locations.mlir new file mode 100644 index 0000000000..32d39272ef --- /dev/null +++ b/test/mlir/onnx/onnx_canonicalization_locations.mlir @@ -0,0 +1,17 @@ +// RUN: onnx-mlir-opt --shape-inference --canonicalize="test-convergence=true" --shape-inference --cse %s -split-input-file --mlir-print-debuginfo | FileCheck %s + +// CHECK-LABEL: func.func @layernorm_with_bias +func.func @layernorm_with_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %arg3: tensor<768xf32>) -> tensor<1x384x768xf32> { + %none = "onnx.NoValue"() {value} : () -> none + %y, %mean, %stddev = "onnx.LayerNormalization"(%arg0, %arg1, %none) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) loc("LN") + %ret = "onnx.Add"(%y, %arg3) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> loc("Bias") + return %ret : tensor<1x384x768xf32> + // CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) loc([[LOC_FUSED:#.+]]) + // CHECK: return [[Y_]] : tensor<1x384x768xf32> + // CHECK-DAG: [[LOC_LN:#.+]] = loc("LN") + // CHECK-DAG: [[LOC_BIAS:#.+]] = loc("Bias") + // CHECK: [[LOC_FUSED]] = loc(fused[[[LOC_LN]], [[LOC_BIAS]]]) +} + + + diff --git a/test/mlir/onnx/onnx_canonicalization_without_shape_inference.mlir b/test/mlir/onnx/onnx_canonicalization_without_shape_inference.mlir new file mode 100644 index 0000000000..15d463e0dd --- /dev/null +++ b/test/mlir/onnx/onnx_canonicalization_without_shape_inference.mlir @@ -0,0 +1,117 @@ +// RUN: onnx-mlir-opt --canonicalize="test-convergence=true" %s -split-input-file | FileCheck %s + +// FIXME: This tests have issues when running shape-inference previous to canonicalize. + +// CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { +func.func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<256xf32>) -> tensor<*xf32> + onnx.Return %1 : tensor<*xf32> + + // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> + // onnx.Return [[GEMM]] : tensor<*xf32> +} + +// ----- + +//CHECK-LABEL: @test_gemm_add_fusion(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> { +func.func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32> + onnx.Return %1 : tensor<*xf32> + + // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> + // onnx.Return [[GEMM]] : tensor<*xf32> +} + +// ----- + +//CHECK-LABEL: @test_gemm_add_fusion_beta_zero(%{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128x128xf32>, %{{.*}}: tensor<128xf32>) -> tensor<*xf32> { +func.func @test_gemm_add_fusion_beta_zero(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "onnx.Gemm"(%arg0, %arg1, %cst) {beta = 0.0 : f32}: (tensor<128x128xf32>, tensor<128x128xf32>, none) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<128xf32>) -> tensor<*xf32> + onnx.Return %1 : tensor<*xf32> + + // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> + // onnx.Return [[GEMM]] : tensor<*xf32> +} + +// ----- + +// Check deriving a new maximum trip count from the break condition of the loop. +// In this test, the new maximum trip count is a constant. +func.func @test_loop_derive_max_trip_count(%arg0: tensor) -> tensor { + %0 = onnx.Constant dense<9223372036854775807> : tensor + %1 = onnx.Constant dense : tensor + %2 = onnx.Constant dense<0> : tensor + %3 = onnx.Constant dense<30> : tensor + %4:4 = "onnx.Loop"(%0, %1, %2, %3, %arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): + %5 = onnx.Constant dense<4> : tensor + %6 = "onnx.Add"(%arg3, %5) : (tensor, tensor) -> tensor + %7 = "onnx.Relu"(%arg5) : (tensor) -> tensor + %8 = "onnx.Less"(%6, %arg4) : (tensor, tensor) -> tensor + onnx.Yield %8, %6, %arg4, %7 : tensor, tensor, tensor, tensor + }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + onnx.Return %4#3 : tensor +// CHECK-LABEL: func.func @test_loop_derive_max_trip_count +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<4> : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<30> : tensor +// CHECK: [[VAR_5_:%.+]]:4 = "onnx.Loop"([[VAR_0_]], [[VAR_2_]], [[VAR_3_]], [[VAR_4_]], [[PARAM_0_]]) ({ +// CHECK: ^bb0([[arg1_:%.+]]: tensor, [[arg2_:%.+]]: tensor, [[arg3_:%.+]]: tensor, [[arg4_:%.+]]: tensor, [[arg5_:%.+]]: tensor): +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Add"([[arg3_]], [[VAR_1_]]) : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Relu"([[arg5_]]) : (tensor) -> tensor +// CHECK: onnx.Yield [[arg2_]], [[VAR_6_]], [[arg4_]], [[VAR_7_]] : tensor, tensor, tensor, tensor +// CHECK: }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) +// CHECK: onnx.Return [[VAR_5_]]#3 : tensor + +} + +// ----- + +// Check deriving a new maximum trip count from the break condition of the loop. +// In this test, the new maximum trip count is not a constant. +func.func @test_loop_derive_max_trip_count_non_constant_ub(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = onnx.Constant dense<9223372036854775807> : tensor + %1 = onnx.Constant dense : tensor + %2 = onnx.Constant dense<0> : tensor + %3:4 = "onnx.Loop"(%0, %1, %2, %arg1, %arg0) ({ + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor): + %4 = onnx.Constant dense<1> : tensor + %5 = "onnx.Add"(%arg4, %4) : (tensor, tensor) -> tensor + %6 = "onnx.Relu"(%arg6) : (tensor) -> tensor + %7 = "onnx.Less"(%5, %arg5) : (tensor, tensor) -> tensor + onnx.Yield %7, %5, %arg5, %6 : tensor, tensor, tensor, tensor + }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + onnx.Return %3#3 : tensor +// CHECK-LABEL: func @test_loop_derive_max_trip_count_non_constant_ub +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9223372036854775807> : tensor +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0> : tensor +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor +// CHECK: [[VAR_5_:%.+]] = "onnx.Cast"([[VAR_3_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor +// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[VAR_4_]], [[VAR_5_]]) : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Cast"([[VAR_6_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK: [[VAR_9_:%.+]] = "onnx.Div"([[VAR_7_]], [[VAR_8_]]) : (tensor, tensor) -> tensor +// CHECK: [[VAR_10_:%.+]] = "onnx.Ceil"([[VAR_9_]]) : (tensor) -> tensor +// CHECK: [[VAR_11_:%.+]] = "onnx.Cast"([[VAR_10_]]) {saturate = 1 : si64, to = i64} : (tensor) -> tensor +// CHECK: [[VAR_12_:%.+]] = "onnx.Min"([[VAR_1_]], [[VAR_1_]]1) : (tensor, tensor) -> tensor +// CHECK: [[VAR_13_:%.+]]:4 = "onnx.Loop"([[VAR_12_]], [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]], [[PARAM_0_]]) ({ +// CHECK: ^bb0([[arg2_:%.+]]: tensor, [[arg3_:%.+]]: tensor, [[arg4_:%.+]]: tensor, [[arg5_:%.+]]: tensor, [[arg6_:%.+]]: tensor): +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Add"([[arg4_]], [[VAR_0_]]) : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.Relu"([[arg6_]]) : (tensor) -> tensor +// CHECK: onnx.Yield [[arg3_]], [[VAR_14_]], [[arg5_]], [[VAR_15_]] : tensor, tensor, tensor, tensor +// CHECK: }) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) +// CHECK: onnx.Return [[VAR_13_]]#3 : tensor + +} \ No newline at end of file diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index c19b1974f7..3b9781321e 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -151,6 +151,42 @@ func.func @test_add_const_associative2_2uses(%x: tensor, %y: tensor, % // ----- +// (x + c) + y will not be rewritten to (x + y) + c because x and c are scalar +func.func @test_add_const_associative_scalar_not_apply_1(%x: tensor<1xi32>, %y: tensor<5xi32>) -> tensor<5xi32> { + %c = onnx.Constant dense<1> : tensor + %1 = "onnx.Add"(%x, %c) : (tensor<1xi32> , tensor) -> tensor<1xi32> + %2 = "onnx.Add"(%1, %y) : (tensor<1xi32> , tensor<5xi32>) -> tensor<5xi32> + onnx.Return %2: tensor<5xi32> + +// CHECK-LABEL: func.func @test_add_const_associative_scalar_not_apply_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi32>, [[PARAM_1_:%.+]]: tensor<5xi32>) -> tensor<5xi32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_0_]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_1_]]) : (tensor<1xi32>, tensor<5xi32>) -> tensor<5xi32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<5xi32> +// CHECK: } +} + +// ----- + +// (x + (y+c)) will not be rewritten to (x + y) + c because y and c are scalar +func.func @test_add_const_associative_scalar_not_apply_2(%x: tensor<5xi32>, %y: tensor<1xi32>) -> tensor<5xi32> { + %c = onnx.Constant dense<1> : tensor + %1 = "onnx.Add"(%y, %c) : (tensor<1xi32> , tensor) -> tensor<1xi32> + %2 = "onnx.Add"(%x, %1) : (tensor<5xi32> , tensor<1xi32>) -> tensor<5xi32> + onnx.Return %2: tensor<5xi32> + +// CHECK-LABEL: func.func @test_add_const_associative_scalar_not_apply_2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<5xi32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_1_]], [[VAR_0_]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_1_]]) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<5xi32> +// CHECK: } +} + +// ----- + // CHECK-LABEL: @test_add_zeros(%arg0: tensor<3xi32>) -> tensor<3xi32> func.func @test_add_zeros(%arg0 : tensor<3xi32>) -> tensor<3xi32> { %0 = onnx.Constant dense<[0, 0, 0]> : tensor<3xi32> @@ -252,6 +288,42 @@ func.func @test_mul_ones(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: onnx.Return %arg0 : tensor<2x2xf32> } +// ----- + +// (x * c) * y will not be rewritten to (x * y) * c because x and c are scalar +func.func @test_mul_const_associative_scalar_not_apply_1(%x: tensor<1xi32>, %y: tensor<5xi32>) -> tensor<5xi32> { + %c = onnx.Constant dense<5> : tensor + %1 = "onnx.Mul"(%x, %c) : (tensor<1xi32> , tensor) -> tensor<1xi32> + %2 = "onnx.Mul"(%1, %y) : (tensor<1xi32> , tensor<5xi32>) -> tensor<5xi32> + onnx.Return %2: tensor<5xi32> + +// CHECK-LABEL: func.func @test_mul_const_associative_scalar_not_apply_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi32>, [[PARAM_1_:%.+]]: tensor<5xi32>) -> tensor<5xi32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<5> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[PARAM_1_]]) : (tensor<1xi32>, tensor<5xi32>) -> tensor<5xi32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<5xi32> +// CHECK: } +} + +// ----- + +// (x * (y*c)) will not be rewritten to (x * y) * c because y and c are scalar +func.func @test_mul_const_associative_scalar_not_apply_2(%x: tensor<5xi32>, %y: tensor<1xi32>) -> tensor<5xi32> { + %c = onnx.Constant dense<5> : tensor + %1 = "onnx.Mul"(%y, %c) : (tensor<1xi32> , tensor) -> tensor<1xi32> + %2 = "onnx.Mul"(%x, %1) : (tensor<5xi32> , tensor<1xi32>) -> tensor<5xi32> + onnx.Return %2: tensor<5xi32> + +// CHECK-LABEL: func.func @test_mul_const_associative_scalar_not_apply_2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<5xi32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<5> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_1_]], [[VAR_0_]]) : (tensor<1xi32>, tensor) -> tensor<1xi32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<5xi32> +// CHECK: } +} + //===----------------------------------------------------------------------===// /// SUB and NEG tests. @@ -320,6 +392,40 @@ func.func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { // ----- +// CHECK-LABEL: @test_abs() -> tensor<4x2xbf16> +func.func @test_abs() -> tensor<4x2xbf16> { + // Test Positive, Negative, Zero, -Zero, +Inf, -Inf, NaN, -NaN + %0 = onnx.Constant dense<[[12.5, -12.5], [0.0, -0.0], [0x7F80, 0xFF80], [0xFFC0, 0x7FC0]]> : tensor<4x2xbf16> + %1 = "onnx.Abs"(%0) : (tensor<4x2xbf16>) -> tensor<4x2xbf16> + "onnx.Return"(%1) : (tensor<4x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[1.250000e+01, 1.250000e+01], [0.000000e+00, 0.000000e+00], [0x7F80, 0x7F80], [0x7FC0, 0x7FC0]]> + // CHECK-NOT: "onnx.Abs" +} + +// ----- + +// CHECK-LABEL: @test_abs2() -> tensor<2x2xi32> +func.func @test_abs2() -> tensor<2x2xi32> { + %0 = onnx.Constant dense<[[12, -12], [0, -1000]]> : tensor<2x2xi32> + %1 = "onnx.Abs"(%0) : (tensor<2x2xi32>) -> tensor<2x2xi32> + "onnx.Return"(%1) : (tensor<2x2xi32>) -> () + // CHECK: onnx.Constant dense<{{.}}[12, 12], [0, 1000]]> + // CHECK-NOT: "onnx.Abs" +} + +// ----- + +// CHECK-LABEL: @test_abs3() -> tensor<1x2xui64> +func.func @test_abs3() -> tensor<1x2xui64> { + %0 = onnx.Constant dense<[[18446744073709551615, 18446744073709551614]]> : tensor<1x2xui64> + %1 = "onnx.Abs"(%0) : (tensor<1x2xui64>) -> tensor<1x2xui64> + "onnx.Return"(%1) : (tensor<1x2xui64>) -> () + // CHECK: onnx.Constant dense<{{.}}[18446744073709551615, 18446744073709551614]]> + // CHECK-NOT: "onnx.Abs" +} + +// ----- + // CHECK-LABEL: @test_ceil() -> tensor<3x2xbf16> func.func @test_ceil() -> tensor<3x2xbf16> { // Test Positive, Negative, Zero, NaN, +Inf, -Inf @@ -420,6 +526,17 @@ func.func @test_reciprocal() -> tensor<3x2xbf16> { // ----- +// CHECK-LABEL: @test_round() -> tensor<5x2xbf16> +func.func @test_round() -> tensor<5x2xbf16> { + %0 = onnx.Constant dense<[[0.9, 2.5], [2.3, 1.5], [-4.5, -3.5], [-2.6, 0x7FC0],[0x7F80, 0xFF80]]> : tensor<5x2xbf16> + %1 = "onnx.Round"(%0) : (tensor<5x2xbf16>) -> tensor<5x2xbf16> + "onnx.Return"(%1) : (tensor<5x2xbf16>) -> () + // CHECK: onnx.Constant dense<{{.}}[1.000000e+00, 2.000000e+00], [2.000000e+00, 2.000000e+00], [-4.000000e+00, -4.000000e+00], [-3.000000e+00, 0x7FC0], [0x7F80, 0xFF80]]> + // CHECK-NOT: "onnx.Round" +} + +// ----- + // CHECK-LABEL: @test_sin() -> tensor<3x2xf32> func.func @test_sin() -> tensor<3x2xf32> { // Test Positive, Negative, Zero, NaN, +Inf, -Inf @@ -2395,6 +2512,17 @@ func.func @test_pow() -> tensor<2x2xf32> { // CHECK: } } +// ----- + +func.func @test_pow_i32_f32_no_prop(%arg0: tensor<1x2xi32>, %arg1: tensor) -> tensor<1x2xi32> { + %0 = onnx.Constant dense<[[1, 2]]> : tensor<1x2xi32> + %1 = onnx.Constant dense<2.0> : tensor + %2 = "onnx.Pow"(%0, %1) : (tensor<1x2xi32>, tensor) -> tensor<1x2xi32> + "onnx.Return"(%2) : (tensor<1x2xi32>) -> () + // CHECK-LABEL: @test_pow_i32_f32_no_prop + // CHECK: "onnx.Pow" +} + //===----------------------------------------------------------------------===// /// Reciprocal test @@ -2408,3 +2536,153 @@ func.func @test_reciprocal() -> tensor<1x2xf32> { // CHECK: {{.*}} = onnx.Constant dense<{{\[}}[-2.500000e-01, 6.250000e-02]{{\]}}> : tensor<1x2xf32> // CHECK-NOT: {{.*}} = "onnx.Reciprocal"{{.*}} } +// ----- +//---------------------------------------------// +// reverseSequence tests + +// CHECK-LABEL: @test_reverse_seq_2d_batchaxis() -> tensor<4x4xf32> +func.func @test_reverse_seq_2d_batchaxis() -> tensor<4x4xf32> { + %0 = onnx.Constant dense<[4,3,2,1]> : tensor<4xi64> + %1 = onnx.Constant dense<[[0.0,4.0,8.0,12.0],[1.0,5.0,9.0,13.0],[2.0,6.0,10.0,14.0],[3.0,7.0,11.0,15.0]]> : tensor<4x4xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<4x4xf32>, tensor<4xi64>) -> tensor<4x4xf32> + onnx.Return %2 : tensor<4x4xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[3.000000e+00, 6.000000e+00, 9.000000e+00, 1.200000e+01], [2.000000e+00, 5.000000e+00, 8.000000e+00, 1.300000e+01], [1.000000e+00, 4.000000e+00, 1.000000e+01, 1.400000e+01], [0.000000e+00, 7.000000e+00, 1.100000e+01, 1.500000e+01]{{.}}> : tensor<4x4xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_2d_timeaxis() -> tensor<4x4xf32> +func.func @test_reverse_seq_2d_timeaxis() -> tensor<4x4xf32> { + %0 = onnx.Constant dense<[1,2,3,4]> : tensor<4xi64> + %1 = onnx.Constant dense<[[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0],[12.0,13.0,14.0,15.0]]> : tensor<4x4xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<4x4xf32>, tensor<4xi64>) -> tensor<4x4xf32> + onnx.Return %2 : tensor<4x4xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [5.000000e+00, 4.000000e+00, 6.000000e+00, 7.000000e+00], [1.000000e+01, 9.000000e+00, 8.000000e+00, 1.100000e+01], [1.500000e+01, 1.400000e+01, 1.300000e+01, 1.200000e+01]{{.}}> : tensor<4x4xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq() -> tensor<3x3x1x2xf32> +func.func @test_reverse_seq() -> tensor<3x3x1x2xf32> { + %0 = onnx.Constant dense<[3,2,1]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[1.1,2.1]],[[3.1,4.1]],[[5.1,6.1]]],[[[1.2,2.2]],[[3.2,4.2]],[[5.2,6.2]]],[[[1.3,2.3]],[[3.3,4.3]],[[5.3,6.3]]]]> : tensor<3x3x1x2xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> + onnx.Return %2 : tensor<3x3x1x2xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[5.100000e+00, 6.100000e+00]], {{.}}[3.100000e+00, 4.100000e+00]{{.}}, {{.}}[1.100000e+00, 2.100000e+00]{{.}}], [{{.}}[3.200000e+00, 4.200000e+00]{{.}}, {{.}}[1.200000e+00, 2.200000e+00]{{.}}, {{.}}[5.200000e+00, 6.1999998]{{.}}], [{{.}}[1.300000e+00, 2.300000e+00]{{.}}, {{.}}[3.300000e+00, 4.300000e+00]{{.}}, {{.}}[5.300000e+00, 6.300000e+00]{{.}}]{{.}}> : tensor<3x3x1x2xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_sameseq_values() -> tensor<3x3x1x2xf32> +func.func @test_reverse_seq_sameseq_values() -> tensor<3x3x1x2xf32> { + %0 = onnx.Constant dense<[3,3,3]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[1.1,2.1]],[[3.1,4.1]],[[5.1,6.1]]],[[[1.2,2.2]],[[3.2,4.2]],[[5.2,6.2]]],[[[1.3,2.3]],[[3.3,4.3]],[[5.3,6.3]]]]> : tensor<3x3x1x2xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> + onnx.Return %2 : tensor<3x3x1x2xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[5.100000e+00, 6.100000e+00]{{.}}, {{.}}[3.100000e+00, 4.100000e+00]{{.}}, {{.}}[1.100000e+00, 2.100000e+00]{{.}}], [{{.}}[5.200000e+00, 6.1999998]{{.}}, {{.}}[3.200000e+00, 4.200000e+00]{{.}}, {{.}}[1.200000e+00, 2.200000e+00]{{.}}], [{{.}}[5.300000e+00, 6.300000e+00]{{.}}, {{.}}[3.300000e+00, 4.300000e+00]{{.}}, {{.}}[1.300000e+00, 2.300000e+00]{{.}}]{{.}}> : tensor<3x3x1x2xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_int() -> tensor<3x3x1x2xi32> +func.func @test_reverse_seq_int() -> tensor<3x3x1x2xi32> { + %0 = onnx.Constant dense<[3,2,1]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[11,21]],[[31,41]],[[51,61]]],[[[12,22]],[[32,42]],[[52,62]]],[[[13,23]],[[33,43]],[[53,63]]]]> : tensor<3x3x1x2xi32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xi32>, tensor<3xi64>) -> tensor<3x3x1x2xi32> + onnx.Return %2 : tensor<3x3x1x2xi32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[51, 61]{{.}}, {{.}}[31, 41]{{.}}, {{.}}[11, 21]{{.}}], [{{.}}[32, 42]{{.}}, {{.}}[12, 22]{{.}}, {{.}}[52, 62]{{.}}], [{{.}}[13, 23]{{.}}, {{.}}[33, 43]{{.}}, {{.}}[53, 63]{{.}}]{{.}}> : tensor<3x3x1x2xi32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_sameseq_values_int() -> tensor<3x3x1x2xi32> +func.func @test_reverse_seq_sameseq_values_int() -> tensor<3x3x1x2xi32> { + %0 = onnx.Constant dense<[3,3,3]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[11,21]],[[31,41]],[[51,61]]],[[[12,22]],[[32,42]],[[52,62]]],[[[13,23]],[[33,43]],[[53,63]]]]> : tensor<3x3x1x2xi32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xi32>, tensor<3xi64>) -> tensor<3x3x1x2xi32> + onnx.Return %2 : tensor<3x3x1x2xi32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[51, 61]{{.}}, {{.}}[31, 41]{{.}}, {{.}}[11, 21]{{.}}], [{{.}}[52, 62]{{.}}, {{.}}[32, 42]{{.}}, {{.}}[12, 22]{{.}}], [{{.}}[53, 63]{{.}}, {{.}}[33, 43]{{.}}, {{.}}[13, 23]{{.}}]{{.}}> : tensor<3x3x1x2xi32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_batch_axis_1() -> tensor<3x3x1x2xf32> +func.func @test_reverse_seq_batch_axis_1() -> tensor<3x3x1x2xf32> { + %0 = onnx.Constant dense<[3,2,1]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[1.1,2.1]],[[3.1,4.1]],[[5.1,6.1]]],[[[1.2,2.2]],[[3.2,4.2]],[[5.2,6.2]]],[[[1.3,2.3]],[[3.3,4.3]],[[5.3,6.3]]]]> : tensor<3x3x1x2xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> + onnx.Return %2 : tensor<3x3x1x2xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[1.300000e+00, 2.300000e+00]{{.}}, {{.}}[3.200000e+00, 4.200000e+00]{{.}}, {{.}}[5.100000e+00, 6.100000e+00]{{.}}], [{{.}}[1.200000e+00, 2.200000e+00]{{.}}, {{.}}[3.100000e+00, 4.100000e+00]{{.}}, {{.}}[5.200000e+00, 6.1999998]{{.}}], [{{.}}[1.100000e+00, 2.100000e+00]{{.}}, {{.}}[3.300000e+00, 4.300000e+00]{{.}}, {{.}}[5.300000e+00, 6.300000e+00]{{.}}]{{.}}> : tensor<3x3x1x2xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_sameseq_values_batch_axis_1() -> tensor<3x3x1x2xf32> +func.func @test_reverse_seq_sameseq_values_batch_axis_1() -> tensor<3x3x1x2xf32> { + %0 = onnx.Constant dense<[3,3,3]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[1.1,2.1]],[[3.1,4.1]],[[5.1,6.1]]],[[[1.2,2.2]],[[3.2,4.2]],[[5.2,6.2]]],[[[1.3,2.3]],[[3.3,4.3]],[[5.3,6.3]]]]> : tensor<3x3x1x2xf32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> + onnx.Return %2 : tensor<3x3x1x2xf32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[1.300000e+00, 2.300000e+00]{{.}}, {{.}}[3.300000e+00, 4.300000e+00]{{.}}, {{.}}[5.300000e+00, 6.300000e+00]{{.}}], [{{.}}[1.200000e+00, 2.200000e+00]{{.}}, {{.}}[3.200000e+00, 4.200000e+00]{{.}}, {{.}}[5.200000e+00, 6.1999998]{{.}}], [{{.}}[1.100000e+00, 2.100000e+00]{{.}}, {{.}}[3.100000e+00, 4.100000e+00]{{.}}, {{.}}[5.100000e+00, 6.100000e+00]{{.}}]{{.}}> : tensor<3x3x1x2xf32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + + + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_int_batch_axis_1() -> tensor<3x3x1x2xi32> +func.func @test_reverse_seq_int_batch_axis_1() -> tensor<3x3x1x2xi32> { + %0 = onnx.Constant dense<[3,2,1]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[11,21]],[[31,41]],[[51,61]]],[[[12,22]],[[32,42]],[[52,62]]],[[[13,23]],[[33,43]],[[53,63]]]]> : tensor<3x3x1x2xi32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xi32>, tensor<3xi64>) -> tensor<3x3x1x2xi32> + onnx.Return %2 : tensor<3x3x1x2xi32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[13, 23]{{.}}, {{.}}[32, 42]{{.}}, {{.}}[51, 61]{{.}}], [{{.}}[12, 22]{{.}}, {{.}}[31, 41]{{.}}, {{.}}[52, 62]{{.}}], [{{.}}[11, 21]{{.}}, {{.}}[33, 43]{{.}}, {{.}}[53, 63]{{.}}]{{.}}> : tensor<3x3x1x2xi32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +// ----- +//---------------------------------------------// +// CHECK-LABEL: @test_reverse_seq_sameseq_values_int_ba_1() -> tensor<3x3x1x2xi32> +func.func @test_reverse_seq_sameseq_values_int_ba_1() -> tensor<3x3x1x2xi32> { + %0 = onnx.Constant dense<[3,3,3]> : tensor<3xi64> + %1 = onnx.Constant dense<[[[[11,21]],[[31,41]],[[51,61]]],[[[12,22]],[[32,42]],[[52,62]]],[[[13,23]],[[33,43]],[[53,63]]]]> : tensor<3x3x1x2xi32> + %2 = "onnx.ReverseSequence"(%1, %0) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xi32>, tensor<3xi64>) -> tensor<3x3x1x2xi32> + onnx.Return %2 : tensor<3x3x1x2xi32> + // CHECK: {{.*}} = onnx.Constant dense<{{.}}[{{.}}[13, 23]{{.}}, {{.}}[33, 43]{{.}}, {{.}}[53, 63]{{.}}], [{{.}}[12, 22]{{.}}, {{.}}[32, 42]{{.}}, {{.}}[52, 62]{{.}}], [{{.}}[11, 21]{{.}}, {{.}}[31, 41]{{.}}, {{.}}[51, 61]{{.}}]{{.}}> : tensor<3x3x1x2xi32> + // CHECK-NOT: {{.*}} = "onnx.ReverseSequence"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Abs test + +// ----- + +// CHECK-LABEL: @test_abs() -> tensor<2xf32> +func.func @test_abs() -> tensor<2xf32> { + %0 = onnx.Constant dense<[-4.0, 16.0]> : tensor<2xf32> + %1 = "onnx.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32> + "onnx.Return"(%1) : (tensor<2xf32>) -> () + // CHECK: {{.*}} = onnx.Constant dense<[4.000000e+00, 1.600000e+01]> : tensor<2xf32> + // CHECK-NOT: {{.*}} = "onnx.Abs"{{.*}} +} + +//===----------------------------------------------------------------------===// +/// Round test + +// ----- + +// CHECK-LABEL: @test_round() -> tensor<5xf32> +func.func @test_round() -> tensor<5xf32> { + %0 = onnx.Constant dense<[0.9, 2.5, 2.3, 1.5, -4.5]> : tensor<5xf32> + %1 = "onnx.Round"(%0) : (tensor<5xf32>) -> tensor<5xf32> + "onnx.Return"(%1) : (tensor<5xf32>) -> () + // CHECK: {{.*}} = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 2.000000e+00, 2.000000e+00, -4.000000e+00]> : tensor<5xf32> + // CHECK-NOT: {{.*}} = "onnx.Round"{{.*}} +} diff --git a/test/mlir/onnx/onnx_constprop_locations.mlir b/test/mlir/onnx/onnx_constprop_locations.mlir new file mode 100644 index 0000000000..c4124ca182 --- /dev/null +++ b/test/mlir/onnx/onnx_constprop_locations.mlir @@ -0,0 +1,30 @@ +// RUN: onnx-mlir-opt --shape-inference --constprop-onnx %s -split-input-file --mlir-print-debuginfo | FileCheck %s + + +//===----------------------------------------------------------------------===// +/// Commutative tests + +// CHECK-LABEL: @test_add_constant_1_loc +func.func @test_add_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant") + %1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Add") + "onnx.Return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]]) + // CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_ADD:#.+]]) + // CHECK-DAG: [[LOC_CONST]] = loc("Constant") + // CHECK-DAG: [[LOC_ADD]] = loc("Add") +} + +// ----- + +// CHECK-LABEL: @test_mul_constant_1_loc +func.func @test_mul_constant_1_loc(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = onnx.Constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32> loc("Constant") + %1 = "onnx.Mul"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> loc("Mul") + "onnx.Return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = onnx.Constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32> loc([[LOC_CONST:#.+]]) + // CHECK-NEXT: [[MUL:%.+]] = "onnx.Mul"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> loc([[LOC_MUL:#.+]]) + // CHECK-DAG: [[LOC_CONST]] = loc("Constant") + // CHECK-DAG: [[LOC_MUL]] = loc("Mul") +} + diff --git a/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir b/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir index 9ef7544648..2d94d98e36 100644 --- a/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir +++ b/test/mlir/onnx/onnx_constprop_no_shape_inference.mlir @@ -99,3 +99,287 @@ func.func @test_scatternd_i32() -> (tensor<4x4x4xi32>) { } // ----- + +//===----------------------------------------------------------------------===// +/// Checks to ensure that constprop does not crash on non static shapes. +/// This does only checks the absence of crashes, not that constants get folded + +// binary ops +// CHECK-LABEL: @test_add_dynamic_result +func.func @test_add_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Add"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_sub_dynamic_result +func.func @test_sub_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Sub"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_mul_dynamic_result +func.func @test_mul_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Mul"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_div_dynamic_result +func.func @test_div_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Div"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_bitwise_and_dynamic_result +func.func @test_bitwise_and_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.BitwiseAnd"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_bitwise_or_dynamic_result +func.func @test_bitwise_or_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.BitwiseOr"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_and_dynamic_result +func.func @test_and_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi1> } : tensor<4x4x4xi1> + %1 = "onnx.And"(%0, %0) : (tensor<4x4x4xi1>, tensor<4x4x4xi1>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_or_dynamic_result +func.func @test_or_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi1> } : tensor<4x4x4xi1> + %1 = "onnx.Or"(%0, %0) : (tensor<4x4x4xi1>, tensor<4x4x4xi1>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_xor_dynamic_result +func.func @test_xor_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi1> } : tensor<4x4x4xi1> + %1 = "onnx.And"(%0, %0) : (tensor<4x4x4xi1>, tensor<4x4x4xi1>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_eq_dynamic_result +func.func @test_eq_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi1> } : tensor<4x4x4xi1> + %1 = "onnx.Equal"(%0, %0) : (tensor<4x4x4xi1>, tensor<4x4x4xi1>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_less_dynamic_result +func.func @test_less_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Less"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_greater_dynamic_result +func.func @test_greater_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Greater"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_less_or_equal_dynamic_result +func.func @test_less_or_equal_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.LessOrEqual"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_greater_or_equal_dynamic_result +func.func @test_greater_or_equal_dynamic_result() -> (tensor<*xi1>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.GreaterOrEqual"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi1> + onnx.Return %1 : tensor<*xi1> +} + +// ----- + +// CHECK-LABEL: @test_mod_dynamic_result +func.func @test_mod_dynamic_result() -> (tensor<*xi32>) { + %0 = onnx.Constant { value = dense<1>: tensor<4x4x4xi32> } : tensor<4x4x4xi32> + %1 = "onnx.Mod"(%0, %0) : (tensor<4x4x4xi32>, tensor<4x4x4xi32>) -> tensor<*xi32> + onnx.Return %1 : tensor<*xi32> +} + +// ----- +// misc ops + +// CHECK-LABEL: @test_where_dynamic_result() +func.func @test_where_dynamic_result() -> tensor<*xf32> { + %0 = onnx.Constant dense<[true, false]> : tensor<2xi1> + %1 = onnx.Constant dense<[[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]]> : tensor<3x2xf32> + %2 = onnx.Constant dense<[[2.0]]> : tensor<1x1xf32> + %3 = "onnx.Where"(%0, %1, %2) : (tensor<2xi1>, tensor<3x2xf32>, tensor<1x1xf32>) -> tensor<*xf32> + "onnx.Return"(%3) : (tensor<*xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_matmul_2d_dynamic_result() +func.func @test_matmul_2d_dynamic_result() -> (tensor<*xf32>) { + %0 = "onnx.Constant"() {value = dense<1.> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "onnx.Constant"() {value = dense<1.> : tensor<3x1xf32>} : () -> tensor<3x1xf32> + %3 = "onnx.MatMul"(%0, %1) : (tensor<2x3xf32>, tensor<3x1xf32>) -> tensor<*xf32> + onnx.Return %3 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_gemm_dynamic_result() +func.func @test_gemm_dynamic_result() -> (tensor<*xi32>) { + %0 = "onnx.Constant"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "onnx.Constant"() {value = dense<[[10, 20], [30, 40]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %2 = "onnx.Constant"() {value = dense<[[1000, 2000], [3000, 4000]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %3 = "onnx.Gemm"(%0, %1, %2) : (tensor<2x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<*xi32> + onnx.Return %3 : tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @test_squeeze_dynamic_result() +func.func @test_squeeze_dynamic_result() -> tensor<*xf32> { + %0 = onnx.Constant dense<[[[4.0]], [[16.0]]]> : tensor<2x1x1xf32> + %1 = onnx.Constant dense<[1, 2]> : tensor<2xi64> + %2 = "onnx.Squeeze"(%0, %1) : (tensor<2x1x1xf32>, tensor<2xi64>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_unsqueeze_dynamic_result() +func.func @test_unsqueeze_dynamic_result() -> tensor<*xf32> { + %0 = onnx.Constant dense<[4.0, 16.0]> : tensor<2xf32> + %1 = onnx.Constant dense<[1, 2]> : tensor<2xi64> + %2 = "onnx.Unsqueeze"(%0, %1) : (tensor<2xf32>, tensor<2xi64>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_pad_dynamic_result() +func.func @test_pad_dynamic_result() -> tensor<*xf32> { + %data = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %pads = onnx.Constant dense<[0, 2, 0, 0]> : tensor<4xi64> + %non = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Pad"(%data, %pads, %non, %non) { mode = "constant" } : (tensor<3x2xf32>, tensor<4xi64>, none, none) -> tensor<*xf32> + onnx.Return %1 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_concat_negative_axis_dynamic_result() +func.func @test_concat_negative_axis_dynamic_result() -> tensor<*xf32>{ + %0 = onnx.Constant dense<[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]> : tensor<3x2xf32> + %1 = onnx.Constant dense<[[11.0, 12.0], [13.0, 14.0], [15.0, 16.0]]> : tensor<3x2xf32> + %2 = "onnx.Concat"(%0, %1) {axis = -1 : si64} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_gather_axis_0_dynamic_result() +func.func @test_gather_axis_0_dynamic_result() -> tensor<*xf32>{ + %0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %1 = onnx.Constant dense<[[0, 1], [1, 2]]> : tensor<2x2xi64> + %2 = "onnx.Gather"(%0, %1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_nonzero_dynamic_result() +func.func @test_nonzero_dynamic_result() -> tensor<*xi64> { + %0 = "onnx.Constant"() {value = dense<[[2, 1], [0, 2], [0, 1]]> : tensor<3x2xi8>} : () -> tensor<3x2xi8> + %1 = "onnx.NonZero"(%0) : (tensor<3x2xi8>) -> tensor<*xi64> + onnx.Return %1 : tensor<*xi64> +} + +// ----- + +// CHECK-LABEL: @test_scatternd_f32_dynamic_result() +func.func @test_scatternd_f32_dynamic_result() -> (tensor<*xf32>) { + %0 = onnx.Constant { name = "constant.0", value = dense<[1., 2., 3., 4., 5., 6., 7., 8.]>:tensor<8xf32> } : tensor<8xf32> + %1 = onnx.Constant { name = "constant.1", value = dense< [[4], [3], [1], [7]]>:tensor<4x1xi64> } : tensor<4x1xi64> + %2 = onnx.Constant { name = "constant.2", value = dense<[9., 10., 11., 12.]>:tensor<4xf32> } : tensor<4xf32> + %3 = "onnx.ScatterND"(%0, %1, %2) {node_name = "ScatterND_6467", node_type = "ScatterND"} : (tensor<8xf32>, tensor<4x1xi64>, tensor<4xf32>) -> tensor<*xf32> + onnx.Return %3 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: @test_pow_dynamic_result() +func.func @test_pow_dynamic_result() -> tensor<*xf32> { + %0 = onnx.Constant dense<64.0> : tensor<2x2xf32> + %1 = onnx.Constant dense<0.5> : tensor + %2 = "onnx.Pow"(%0, %1) : (tensor<2x2xf32> , tensor) -> tensor<*xf32> + onnx.Return %2 : tensor<*xf32> +} + + +// ----- + +/// variadic ops + +// CHECK-LABEL: @test_max_dynamic_result +func.func @test_max_dynamic_result() -> tensor<*xi32> { + %0 = "onnx.Constant"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "onnx.Max"(%0) : (tensor<2x2xi32>) -> tensor<*xi32> + "onnx.Return"(%1) : (tensor<*xi32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_min_dynamic_result +func.func @test_min_dynamic_result() -> tensor<*xi32> { + %0 = "onnx.Constant"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "onnx.Min"(%0) : (tensor<2x2xi32>) -> tensor<*xi32> + "onnx.Return"(%1) : (tensor<*xi32>) -> () +} + +// ----- + +// CHECK-LABEL: @test_sum_dynamic_result +func.func @test_sum_dynamic_result() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<0.5> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %1 = "onnx.Sum"(%0) : (tensor<2x2xf32>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () +} + +// ----- diff --git a/test/mlir/onnx/onnx_decompose.mlir b/test/mlir/onnx/onnx_decompose.mlir index 6fcdd0bbd1..711f7bcc72 100644 --- a/test/mlir/onnx/onnx_decompose.mlir +++ b/test/mlir/onnx/onnx_decompose.mlir @@ -2,6 +2,42 @@ // ----- +func.func @test_grid_sample_v16_bicubic(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bicubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bicubic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "cubic", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_bilinear(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "bilinear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_bilinear +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "linear", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_grid_sample_v16_nearest(%arg0: tensor<2x1x4x4xf32>, %arg1: tensor<2x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSampleV16"(%arg0, %arg1) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_grid_sample_v16_nearest +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<2x6x6x2xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.GridSample"([[PARAM_0_]], [[PARAM_1_]]) {align_corners = 1 : si64, mode = "nearest", padding_mode = "zeros"} : (tensor<2x1x4x4xf32>, tensor<2x6x6x2xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + func.func @test_dft(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 ="onnx.DFTV17"(%arg0, %arg1) : (tensor, tensor)-> tensor<*xf32> @@ -10,10 +46,9 @@ func.func @test_dft(%arg0 : tensor, %arg1 : tensor) -> tensor< // mlir2FileCheck.py // CHECK-LABEL: func.func @test_dft // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor<*xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor -// CHECK: [[VAR_2_:%.+]] = "onnx.DFT"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]]) {inverse = 0 : si64, onesided = 0 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> -// CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.DFT"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {inverse = 0 : si64, onesided = 0 : si64} : (tensor, tensor, tensor) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> // CHECK: } } @@ -90,7 +125,7 @@ func.func @test_reducelogsumexp(%arg0 : tensor, %arg1 : tensor // CHECK-NEXT: [[REDUCE_MAX:%.+]] = "onnx.ReduceMax"(%arg0, %arg1) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: [[SUB:%.+]] = "onnx.Sub"(%arg0, [[REDUCE_MAX]]) : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: [[EXP:%.+]] = "onnx.Exp"([[SUB]]) : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]], %arg1) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<*xf32>, tensor) -> tensor<*xf32> + // CHECK-NEXT: [[REDUCE_SUM:%.+]] = "onnx.ReduceSum"([[EXP]], %arg1) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK-NEXT: [[LOG:%.+]] = "onnx.Log"([[REDUCE_SUM]]) : (tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: [[SQUEEZE:%.+]] = "onnx.Squeeze"([[REDUCE_MAX]], %arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK-NEXT: [[RES:%.+]] = "onnx.Add"([[LOG]], [[SQUEEZE]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> @@ -163,15 +198,15 @@ func.func @test_scaler_no_offset(%arg0: tensor<3xf32>) -> tensor<3xf32> { // ----- // scaler no offset, int input -// CHECK-LABEL: func @test_scaler_no_offset2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-LABEL: func.func @test_scaler_no_offset2 func.func @test_scaler_no_offset2(%arg0: tensor<3xi32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - - // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> - // CHECK-NEXT: %1 = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> - // CHECK-NEXT: %2 = "onnx.Mul"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return %2 : tensor<3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_0_]]) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<3xf32> } // ----- @@ -190,61 +225,65 @@ func.func @test_scaler_no_scale(%arg0: tensor<3xf32>) -> tensor<3xf32> { // ----- // scaler no scale, int input -// CHECK-LABEL: func @test_scaler_no_scale2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-LABEL: func @test_scaler_no_scale2 func.func @test_scaler_no_scale2(%arg0: tensor<3xi32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32]} : (tensor<3xi32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> - // CHECK-NEXT: %1 = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> - // CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return %2 : tensor<3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[VAR_1_]], [[VAR_0_]]) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<3xf32> } // ----- // normal scaler -// CHECK-LABEL: func @test_scaler_normal(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-LABEL: func @test_scaler_normal func.func @test_scaler_normal(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xf32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> - // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: %2 = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> - // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return %3 : tensor<3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_3_]] : tensor<3xf32> } // ----- // normal scaler, int input -// CHECK-LABEL: func @test_scaler_normal2(%{{.*}}: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-LABEL: func @test_scaler_normal2 func.func @test_scaler_normal2(%arg0: tensor<3xi32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32, 0.99999988 : f32, 0.999999701 : f32], scale = [3.125000e-02 : f32, 0.0909090936 : f32, 0.0333333351 : f32]} : (tensor<3xi32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = "onnx.Cast"(%arg0) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> - // CHECK-NEXT: %1 = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> - // CHECK-NEXT: %2 = "onnx.Sub"(%0, %1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - // CHECK-NEXT: %3 = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> - // CHECK-NEXT: %4 = "onnx.Mul"(%2, %3) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return %4 : tensor<3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xi32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[3.125000e-02, 0.0909090936, 0.0333333351]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1986.99939, 0.99999988, 0.999999701]> : tensor<3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = f32} : (tensor<3xi32>) -> tensor<*xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Sub"([[VAR_2_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_3_]], [[VAR_0_]]) : (tensor<*xf32>, tensor<3xf32>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_4_]] : tensor<3xf32> } // ----- // normal scaler with constant offset and scale -// CHECK-LABEL: func @test_scaler_constant(%{{.*}}: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-LABEL: func @test_scaler_constant func.func @test_scaler_constant(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = "onnx.Scaler"(%arg0) {offset = [1986.99939 : f32], scale = [3.125000e-02 : f32]} : (tensor<3xf32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - // CHECK-NEXT: %0 = onnx.Constant dense<1986.99939> : tensor<1xf32> - // CHECK-NEXT: %1 = "onnx.Sub"(%arg0, %0) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> - // CHECK-NEXT: %2 = onnx.Constant dense<3.125000e-02> : tensor<1xf32> - // CHECK-NEXT: %3 = "onnx.Mul"(%1, %2) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return %3 : tensor<3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<3.125000e-02> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1986.99939> : tensor<1xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_3_]] : tensor<3xf32> } // ----- @@ -265,11 +304,11 @@ func.func @test_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { func.func @test_upsample(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<4xf32>) -> tensor<1x1x4x6xf32> { %0 = "onnx.Upsample"(%arg0, %arg1) {mode = "nearest"} : (tensor<1x1x2x2xf32>, tensor<4xf32>) -> tensor<1x1x4x6xf32> onnx.Return %0 : tensor<1x1x4x6xf32> - // CHECK-LABEL: test_upsample - // CHECK: [[NONE_0:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[NONE_1:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[RES:%.+]] = "onnx.Resize"(%arg0, [[NONE_0]], %arg1, [[NONE_1]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x1x2x2xf32>, none, tensor<4xf32>, none) -> tensor<1x1x4x6xf32> - // CHECK: onnx.Return [[RES]] : tensor<1x1x4x6xf32> +// CHECK-LABEL: func.func @test_upsample +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x2x2xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>) -> tensor<1x1x4x6xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_1_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]], [[VAR_0_]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x1x2x2xf32>, none, tensor<4xf32>, none) -> tensor<1x1x4x6xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<1x1x4x6xf32> } // ----- @@ -277,12 +316,12 @@ func.func @test_upsample(%arg0: tensor<1x1x2x2xf32>, %arg1: tensor<4xf32>) -> te func.func @test_upsamplev7(%arg0: tensor<1x1x2x2xf32>) -> tensor<1x1x4x6xf32> { %0 = "onnx.UpsampleV7"(%arg0) {mode = "nearest", scales = [0.1 : f32, 0.2 : f32, 0.3 : f32, 0.4 : f32]} : (tensor<1x1x2x2xf32>) -> tensor<1x1x4x6xf32> onnx.Return %0 : tensor<1x1x4x6xf32> - // CHECK-LABEL: test_upsamplev7 - // CHECK: [[NONE_0:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[SCALES:%.+]] = onnx.Constant dense<[1.000000e-01, 2.000000e-01, 3.000000e-01, 4.000000e-01]> : tensor<4xf32> - // CHECK: [[NONE_1:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[RES:%.+]] = "onnx.Resize"(%arg0, [[NONE_0]], [[SCALES]], [[NONE_1]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x1x2x2xf32>, none, tensor<4xf32>, none) -> tensor<1x1x4x6xf32> - // CHECK: onnx.Return [[RES]] : tensor<1x1x4x6xf32> +// CHECK-LABEL: func.func @test_upsamplev7 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x2x2xf32>) -> tensor<1x1x4x6xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1.000000e-01, 2.000000e-01, 3.000000e-01, 4.000000e-01]> : tensor<4xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_0_]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x1x2x2xf32>, none, tensor<4xf32>, none) -> tensor<1x1x4x6xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x1x4x6xf32> } // ----- @@ -303,12 +342,12 @@ func.func @test_padv2(%arg0: tensor<1x3x224x224xf32>) -> tensor<*xf32> { func.func @test_resizev10(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<4xf32>) -> tensor<*xf32> { %0 = "onnx.ResizeV10"(%arg0, %arg1) {mode = "nearest"} : (tensor<1x2x3x4xf32>, tensor<4xf32>) -> tensor<*xf32> onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_resizev10 - // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>) -> tensor<*xf32> { - // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue" - // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue" - // CHECK: [[VAR_2_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]], [[VAR_1_]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x2x3x4xf32>, none, tensor<4xf32>, none) -> tensor<*xf32> - // CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> +// CHECK-LABEL: func.func @test_resizev10 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_1_:%.+]] = "onnx.Resize"([[PARAM_0_]], [[VAR_0_]], [[PARAM_1_]], [[VAR_0_]]) {antialias = 0 : si64, coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, keep_aspect_ratio_policy = "stretch", mode = "nearest", nearest_mode = "round_prefer_floor"} : (tensor<1x2x3x4xf32>, none, tensor<4xf32>, none) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> +// CHECK: } } // ----- @@ -365,102 +404,108 @@ func.func @test_seqence_construct_2(%arg0: tensor<*xi16>) -> !onnx.Seq) -> () { +func.func @test_clipv6(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.ClipV6"(%arg0) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_clipv6 - // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) { - // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor - // CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<6.000000e+00> : tensor - // CHECK: [[VAR_2_:%.+]] = "onnx.Clip"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_clipv6 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<6.000000e+00> : tensor +// CHECK: [[VAR_2_:%.+]] = "onnx.Clip"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> } // ----- -func.func @test_splitV11(%arg0 : tensor<*xf32>) -> () { +func.func @test_splitV11(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.SplitV11"(%arg0) {axis = 1 : si64, split = [1]} : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_splitV11 - // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> - // CHECK: [[VAR_1_:%.+]] = "onnx.Split"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_splitV11 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_splitV11_no_split(%arg0 : tensor<*xf32>) -> () { +func.func @test_splitV11_no_split(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.SplitV11"(%arg0) {axis = 1 : si64} : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_splitV11_no_split - // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[VAR_1_:%.+]] = "onnx.Split"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, none) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_splitV11_no_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_1_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<*xf32>, none) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_splitV13(%arg0 : tensor<*xf32>) -> () { +func.func @test_splitV13(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = onnx.Constant dense<1> : tensor<1xi64> %1 = "onnx.SplitV13"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> - onnx.Return + onnx.Return %1 : tensor<*xf32> - // CHECK-LABEL: func @test_splitV13 - // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> - // CHECK: [[VAR_1_:%.+]] = "onnx.Split"(%arg0, %0) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_splitV13 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_squeezeV11(%arg0 : tensor<*xf32>) -> () { +func.func @test_squeezeV11(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.SqueezeV11"(%arg0) {axes = [1]} : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_squeezeV11 - // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> - // CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"(%arg0, %0) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_squeezeV11 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[VAR_0_]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_squeezeV11_no_axes(%arg0 : tensor<*xf32>) -> () { +func.func @test_squeezeV11_no_axes(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.SqueezeV11"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_squeezeV11_no_axes - // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"(%arg0, %0) : (tensor<*xf32>, none) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_squeezeV11_no_axes +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_1_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[VAR_0_]]) : (tensor<*xf32>, none) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_unsqueezeV11(%arg0 : tensor<*xf32>) -> () { +func.func @test_unsqueezeV11(%arg0 : tensor<*xf32>) -> tensor<*xf32> { %0 = "onnx.UnsqueezeV11"(%arg0) {axes = [1]} : (tensor<*xf32>) -> tensor<*xf32> - onnx.Return + onnx.Return %0 : tensor<*xf32> - // CHECK-LABEL: func @test_unsqueezeV11 - // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> - // CHECK: [[VAR_1_:%.+]] = "onnx.Unsqueeze"(%arg0, %0) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> - // CHECK: onnx.Return +// CHECK-LABEL: func.func @test_unsqueezeV11 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_0_]], [[VAR_0_]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> } // ----- -func.func @test_padV13(%arg0 : tensor<*xi64>, %arg1 : tensor<2xi64>) -> () { +func.func @test_padV13(%arg0 : tensor<*xi64>, %arg1 : tensor<2xi64>) -> tensor<*xi64> { %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.PadV13"(%arg0, %arg1, %0) : (tensor<*xi64>, tensor<2xi64>, none) -> tensor<*xi64> - onnx.Return - // CHECK-LABEL: func @test_padV13 - // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none - // CHECK: [[VAR_2_:%.+]] = "onnx.Pad"(%arg0, %arg1, [[VAR_0_]], [[VAR_1_]]) {mode = "constant"} : (tensor<*xi64>, tensor<2xi64>, none, none) -> tensor<*xi64> - // CHECK: onnx.Return + onnx.Return %1 : tensor<*xi64> +// CHECK-LABEL: func.func @test_padV13 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xi64>, [[PARAM_1_:%.+]]: tensor<2xi64>) -> tensor<*xi64> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_1_:%.+]] = "onnx.Pad"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]], [[VAR_0_]]) {mode = "constant"} : (tensor<*xi64>, tensor<2xi64>, none, none) -> tensor<*xi64> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xi64> } // ----- @@ -492,21 +537,20 @@ func.func @concat_fuse_0(%arg0: tensor, %arg1: tensor) -> (t // ----- -func.func @test_concatfuse_1(%arg0: tensor, %arg1: tensor) -> (tensor<2xi64>, tensor<50x?xf32>) +func.func @test_concatfuse_1(%arg0: tensor, %arg1: tensor) -> (tensor, tensor<2xi64>, tensor<50x?xf32>) { %1 = "onnx.Concat"(%arg0, %arg1) {axis = 1 : si64} : (tensor, tensor) -> tensor %2 = "onnx.Transpose"(%1) {perm = [1, 0]} : (tensor) -> tensor<50x?xf32> %3 = "onnx.Shape"(%1) : (tensor) -> tensor<2xi64> %4 = "onnx.Sin"(%1) : (tensor) -> tensor - onnx.Return %3, %2 : tensor<2xi64>, tensor<50x?xf32> + onnx.Return %4, %3, %2 : tensor, tensor<2xi64>, tensor<50x?xf32> // CHECK-LABEL: func.func @test_concatfuse_1 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> (tensor<2xi64>, tensor<50x?xf32>) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> (tensor, tensor<2xi64>, tensor<50x?xf32>) { // CHECK: [[VAR_0_:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]]) {axis = 1 : si64} : (tensor, tensor) -> tensor // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[VAR_0_]]) {perm = [1, 0]} : (tensor) -> tensor<50x?xf32> // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Shape"([[VAR_0_]]) {start = 0 : si64} : (tensor) -> tensor<2xi64> // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Sin"([[VAR_0_]]) : (tensor) -> tensor -// CHECK: onnx.Return [[VAR_2_]], [[VAR_1_]] : tensor<2xi64>, tensor<50x?xf32> -// CHECK: } +// CHECK: onnx.Return [[VAR_3_]], [[VAR_2_]], [[VAR_1_]] : tensor, tensor<2xi64>, tensor<50x?xf32> } // ----- @@ -542,47 +586,140 @@ func.func @test_constantofshape(%arg0: tensor) -> tensor<*xi32> { // ----- -func.func @test_groupnorm(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> { - %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32> +func.func @test_hardswish_f32(%arg0: tensor) -> tensor { + %0 = "onnx.HardSwish"(%arg0) : (tensor) -> tensor + return %0 : tensor +// CHECK-LABEL: func @test_hardswish_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor +// CHECK: [[VAR_0_:%.+]] = "onnx.HardSigmoid"([[PARAM_0_]]) {alpha = 0.166666672 : f32, beta = 5.000000e-01 : f32} : (tensor) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[VAR_0_]], [[PARAM_0_]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAR_1_]] : tensor +} + +// ----- + +func.func @test_groupnorm_v18(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> { + %0 = "onnx.GroupNormalizationV18"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32> onnx.Return %0 : tensor<3x4x2x2xf32> // mlir2FileCheck.py -// CHECK-LABEL: func.func @test_groupnorm +// CHECK-LABEL: func.func @test_groupnorm_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x2x2xf32> { -// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<3xi64>) -> tensor<2x1x1x1xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<3xi64>) -> tensor<2x1x1x1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<1xi64> -// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_1_]]) : (tensor<2xf32>, tensor<3xi64>) -> tensor<2x1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_1_]]) : (tensor<2xf32>, tensor<3xi64>) -> tensor<2x1x1x1xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<1xi64> // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x2x2xf32>) -> tensor<2xi64> -// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5xi64> -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<3x4x2x2xf32>, tensor<5xi64>) -> tensor<3x2x2x2x2xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_7_]], [[VAR_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x1x1x1xf32>, tensor<2x1x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none) -// CHECK: [[VAR_9_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<4xi64> -// CHECK: [[VAR_10_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_9_]]) {allowzero = 0 : si64} : (tensor<3x2x2x2x2xf32>, tensor<4xi64>) -> tensor<3x4x2x2xf32> -// CHECK: onnx.Return [[VAR_10_]] : tensor<3x4x2x2xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_4_]], [[VAR_0_]], [[VAR_5_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5xi64> +// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<3x4x2x2xf32>, tensor<5xi64>) -> tensor<3x2x2x2x2xf32> +// CHECK: [[Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_7_]], [[VAR_2_]], [[VAR_3_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x1x1x1xf32>, tensor<2x1x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none) +// CHECK: [[VAR_8_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<4xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x2x2x2x2xf32>, tensor<4xi64>) -> tensor<3x4x2x2xf32> +// CHECK: onnx.Return [[VAR_9_]] : tensor<3x4x2x2xf32> +} + +// ----- + +func.func @test_groupnorm_dynamic_1(%arg0: tensor<*xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_groupnorm_dynamic_1 +// CHECK: onnx.GroupNormalization +} +// ----- + +func.func @test_groupnorm_dynamic_2(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<2x3x4x5x6xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_groupnorm_dynamic_2 +// CHECK: onnx.GroupNormalization +} +// ----- + +func.func @test_groupnorm_dynamic_3(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<2x3x4x5x6xf32>, tensor, tensor) -> tensor + onnx.Return %0 : tensor +// CHECK-LABEL: func.func @test_groupnorm_dynamic_3 +// CHECK: onnx.GroupNormalization +} + +// ----- + +func.func @test_groupnorm_dynamic_4(%arg0: tensor, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<2x3x4x5x6xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5x6xf32> + onnx.Return %0 : tensor<2x3x4x5x6xf32> +// CHECK-LABEL: func.func @test_groupnorm_dynamic_4 +// CHECK: onnx.GroupNormalization +} + +// ----- + +func.func @test_groupnorm_v21(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<3x4x2x2xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<3x4x2x2xf32> + onnx.Return %0 : tensor<3x4x2x2xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_groupnorm_v21 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x2x2xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>, [[PARAM_2_:%.+]]: tensor<4xf32>) -> tensor<3x4x2x2xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1> : tensor<2xi64> +// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<2xi64>) -> tensor<4xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<4xi64>) -> tensor<2x2x1x1xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x2x2xf32>) -> tensor<2xi64> +// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_6_]], [[VAR_0_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x2x2xf32>, tensor<5xi64>) -> tensor<3x2x2x2x2xf32> +// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_4_]], [[VAR_5_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x2x2xf32>, tensor<2x2x1x1xf32>, tensor<2x2x1x1xf32>) -> (tensor<3x2x2x2x2xf32>, none, none) +// CHECK: [[VAR_10_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x2x2xf32>) -> tensor<4xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_10_]]) {allowzero = 0 : si64} : (tensor<3x2x2x2x2xf32>, tensor<4xi64>) -> tensor<3x4x2x2xf32> +// CHECK: onnx.Return [[VAR_11_]] : tensor<3x4x2x2xf32> // CHECK: } } + // ----- -func.func @group_norm5d(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { - %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32> +func.func @group_norm5d_v18(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { + %0 = "onnx.GroupNormalizationV18"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x6x8x16xf32> onnx.Return %0 : tensor<3x4x6x8x16xf32> -// CHECK-LABEL: func.func @group_norm5d +// CHECK-LABEL: func.func @group_norm5d_v18 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<2xf32>, [[PARAM_2_:%.+]]: tensor<2xf32>) -> tensor<3x4x6x8x16xf32> { -// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3, 4]> : tensor<4xi64> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_0_]]) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<1xi64> -// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 2, 3, 4]> : tensor<4xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_1_]]) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_1_]]) : (tensor<2xf32>, tensor<4xi64>) -> tensor<2x1x1x1x1xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<1xi64> // CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<3xi64> -// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_3_]], [[VAR_4_]], [[VAR_5_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<3xi64>) -> tensor<6xi64> -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_7_]], [[VAR_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x1x1x1x1xf32>, tensor<2x1x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none) -// CHECK: [[VAR_9_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<5xi64> -// CHECK: [[VAR_10_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_9_]]) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32> -// CHECK: onnx.Return [[VAR_10_]] : tensor<3x4x6x8x16xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Concat"([[VAR_4_]], [[VAR_0_]], [[VAR_5_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<3xi64>) -> tensor<6xi64> +// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_6_]]) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32> +// CHECK: [[Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_7_]], [[VAR_2_]], [[VAR_3_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x1x1x1x1xf32>, tensor<2x1x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none) +// CHECK: [[VAR_8_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<5xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[Y_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32> +// CHECK: onnx.Return [[VAR_9_]] : tensor<3x4x6x8x16xf32> +} + +// ----- + +func.func @group_norm5d_v21(%arg0: tensor<3x4x6x8x16xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<3x4x6x8x16xf32> { + %0 = "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x6x8x16xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<3x4x6x8x16xf32> + onnx.Return %0 : tensor<3x4x6x8x16xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @group_norm5d_v21 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x6x8x16xf32>, [[PARAM_1_:%.+]]: tensor<4xf32>, [[PARAM_2_:%.+]]: tensor<4xf32>) -> tensor<3x4x6x8x16xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, -1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> +// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>, tensor<3xi64>) -> tensor<5xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Reshape"([[PARAM_2_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<4xf32>, tensor<5xi64>) -> tensor<2x2x1x1x1xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {end = 1 : si64, start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 2 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<3xi64> +// CHECK: [[VAR_8_:%.+]] = "onnx.Concat"([[VAR_6_]], [[VAR_0_]], [[VAR_7_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<2xi64>, tensor<3xi64>) -> tensor<6xi64> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_8_]]) {allowzero = 0 : si64} : (tensor<3x4x6x8x16xf32>, tensor<6xi64>) -> tensor<3x2x2x6x8x16xf32> +// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_9_]], [[VAR_4_]], [[VAR_5_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<2x2x1x1x1xf32>, tensor<2x2x1x1x1xf32>) -> (tensor<3x2x2x6x8x16xf32>, none, none) +// CHECK: [[VAR_10_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<3x4x6x8x16xf32>) -> tensor<5xi64> +// CHECK: [[VAR_11_:%.+]] = "onnx.Reshape"([[VAR_Y_]], [[VAR_10_]]) {allowzero = 0 : si64} : (tensor<3x2x2x6x8x16xf32>, tensor<5xi64>) -> tensor<3x4x6x8x16xf32> +// CHECK: onnx.Return [[VAR_11_]] : tensor<3x4x6x8x16xf32> // CHECK: } } @@ -597,10 +734,43 @@ func.func @test_instancenorm(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<3xf32>, // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32> // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<2x3x4x5x6xf32>, tensor<3x1x1x1xf32>, tensor<3x1x1x1xf32>) -> (tensor<2x3x4x5x6xf32>, none, none) +// CHECK: [[Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<2x3x4x5x6xf32>, tensor<3x1x1x1xf32>, tensor<3x1x1x1xf32>) -> (tensor<2x3x4x5x6xf32>, none, none) // CHECK: onnx.Return [[Y_]] : tensor<2x3x4x5x6xf32> -// CHECK: } +} + +// ----- + +func.func @test_instancenorm_dynamic_1(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<*xf32> { + %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<*xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_instancenorm_dynamic_1 +// CHECK: onnx.InstanceNormalization +} + +// ----- + +func.func @test_instancenorm_dynamic_2(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { + %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5x6xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> +// CHECK-LABEL: func.func @test_instancenorm_dynamic_2 +// CHECK: onnx.InstanceNormalization +} +// ----- + +func.func @test_instancenorm_dynamic_3(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5x6xf32>, tensor, tensor) -> tensor + onnx.Return %0 : tensor +// CHECK-LABEL: func.func @test_instancenorm_dynamic_3 +// CHECK: onnx.InstanceNormalization +} + +// ----- + +func.func @test_instancenorm_dynamic_4(%arg0: tensor, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<2x3x4x5x6xf32> { + %0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5x6xf32> + onnx.Return %0 : tensor<2x3x4x5x6xf32> +// CHECK-LABEL: func.func @test_instancenorm_dynamic_4 +// CHECK: onnx.InstanceNormalization } // ----- @@ -646,3 +816,637 @@ func.func @test_castlike(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf16>) -> tensor // CHECK: onnx.Return [[RES]] : tensor<*xf16> } +// ----- + +func.func @test_sum(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<64x128x10xf32> { + %0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<64x128x10xf32> + onnx.Return %0 : tensor<64x128x10xf32> + // CHECK-LABEL: func @test_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) + // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]]) + // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]]) + // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]]) + // CHECK-NEXT: onnx.Return %[[SUM2]] +} + +// ----- + +func.func @test_sum_to_unranked(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> + // CHECK-LABEL: func @test_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) + // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]]) + // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]]) + // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]]) + // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[SUM2]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32> + // CHECK-NEXT: onnx.Return %[[CAST]] +} + +// ----- + +func.func @test_sum_single_input(%arg0: tensor<64x128x10xf32>) -> tensor<64x128x10xf32> { + %0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<64x128x10xf32> + onnx.Return %0 : tensor<64x128x10xf32> + // CHECK-LABEL: func @test_sum_single_input + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) + // CHECK-NEXT: onnx.Return %[[ARG0]] +} + +// ----- + +func.func @test_sum_single_input_to_unranked(%arg0: tensor<64x128x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> + // CHECK-LABEL: func @test_sum_single_input_to_unranked + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}) + // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32> + // CHECK-NEXT: onnx.Return %[[CAST]] +} +// ----- + +func.func @test_batchnorm_f32(%arg0: tensor<100x3x10x10xf32>) -> tensor<100x3x10x10xf32> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>) + return %4 : tensor<100x3x10x10xf32> +// CHECK-LABEL: func.func @test_batchnorm_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x10x10xf32>) -> tensor<100x3x10x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.BatchNormalizationInferenceMode"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<100x3x10x10xf32> +// CHECK: return [[VAR_4_]] : tensor<100x3x10x10xf32> +} + +// ----- + +func.func @test_batchnorm_f16_dynamic(%arg0: tensor<100x3x?x?xf16>) -> tensor<*xf16> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>) + return %4 : tensor<*xf16> +// CHECK-LABEL: func @test_batchnorm_f16_dynamic +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: "onnx.BatchNormalizationInferenceMode"(%arg0, [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32} +} + +// ----- + +func.func @test_batchnorm_bf16_dynamic(%arg0: tensor<100x3x?x?xbf16>) -> tensor<*xbf16> { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16>) + return %4 : tensor<*xbf16> +// CHECK-LABEL: func.func @test_batchnorm_bf16_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x?x?xbf16>) -> tensor<*xbf16> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xbf16> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xbf16> +// CHECK: [[VAR_4_:%.+]] = "onnx.BatchNormalizationInferenceMode"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> tensor<*xbf16> +// CHECK: return [[VAR_4_]] : tensor<*xbf16> +} + + +// ----- + +func.func @test_batchnorm_bf16_use_mean_var(%arg0: tensor<100x3x?x?xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16>) + return %4, %5, %6 : tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16> +// CHECK-LABEL: func @test_batchnorm_bf16_use_mean_var +// CHECK: onnx.BatchNormalization" +} + +// ----- + +func.func @test_batchnorm_bf16_use_mean(%arg0: tensor<100x3x?x?xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16>) + return %4, %5 : tensor<*xbf16>, tensor<*xbf16> +// CHECK-LABEL: func @test_batchnorm_bf16_use_mean +// CHECK: onnx.BatchNormalization" +} + +// ----- + +func.func @test_batchnorm_bf16_use_var(%arg0: tensor<100x3x?x?xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xbf16>} : () -> tensor<3xbf16> + %4, %5, %6 = "onnx.BatchNormalization"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>, tensor<3xbf16>) -> (tensor<*xbf16>, tensor<*xbf16>, tensor<*xbf16>) + return %4, %6 : tensor<*xbf16>, tensor<*xbf16> +// CHECK-LABEL: func @test_batchnorm_bf16_use_var +// CHECK: onnx.BatchNormalization" +} + +// ----- + +func.func @test_batchnormv9_f32_use_saved_mean_var(%arg0: tensor<100x3x?x?xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %4, %7, %8 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> +// CHECK-LABEL: func @test_batchnormv9_f32_use_saved_mean_var +// CHECK: onnx.BatchNormalizationV9" +} + +// ----- + +func.func @test_batchnormv9_f32_use_saved_mean(%arg0: tensor<100x3x?x?xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %4, %7 : tensor<*xf32>, tensor<*xf32> +// CHECK-LABEL: func @test_batchnormv9_f32_use_saved_mean +// CHECK: onnx.BatchNormalizationV9" +} + +// ----- + +func.func @test_batchnormv9_f32_use_saved_var(%arg0: tensor<100x3x?x?xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %4, %8 : tensor<*xf32>, tensor<*xf32> +// CHECK-LABEL: func @test_batchnormv9_f32_use_saved_var +// CHECK: onnx.BatchNormalizationV9" +} + +// ----- + +func.func @test_batchnormv9_f32(%arg0: tensor<100x3x10x10xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) + return %4, %5, %6 : tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32> +} + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_batchnormv9_f32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x10x10xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf32> +// CHECK: [[Y_:%.+]], [[VAR_running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32, training_mode = 0 : si64} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>) +// CHECK: return [[Y_]], [[VAR_running_mean_]], [[VAR_running_var_]] : tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32> +// CHECK: } + +// ----- + +func.func @test_batchnormv9_f32_no_var_mean_use(%arg0: tensor<100x3x10x10xf32>) -> (tensor<100x3x10x10xf32>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) + return %4: tensor<100x3x10x10xf32> +} + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_batchnormv9_f32_no_var_mean_use +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x10x10xf32>) -> tensor<100x3x10x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.BatchNormalizationInferenceMode"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32} : (tensor<100x3x10x10xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<100x3x10x10xf32> +// CHECK: return [[VAR_4_]] : tensor<100x3x10x10xf32> +// CHECK: } + +// ----- + +func.func @test_batchnormv9_f16_dynamic(%arg0: tensor<100x3x?x?xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>) { + %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %1 = "onnx.Constant"() {value = dense<[2.0, 3.0, 4.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %2 = "onnx.Constant"() {value = dense<[3.0, 4.0, 5.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf16>} : () -> tensor<3xf16> + %4, %5, %6, %7, %8 = "onnx.BatchNormalizationV9"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32, momentum = 1.00000007E-3 : f32} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>,tensor<*xf16>, tensor<*xf16>) + return %4, %5, %6 : tensor<*xf16>, tensor<*xf16>, tensor<*xf16> +} + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_batchnormv9_f16_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<100x3x?x?xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf16> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf16> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf16> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf16> +// CHECK: [[Y_:%.+]], [[VAR_running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]]) {epsilon = 1.00000007E-5 : f32, momentum = 1.000000e-03 : f32, training_mode = 0 : si64} : (tensor<100x3x?x?xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>, tensor<3xf16>) -> (tensor<*xf16>, tensor<*xf16>, tensor<*xf16>) +// CHECK: return [[Y_]], [[VAR_running_mean_]], [[VAR_running_var_]] : tensor<*xf16>, tensor<*xf16>, tensor<*xf16> +// CHECK: } + +// ----- + +func.func @test_pad_slice_only_slice() -> tensor<3x1xf32> { + %data = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %pads = onnx.Constant dense<[0, -1, 0, 0]> : tensor<4xi64> + %non = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Pad"(%data, %pads, %non, %non) { mode = "constant" } : (tensor<3x2xf32>, tensor<4xi64>, none, none) -> tensor<3x1xf32> + onnx.Return %1 : tensor<3x1xf32> +} +// CHECK-LABEL: func.func @test_pad_slice_only_slice +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}[1.000000e+00, 1.200000e+00], [2.300000e+00, 3.400000e+00], [4.500000e+00, 5.700000e+00]{{.}}> : tensor<3x2xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[3, 2]> : tensor<2xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_4_:%.+]] = "onnx.Slice"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]], [[VAR_3_]], [[VAR_3_]]) : (tensor<3x2xf32>, tensor<2xi64>, tensor<2xi64>, none, none) -> tensor<3x1xf32> +// CHECK: onnx.Return [[VAR_4_]] : tensor<3x1xf32> + +// ----- + +func.func @test_pad_slice() -> tensor<4x1xf32> { + %data = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> + %pads = onnx.Constant dense<[0, -1, 1, 0]> : tensor<4xi64> + %non = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Pad"(%data, %pads, %non, %non) { mode = "constant" } : (tensor<3x2xf32>, tensor<4xi64>, none, none) -> tensor<4x1xf32> + onnx.Return %1 : tensor<4x1xf32> +} +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 0, 1, 0]> : tensor<4xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{.}}[1.000000e+00, 1.200000e+00], [2.300000e+00, 3.400000e+00], [4.500000e+00, 5.700000e+00]{{.}}> : tensor<3x2xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<[3, 2]> : tensor<2xi64> +// CHECK: [[VAR_5_:%.+]] = "onnx.Slice"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]], [[VAR_2_]], [[VAR_2_]]) : (tensor<3x2xf32>, tensor<2xi64>, tensor<2xi64>, none, none) -> tensor<3x1xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Pad"([[VAR_5_]], [[VAR_0_]], [[VAR_2_]], [[VAR_2_]]) {mode = "constant"} : (tensor<3x1xf32>, tensor<4xi64>, none, none) -> tensor<4x1xf32> +// CHECK: onnx.Return [[VAR_6_]] : tensor<4x1xf32> + +// ----- + +func.func @test_pad_slice_dynamic(%data : tensor<*xf32>) -> tensor<*xf32> { + // Just checks that we do not crash + %pads = onnx.Constant dense<[0, -1, 1, 0]> : tensor<4xi64> + %non = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Pad"(%data, %pads, %non, %non) { mode = "constant" } : (tensor<*xf32>, tensor<4xi64>, none, none) -> tensor<*xf32> + onnx.Return %1 : tensor<*xf32> +} +// CHECK-LABEL: func.func @test_pad_slice_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, -1, 1, 0]> : tensor<4xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_2_:%.+]] = "onnx.Pad"([[PARAM_0_]], [[VAR_0_]], [[VAR_1_]], [[VAR_1_]]) {mode = "constant"} : (tensor<*xf32>, tensor<4xi64>, none, none) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<*xf32> + +// ----- +func.func @test_scatter_nd_single_split_begin(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], [0, 0, 5], [0, 0, 6], [0, 0, 7], [0, 0, 8], [0, 0, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_single_split_begin +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6x10x12xf32>, [[PARAM_1_:%.+]]: tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 1, 5]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x6x10x12xf32>, tensor<3xi64>) -> (tensor<1x0x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x5x10x12xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 1 : si64} : (tensor<1x0x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x5x10x12xf32>) -> tensor<1x6x10x12xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x6x10x12xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_single_split_end(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 5, 0], [0, 5, 1], [0, 5, 2], [0, 5, 3], [0, 5, 4], [0, 5, 5], [0, 5, 6], [0, 5, 7], [0, 5, 8], [0, 5, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_single_split_end +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6x10x12xf32>, [[PARAM_1_:%.+]]: tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[5, 1, 0]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x6x10x12xf32>, tensor<3xi64>) -> (tensor<1x5x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x0x10x12xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 1 : si64} : (tensor<1x5x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x0x10x12xf32>) -> tensor<1x6x10x12xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x6x10x12xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_double_split(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], [0, 1, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_double_split +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6x10x12xf32>, [[PARAM_1_:%.+]]: tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 1, 4]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 1 : si64} : (tensor<1x6x10x12xf32>, tensor<3xi64>) -> (tensor<1x1x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x4x10x12xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 1 : si64} : (tensor<1x1x10x12xf32>, tensor<1x1x10x12xf32>, tensor<1x4x10x12xf32>) -> tensor<1x6x10x12xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x6x10x12xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_reduction(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], [0, 0, 5], [0, 0, 6], [0, 0, 7], [0, 0, 8], [0, 0, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "add"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_reduction +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_not_const(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32>, %indices : tensor<1x1x10x3xi64> ) -> tensor<1x6x10x12xf32> { + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_not_const +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_dynamic(%data : tensor<*xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<*xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], [0, 0, 5], [0, 0, 6], [0, 0, 7], [0, 0, 8], [0, 0, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<*xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<*xf32> + onnx.Return %0 : tensor<*xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_dynamic +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_multi_dim_differ(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], [0, 1, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<2x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<2x6x10x12xf32> + onnx.Return %0 : tensor<2x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_differ +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_multi_dim_differ_multi_shift(%data : tensor<2x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<2x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[1, 1, 0], [1, 1, 1], [1, 1, 2], [1, 1, 3], [1, 1, 4], [1, 1, 5], [1, 1, 6], [1, 1, 7], [1, 1, 8], [1, 1, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<2x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<2x6x10x12xf32> + onnx.Return %0 : tensor<2x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_differ_multi_shift +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_negative_shift(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[ 0, -1, 0], [ 0, -1, 1], [ 0, -1, 2], [ 0, -1, 3], [ 0, -1, 4], [ 0, -1, 5], [ 0, -1, 6], [ 0, -1, 7], [ 0, -1, 8], [ 0, -1, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_negative_shift +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_full_overwrite(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x6x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [0, 0, 4], [0, 0, 5], [0, 0, 6], [0, 0, 7], [0, 0, 8], [0, 0, 9]], [[0, 1, 0], [0, 1, 1], [0, 1, 2], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], [0, 1, 9]], [[0, 2, 0], [0, 2, 1], [0, 2, 2], [0, 2, 3], [0, 2, 4], [0, 2, 5], [0, 2, 6], [0, 2, 7], [0, 2, 8], [0, 2, 9]], [[0, 3, 0], [0, 3, 1], [0, 3, 2], [0, 3, 3], [0, 3, 4], [0, 3, 5], [0, 3, 6], [0, 3, 7], [0, 3, 8], [0, 3, 9]], [[0, 4, 0], [0, 4, 1], [0, 4, 2], [0, 4, 3], [0, 4, 4], [0, 4, 5], [0, 4, 6], [0, 4, 7], [0, 4, 8], [0, 4, 9]], [[0, 5, 0], [0, 5, 1], [0, 5, 2], [0, 5, 3], [0, 5, 4], [0, 5, 5], [0, 5, 6], [0, 5, 7], [0, 5, 8], [0, 5, 9]]]]> : tensor<1x6x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x6x10x3xi64>, tensor<1x6x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_full_overwrite +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x6x10x12xf32>, [[PARAM_1_:%.+]]: tensor<1x6x10x12xf32>) -> tensor<1x6x10x12xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 12, 0]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 3 : si64} : (tensor<1x6x10x12xf32>, tensor<3xi64>) -> (tensor<1x6x10x0xf32>, tensor<1x6x10x12xf32>, tensor<1x6x10x0xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 3 : si64} : (tensor<1x6x10x0xf32>, tensor<1x6x10x12xf32>, tensor<1x6x10x0xf32>) -> tensor<1x6x10x12xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x6x10x12xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_multi_dim(%data : tensor<1x4x6xf32>, %updates : tensor<1x4x2xf32> ) -> tensor<1x4x6xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, 1, 1]], [[0, 2, 0], [0, 2, 1]], [[0, 3, 0], [0, 3, 1]]]]> : tensor<1x4x2x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x4x6xf32>, tensor<1x4x2x3xi64>, tensor<1x4x2xf32>) -> tensor<1x4x6xf32> + onnx.Return %0 : tensor<1x4x6xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_multi_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x6xf32>, [[PARAM_1_:%.+]]: tensor<1x4x2xf32>) -> tensor<1x4x6xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 2, 4]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x4x6xf32>, tensor<3xi64>) -> (tensor<1x4x0xf32>, tensor<1x4x2xf32>, tensor<1x4x4xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 2 : si64} : (tensor<1x4x0xf32>, tensor<1x4x2xf32>, tensor<1x4x4xf32>) -> tensor<1x4x6xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x4x6xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_multi_dim_shift(%data : tensor<1x4x6xf32>, %updates : tensor<1x4x2xf32> ) -> tensor<1x4x6xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 1], [0, 0, 2]], [[0, 1, 1], [0, 1, 2]], [[0, 2, 1], [0, 2, 2]], [[0, 3, 1], [0, 3, 2]]]]> : tensor<1x4x2x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x4x6xf32>, tensor<1x4x2x3xi64>, tensor<1x4x2xf32>) -> tensor<1x4x6xf32> + onnx.Return %0 : tensor<1x4x6xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_shift +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x4x6xf32>, [[PARAM_1_:%.+]]: tensor<1x4x2xf32>) -> tensor<1x4x6xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64> +// CHECK: [[VAR_1_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x4x6xf32>, tensor<3xi64>) -> (tensor<1x4x1xf32>, tensor<1x4x2xf32>, tensor<1x4x3xf32>) +// CHECK: [[VAR_2_:%.+]] = "onnx.Concat"([[VAR_1_]]#0, [[PARAM_1_]], [[VAR_1_]]#2) {axis = 2 : si64} : (tensor<1x4x1xf32>, tensor<1x4x2xf32>, tensor<1x4x3xf32>) -> tensor<1x4x6xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x4x6xf32> +// CHECK: } + +// ----- +func.func @test_scatter_nd_multi_dim_not_in_order(%data : tensor<1x4x6xf32>, %updates : tensor<1x4x2xf32> ) -> tensor<1x4x6xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 1, 1]], [[0, 1, 0], [0, 0, 1]], [[0, 2, 0], [0, 2, 1]], [[0, 3, 0], [0, 3, 1]]]]> : tensor<1x4x2x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x4x6xf32>, tensor<1x4x2x3xi64>, tensor<1x4x2xf32>) -> tensor<1x4x6xf32> + onnx.Return %0 : tensor<1x4x6xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_multi_dim_not_in_order +// CHECK: onnx.ScatterND + +// ----- +func.func @test_scatter_nd_single_not_in_order(%data : tensor<1x6x10x12xf32>, %updates : tensor<1x1x10x12xf32> ) -> tensor<1x6x10x12xf32> { + %indices = onnx.Constant dense<[[[[0, 0, 0], [0, 0, 2], [0, 0, 1], [0, 0, 3], [0, 0, 4], [0, 0, 5], [0, 0, 6], [0, 0, 7], [0, 0, 8], [0, 0, 9]]]]> : tensor<1x1x10x3xi64> + %0 = "onnx.ScatterND"(%data, %indices, %updates) {reduction = "none"} : (tensor<1x6x10x12xf32>, tensor<1x1x10x3xi64>, tensor<1x1x10x12xf32>) -> tensor<1x6x10x12xf32> + onnx.Return %0 : tensor<1x6x10x12xf32> +} +// CHECK-LABEL: func.func @test_scatter_nd_single_not_in_order +// CHECK: onnx.ScatterND +// ----- + +func.func @sce_mean(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> tensor { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "mean"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceMean"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_mean_return_log_prob(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> (tensor, tensor<64x10xf32>) { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "mean"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, tensor<64x10xf32>) + onnx.Return %output, %log_prob : tensor, tensor<64x10xf32> + // CHECK-LABEL: func @sce_mean_return_log_prob + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceMean"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]], %[[LOG_SOFTMAX]] : tensor, tensor<64x10xf32> +} + +// ----- + +func.func @sce_mean_with_weight_NCD1D2(%arg0: tensor<64x10x2x3xf32>, %arg1: tensor<64x2x3xi64>, %arg2: tensor<10xf32>) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "mean"} : (tensor<64x10x2x3xf32>, tensor<64x2x3xi64>, tensor<10xf32>) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean_with_weight_NCD1D2 + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK-DAG: %[[COLLAPSED_SHAPE:.*]] = onnx.Constant dense<[384, 10]> : tensor<2xi64> + // CHECK-DAG: %[[EXPANDED_WEIGHT_SHAPE:.*]] = onnx.Constant dense<[10, 1]> : tensor<2xi64> + // CHECK-DAG: %[[W_SHAPE:.*]] = onnx.Constant dense<[64, 2, 3]> : tensor<3xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x10x1x1xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SUML:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + + // This block is an `onnx.EinSum` expanded by a different pattern rewrite + // CHECK-NEXT: %[[TRANSPOSE_ONE_HOT:.*]] = "onnx.Transpose"(%[[ONE_HOT_LABELS_F]]) {perm = [0, 2, 3, 1]} : ({{.*}}) -> tensor<64x2x3x10xf32> + // CHECK-NEXT: %[[COLLAPSED_ONE_SHOT:.*]] = "onnx.Reshape"(%[[TRANSPOSE_ONE_HOT]], %[[COLLAPSED_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<384x10xf32> + // CHECK-NEXT: %[[EXPANDED_WEIGHT:.*]] = "onnx.Reshape"(%[[ARG2]], %[[EXPANDED_WEIGHT_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<10x1xf32> + // CHECK-NEXT: %[[MATMUL:.*]] = "onnx.MatMul"(%[[COLLAPSED_ONE_SHOT]], %[[EXPANDED_WEIGHT]]) : ({{.*}}) -> tensor<384x1xf32> + // CHECK-NEXT: %[[W:.*]] = "onnx.Reshape"(%[[MATMUL]], %[[W_SHAPE]]) {allowzero = 0 : si64} : ({{.*}}) -> tensor<64x2x3xf32> + + // CHECK-NEXT: %[[SUMW:.*]] = "onnx.ReduceSum"(%[[W]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.Div"(%[[SUML]], %[[SUMW]]) + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_mean_with_weight_NCD1D2_dynamic_num_classes(%arg0: tensor<64x?x2x3xf32>, %arg1: tensor<64x2x3xi64>, %arg2: tensor) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "mean"} : (tensor<64x?x2x3xf32>, tensor<64x2x3xi64>, tensor) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_mean_with_weight_NCD1D2 + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = "onnx.Dim"(%[[ARG0]]) {axis = 1 : si64} : (tensor<64x?x2x3xf32>) -> tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<[0, 2, 3]> : tensor<3xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x?x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x?x1x1xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SUML:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[S_WEIGHTS:.*]] = "onnx.Einsum"(%[[ONE_HOT_LABELS_F]], %[[ARG2]]) {equation = "ij...,j->i..."} + // CHECK-NEXT: %[[SUMW:.*]] = "onnx.ReduceSum"(%[[S_WEIGHTS]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.Div"(%[[SUML]], %[[SUMW]]) + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_sum(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>) -> tensor { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "sum"} : (tensor<64x10xf32>, tensor<64xi64>, none) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_sum + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_sum_with_weight(%arg0: tensor<64x10xf32>, %arg1: tensor<64xi64>, %arg2: tensor<10xf32>) -> tensor { + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %arg2) {reduction = "sum"} : (tensor<64x10xf32>, tensor<64xi64>, tensor<10xf32>) -> (tensor, none) + onnx.Return %output : tensor + // CHECK-LABEL: func @sce_sum_with_weight + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) + // CHECK-DAG: %[[NONE:.*]] = "onnx.NoValue"() {value} + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[UNSQUEEZE_AXES:.*]] = onnx.Constant dense<0> : tensor<1xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[UNSQUEEZE_WEIGHT:.*]] = "onnx.Unsqueeze"(%[[ARG2]], %[[UNSQUEEZE_AXES]]) : ({{.*}}) -> tensor<1x10xf32> + // CHECK-NEXT: %[[WEIGHT_PROD:.*]] = "onnx.Mul"(%[[PROD]], %[[UNSQUEEZE_WEIGHT]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[WEIGHT_PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1xf32> + // CHECK-NEXT: %[[MEAN:.*]] = "onnx.ReduceSum"(%[[SUM]], %[[NONE]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[MEAN]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor +} + +// ----- + +func.func @sce_none(%arg0: tensor<64x10x2x3xf32>, %arg1: tensor<64x2x3xi64>) -> tensor<64x2x3xf32> { + %0 = "onnx.NoValue"() {value, weight} : () -> none + %output, %log_prob = "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1, %0) {reduction = "none"} : (tensor<64x10x2x3xf32>, tensor<64x2x3xi64>, none) -> (tensor<64x2x3xf32>, none) + onnx.Return %output : tensor<64x2x3xf32> + // CHECK-LABEL: func @sce_none + // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) + // CHECK-DAG: %[[NUM_CLASSES:.*]] = onnx.Constant dense<10> : tensor<1xi64> + // CHECK-DAG: %[[ONE_HOT_VALS:.*]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> + // CHECK-DAG: %[[REDUCE_AXIS:.*]] = onnx.Constant dense<1> : tensor<1xi64> + // CHECK: %[[ONE_HOT_LABELS:.*]] = "onnx.OneHot"(%[[ARG1]], %[[NUM_CLASSES]], %[[ONE_HOT_VALS]]) {axis = 1 : si64} : ({{.*}}) -> tensor<64x10x2x3xi64> + // CHECK-NEXT: %[[ONE_HOT_LABELS_F:.*]] = "onnx.Cast"(%[[ONE_HOT_LABELS]]) {saturate = 1 : si64, to = f32} + // CHECK-NEXT: %[[SOFTMAX:.*]] = "onnx.Softmax"(%[[ARG0]]) {axis = 1 : si64} + // CHECK-NEXT: %[[LOG_SOFTMAX:.*]] = "onnx.Log"(%[[SOFTMAX]]) + // CHECK-NEXT: %[[PROD:.*]] = "onnx.Mul"(%[[LOG_SOFTMAX]], %[[ONE_HOT_LABELS_F]]) + // CHECK-NEXT: %[[SUM:.*]] = "onnx.ReduceSum"(%[[PROD]], %[[REDUCE_AXIS]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : ({{.*}}) -> tensor<64x1x2x3xf32> + // CHECK-NEXT: %[[SQUEEZE:.*]] = "onnx.Squeeze"(%[[SUM]], %[[REDUCE_AXIS]]) : ({{.*}}) -> tensor<64x2x3xf32> + // CHECK-NEXT: %[[LOSS:.*]] = "onnx.Neg"(%[[SQUEEZE]]) + // CHECK-NEXT: onnx.Return %[[LOSS]] : tensor<64x2x3xf32> +} diff --git a/test/mlir/onnx/onnx_decompose_canonicalize.mlir b/test/mlir/onnx/onnx_decompose_canonicalize.mlir new file mode 100644 index 0000000000..a132445562 --- /dev/null +++ b/test/mlir/onnx/onnx_decompose_canonicalize.mlir @@ -0,0 +1,43 @@ + +// RUN: onnx-mlir-opt --decompose-onnx --canonicalize %s -split-input-file | FileCheck %s + +// ----- + +// Test one pattern in lstm_no_data.onnx. +// The type of output of SequenceAt is not the same as the element type +// of the input sequence +func.func @sequence_at_squeezed(%arg0 : tensor<1x1x100xf32>) -> tensor<1x100xf32> { + %26 = onnx.Constant dense<0> : tensor + %27 = onnx.Constant dense<1> : tensor + %32 = "onnx.SplitToSequence"(%arg0, %27) {axis = 0 : si64, keepdims = 0 : si64} : (tensor<1x1x100xf32>, tensor) -> !onnx.Seq> + %33 = "onnx.SequenceAt"(%32, %26) : (!onnx.Seq>, tensor) -> tensor<1x100xf32> + return %33: tensor<1x100xf32> +// CHECK-LABEL: func.func @sequence_at_squeezed +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x100xf32>) -> tensor<1x100xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.Split"([[PARAM_0_]], [[VAR_1_]]) {axis = 0 : si64} : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x1x100xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Squeeze"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x1x100xf32>, tensor<1xi64>) -> tensor<1x100xf32> +// CHECK: return [[VAR_3_]] : tensor<1x100xf32> +// CHECK: } +} + +func.func @sequence_at_multi(%arg0 : tensor<1x1x400xf32>) -> tensor<1x1x100xf32> { + %15 = onnx.Constant dense<0> : tensor + %38 = onnx.Constant dense<1> : tensor + %65 = onnx.Constant dense<100> : tensor + %66 = "onnx.SplitToSequence"(%arg0, %65) {axis = 2 : si64, keepdims = 1 : si64} : (tensor<1x1x400xf32>, tensor) -> !onnx.Seq> + %67 = "onnx.SequenceAt"(%66, %15) : (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + %68 = "onnx.SequenceAt"(%66, %38) : (!onnx.Seq>, tensor) -> tensor<1x1x100xf32> + %40 = "onnx.Add"(%67, %68) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32> + return %40: tensor<1x1x100xf32> +// CHECK-LABEL: func.func @sequence_at_multi +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x400xf32>) -> tensor<1x1x100xf32> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<100> : tensor<4xi64> +// CHECK-DAG: [[VAR_1_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>) +// CHECK-DAG: [[VAR_2_:%.+]]:4 = "onnx.Split"([[PARAM_0_]], [[VAR_0_]]) {axis = 2 : si64} : (tensor<1x1x400xf32>, tensor<4xi64>) -> (tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>, tensor<1x1x100xf32>) +// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_1_]]#0, [[VAR_2_]]#1) : (tensor<1x1x100xf32>, tensor<1x1x100xf32>) -> tensor<1x1x100xf32> +// CHECK: return [[VAR_3_]] : tensor<1x1x100xf32> +// CHECK: } +} + diff --git a/test/mlir/onnx/onnx_decompose_convtranspose.mlir b/test/mlir/onnx/onnx_decompose_convtranspose.mlir index 10912d7732..b88dd4713a 100644 --- a/test/mlir/onnx/onnx_decompose_convtranspose.mlir +++ b/test/mlir/onnx/onnx_decompose_convtranspose.mlir @@ -1,37 +1,29 @@ // RUN: onnx-mlir-opt --shape-inference --decompose-onnx %s -split-input-file | FileCheck %s + // ----- // Test unit strides. Only convert weight tensor func.func @test_convtrans_unitstrides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { -// CHECK-LABEL: func.func @test_convtrans_unitstrides -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) - %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32> onnx.Return %1 : tensor<1x2x5x5xf32> -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK-LABEL: func.func @test_convtrans_unitstrides +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_1_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_4_]], [[VAR_1_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> // CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x1x2xf32>) -> tensor<1x2x3x3xf32> // CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [1, 0, 2, 3]} : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_7_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2]} : (tensor<1x1x3x3xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x5x5xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> -// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Pad"([[VAR_8_]], [[VAR_9_]], [[VAR_10_]], [[VAR_11_]]) {mode = "constant"} : (tensor<1x2x5x5xf32>, tensor<8xi64>, none, none) -> tensor<1x2x5x5xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> -// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_16_:%.+]] = "onnx.Pad"([[VAR_12_]], [[VAR_13_]], [[VAR_14_]], [[VAR_15_]]) {mode = "constant"} : (tensor<1x2x5x5xf32>, tensor<8xi64>, none, none) -> tensor<1x2x5x5xf32> -// CHECK: onnx.Return [[VAR_16_]] : tensor<1x2x5x5xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_7_]], [[VAR_2_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2]} : (tensor<1x1x3x3xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x5x5xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Pad"([[VAR_8_]], [[VAR_0_]], [[VAR_2_]], [[VAR_2_]]) {mode = "constant"} : (tensor<1x2x5x5xf32>, tensor<8xi64>, none, none) -> tensor<1x2x5x5xf32> +// CHECK: [[VAR_10_:%.+]] = "onnx.Pad"([[VAR_9_]], [[VAR_0_]], [[VAR_2_]], [[VAR_2_]]) {mode = "constant"} : (tensor<1x2x5x5xf32>, tensor<8xi64>, none, none) -> tensor<1x2x5x5xf32> +// CHECK: onnx.Return [[VAR_10_]] : tensor<1x2x5x5xf32> } @@ -40,25 +32,22 @@ // Test 1d input func.func @test_convtrans1d_unitstrides(%arg0: tensor<1x1x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { -// CHECK-LABEL: func.func @test_convtrans1d_unitstrides -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3xf32>) - %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32> onnx.Return %1 : tensor<1x2x5xf32> -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 0, 1]} : (tensor<1x2x3xf32>) -> tensor<3x1x2xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> -// CHECK: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x1x2xf32>, tensor<1xi64>) -> tensor<3x1x2xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.Transpose"([[VAR_3_]]) {perm = [1, 2, 0]} : (tensor<3x1x2xf32>) -> tensor<1x2x3xf32> -// CHECK: [[VAR_5_:%.+]] = "onnx.Transpose"([[VAR_4_]]) {perm = [1, 0, 2]} : (tensor<1x2x3xf32>) -> tensor<2x1x3xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_5_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2]} : (tensor<1x1x3xf32>, tensor<2x1x3xf32>, none) -> tensor<1x2x5xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = onnx.Constant dense<0> : tensor<6xi64> -// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_10_:%.+]] = "onnx.Pad"([[VAR_6_]], [[VAR_7_]], [[VAR_8_]], [[VAR_9_]]) {mode = "constant"} : (tensor<1x2x5xf32>, tensor<6xi64>, none, none) -> tensor<1x2x5xf32> -// CHECK: onnx.Return [[VAR_10_]] : tensor<1x2x5xf32> +// CHECK-LABEL: func.func @test_convtrans1d_unitstrides +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<6xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 0, 1]} : (tensor<1x2x3xf32>) -> tensor<3x1x2xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_1_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x1x2xf32>, tensor<1xi64>) -> tensor<3x1x2xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Transpose"([[VAR_4_]]) {perm = [1, 2, 0]} : (tensor<3x1x2xf32>) -> tensor<1x2x3xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [1, 0, 2]} : (tensor<1x2x3xf32>) -> tensor<2x1x3xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_6_]], [[VAR_2_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2]} : (tensor<1x1x3xf32>, tensor<2x1x3xf32>, none) -> tensor<1x2x5xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Pad"([[VAR_7_]], [[VAR_0_]], [[VAR_2_]], [[VAR_2_]]) {mode = "constant"} : (tensor<1x2x5xf32>, tensor<6xi64>, none, none) -> tensor<1x2x5xf32> +// CHECK: onnx.Return [[VAR_8_]] : tensor<1x2x5xf32> } // ----- @@ -66,41 +55,28 @@ // Test 3d input func.func @test_convtrans3d_unitstrides(%arg0: tensor<1x1x3x4x5xf32>, %arg1: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { -// CHECK-LABEL: func.func @test_convtrans3d_unitstrides -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x4x5xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3x3xf32>) - %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> onnx.Return %1 : tensor<1x2x5x6x7xf32> -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 4, 0, 1]} : (tensor<1x2x3x3x3xf32>) -> tensor<3x3x3x1x2xf32> +// CHECK-LABEL: func.func @test_convtrans3d_unitstrides +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x4x5xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<10xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> // CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x3x1x2xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x3x1x2xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [2, 3, 4, 0, 1]} : (tensor<3x3x3x1x2xf32>) -> tensor<3x1x2x3x3xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = onnx.Constant dense<3> : tensor<1xi64> -// CHECK: [[VAR_8_:%.+]] = "onnx.ReverseSequence"([[VAR_6_]], [[VAR_7_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x1x2x3x3xf32>, tensor<1xi64>) -> tensor<3x1x2x3x3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 4, 0, 1]} : (tensor<1x2x3x3x3xf32>) -> tensor<3x3x3x1x2xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_4_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x3x1x2xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.ReverseSequence"([[VAR_5_]], [[VAR_2_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x3x1x2xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [2, 3, 4, 0, 1]} : (tensor<3x3x3x1x2xf32>) -> tensor<3x1x2x3x3xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.ReverseSequence"([[VAR_7_]], [[VAR_1_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x1x2x3x3xf32>, tensor<1xi64>) -> tensor<3x1x2x3x3xf32> // CHECK: [[VAR_9_:%.+]] = "onnx.Transpose"([[VAR_8_]]) {perm = [1, 2, 3, 4, 0]} : (tensor<3x1x2x3x3xf32>) -> tensor<1x2x3x3x3xf32> // CHECK: [[VAR_10_:%.+]] = "onnx.Transpose"([[VAR_9_]]) {perm = [1, 0, 2, 3, 4]} : (tensor<1x2x3x3x3xf32>) -> tensor<2x1x3x3x3xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_10_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2, 2, 2]} : (tensor<1x1x3x4x5xf32>, tensor<2x1x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = onnx.Constant dense<0> : tensor<10xi64> -// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.Pad"([[VAR_11_]], [[VAR_12_]], [[VAR_13_]], [[VAR_14_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> -// CHECK-DAG: [[VAR_16_:%.+]] = onnx.Constant dense<0> : tensor<10xi64> -// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.Pad"([[VAR_15_]], [[VAR_16_]], [[VAR_17_]], [[VAR_18_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> -// CHECK-DAG: [[VAR_20_:%.+]] = onnx.Constant dense<0> : tensor<10xi64> -// CHECK-DAG: [[VAR_21_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_22_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_23_:%.+]] = "onnx.Pad"([[VAR_19_]], [[VAR_20_]], [[VAR_21_]], [[VAR_22_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> -// CHECK: onnx.Return [[VAR_23_]] : tensor<1x2x5x6x7xf32> +// CHECK: [[VAR_11_:%.+]] = "onnx.Conv"([[PARAM_0_]], [[VAR_10_]], [[VAR_3_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2, 2, 2]} : (tensor<1x1x3x4x5xf32>, tensor<2x1x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> +// CHECK: [[VAR_12_:%.+]] = "onnx.Pad"([[VAR_11_]], [[VAR_0_]], [[VAR_3_]], [[VAR_3_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> +// CHECK: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_12_]], [[VAR_0_]], [[VAR_3_]], [[VAR_3_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> +// CHECK: [[VAR_14_:%.+]] = "onnx.Pad"([[VAR_13_]], [[VAR_0_]], [[VAR_3_]], [[VAR_3_]]) {mode = "constant"} : (tensor<1x2x5x6x7xf32>, tensor<10xi64>, none, none) -> tensor<1x2x5x6x7xf32> +// CHECK: onnx.Return [[VAR_14_]] : tensor<1x2x5x6x7xf32> } // ----- @@ -108,59 +84,36 @@ // Test non unit strides. Added pads between elements in input data. func.func @test_convtrans_strides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { -// CHECK-LABEL: func.func @test_convtrans_strides -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) - %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32> onnx.Return %1 : tensor<1x2x7x3xf32> -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK-LABEL: func.func @test_convtrans_strides +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> // CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> -// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x1x2xf32>) -> tensor<1x2x3x3xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [1, 0, 2, 3]} : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_8_]]) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<3xi64>) -> (tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>) -// CHECK-DAG: [[VAR_10_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_9_]]#0, [[VAR_10_]], [[VAR_11_]], [[VAR_12_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_9_]]#1, [[VAR_14_]], [[VAR_15_]], [[VAR_16_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_13_]], [[VAR_17_]], [[VAR_9_]]#2) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>, tensor<1x1x1x3xf32>) -> tensor<1x1x7x3xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:3 = "onnx.Split"([[VAR_18_]], [[VAR_19_]]) {axis = 3 : si64} : (tensor<1x1x7x3xf32>, tensor<3xi64>) -> (tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>) -// CHECK-DAG: [[VAR_21_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_22_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Pad"([[VAR_20_]]#0, [[VAR_21_]], [[VAR_22_]], [[VAR_23_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> -// CHECK-DAG: [[VAR_25_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_28_:%.+]] = "onnx.Pad"([[VAR_20_]]#1, [[VAR_25_]], [[VAR_26_]], [[VAR_27_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> -// CHECK: [[VAR_29_:%.+]] = "onnx.Concat"([[VAR_24_]], [[VAR_28_]], [[VAR_20_]]#2) {axis = 3 : si64} : (tensor<1x1x7x2xf32>, tensor<1x1x7x2xf32>, tensor<1x1x7x1xf32>) -> tensor<1x1x7x5xf32> -// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Conv"([[VAR_29_]], [[VAR_7_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 0, 1, 0]} : (tensor<1x1x7x5xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x7x3xf32> -// CHECK-DAG: [[VAR_31_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> -// CHECK-DAG: [[VAR_32_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_33_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_:%.+]] = "onnx.Pad"([[VAR_30_]], [[VAR_31_]], [[VAR_32_]], [[VAR_33_]]) {mode = "constant"} : (tensor<1x2x7x3xf32>, tensor<8xi64>, none, none) -> tensor<1x2x7x3xf32> -// CHECK-DAG: [[VAR_35_:%.+]] = onnx.Constant dense<0> : tensor<8xi64> -// CHECK-DAG: [[VAR_36_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_37_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_38_:%.+]] = "onnx.Pad"([[VAR_34_]], [[VAR_35_]], [[VAR_36_]], [[VAR_37_]]) {mode = "constant"} : (tensor<1x2x7x3xf32>, tensor<8xi64>, none, none) -> tensor<1x2x7x3xf32> -// CHECK: onnx.Return [[VAR_38_]] : tensor<1x2x7x3xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.ReverseSequence"([[VAR_6_]], [[VAR_4_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.ReverseSequence"([[VAR_7_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Transpose"([[VAR_8_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x1x2xf32>) -> tensor<1x2x3x3xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Transpose"([[VAR_9_]]) {perm = [1, 0, 2, 3]} : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32> +// CHECK-DAG: [[VAR_11_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<3xi64>) -> (tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Pad"([[VAR_11_]]#0, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_11_]]#1, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> +// CHECK: [[VAR_14_:%.+]] = "onnx.Concat"([[VAR_12_]], [[VAR_13_]], [[VAR_11_]]#2) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>, tensor<1x1x1x3xf32>) -> tensor<1x1x7x3xf32> +// CHECK: [[VAR_15_:%.+]]:3 = "onnx.Split"([[VAR_14_]], [[VAR_3_]]) {axis = 3 : si64} : (tensor<1x1x7x3xf32>, tensor<3xi64>) -> (tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>) +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Pad"([[VAR_15_]]#0, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_15_]]#1, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_15_]]#2) {axis = 3 : si64} : (tensor<1x1x7x2xf32>, tensor<1x1x7x2xf32>, tensor<1x1x7x1xf32>) -> tensor<1x1x7x5xf32> +// CHECK: [[VAR_19_:%.+]] = "onnx.Conv"([[VAR_18_]], [[VAR_10_]], [[VAR_5_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 0, 1, 0]} : (tensor<1x1x7x5xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x7x3xf32> +// CHECK: [[VAR_20_:%.+]] = "onnx.Pad"([[VAR_19_]], [[VAR_0_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x2x7x3xf32>, tensor<8xi64>, none, none) -> tensor<1x2x7x3xf32> +// CHECK: [[VAR_21_:%.+]] = "onnx.Pad"([[VAR_20_]], [[VAR_0_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x2x7x3xf32>, tensor<8xi64>, none, none) -> tensor<1x2x7x3xf32> +// CHECK: onnx.Return [[VAR_21_]] : tensor<1x2x7x3xf32> } // ----- @@ -168,59 +121,36 @@ // Test output_padding. Additional pads are inserted after Conv op func.func @test_convtrans_outputpadding(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { -// CHECK-LABEL: func.func @test_convtrans_outputpadding -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) - %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32> onnx.Return %1 : tensor<1x2x10x8xf32> -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK-LABEL: func.func @test_convtrans_outputpadding +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x3x3xf32>, [[PARAM_1_:%.+]]: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 1, 0]> : tensor<8xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> // CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> -// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x1x2xf32>) -> tensor<1x2x3x3xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [1, 0, 2, 3]} : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_8_]]) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<3xi64>) -> (tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>) -// CHECK-DAG: [[VAR_10_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_9_]]#0, [[VAR_10_]], [[VAR_11_]], [[VAR_12_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_9_]]#1, [[VAR_14_]], [[VAR_15_]], [[VAR_16_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_13_]], [[VAR_17_]], [[VAR_9_]]#2) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>, tensor<1x1x1x3xf32>) -> tensor<1x1x7x3xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:3 = "onnx.Split"([[VAR_18_]], [[VAR_19_]]) {axis = 3 : si64} : (tensor<1x1x7x3xf32>, tensor<3xi64>) -> (tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>) -// CHECK-DAG: [[VAR_21_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_22_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Pad"([[VAR_20_]]#0, [[VAR_21_]], [[VAR_22_]], [[VAR_23_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> -// CHECK-DAG: [[VAR_25_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_28_:%.+]] = "onnx.Pad"([[VAR_20_]]#1, [[VAR_25_]], [[VAR_26_]], [[VAR_27_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> -// CHECK: [[VAR_29_:%.+]] = "onnx.Concat"([[VAR_24_]], [[VAR_28_]], [[VAR_20_]]#2) {axis = 3 : si64} : (tensor<1x1x7x2xf32>, tensor<1x1x7x2xf32>, tensor<1x1x7x1xf32>) -> tensor<1x1x7x5xf32> -// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Conv"([[VAR_29_]], [[VAR_7_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2]} : (tensor<1x1x7x5xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x9x7xf32> -// CHECK-DAG: [[VAR_31_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 1, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_32_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_33_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_:%.+]] = "onnx.Pad"([[VAR_30_]], [[VAR_31_]], [[VAR_32_]], [[VAR_33_]]) {mode = "constant"} : (tensor<1x2x9x7xf32>, tensor<8xi64>, none, none) -> tensor<1x2x10x7xf32> -// CHECK-DAG: [[VAR_35_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_36_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_37_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_38_:%.+]] = "onnx.Pad"([[VAR_34_]], [[VAR_35_]], [[VAR_36_]], [[VAR_37_]]) {mode = "constant"} : (tensor<1x2x10x7xf32>, tensor<8xi64>, none, none) -> tensor<1x2x10x8xf32> -// CHECK: onnx.Return [[VAR_38_]] : tensor<1x2x10x8xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor<1x2x3x3xf32>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.ReverseSequence"([[VAR_6_]], [[VAR_4_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.ReverseSequence"([[VAR_7_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x1x2xf32>, tensor<3xi64>) -> tensor<3x3x1x2xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Transpose"([[VAR_8_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x1x2xf32>) -> tensor<1x2x3x3xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Transpose"([[VAR_9_]]) {perm = [1, 0, 2, 3]} : (tensor<1x2x3x3xf32>) -> tensor<2x1x3x3xf32> +// CHECK-DAG: [[VAR_11_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<3xi64>) -> (tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>, tensor<1x1x1x3xf32>) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Pad"([[VAR_11_]]#0, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_11_]]#1, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x1x3xf32>, tensor<8xi64>, none, none) -> tensor<1x1x3x3xf32> +// CHECK: [[VAR_14_:%.+]] = "onnx.Concat"([[VAR_12_]], [[VAR_13_]], [[VAR_11_]]#2) {axis = 2 : si64} : (tensor<1x1x3x3xf32>, tensor<1x1x3x3xf32>, tensor<1x1x1x3xf32>) -> tensor<1x1x7x3xf32> +// CHECK: [[VAR_15_:%.+]]:3 = "onnx.Split"([[VAR_14_]], [[VAR_3_]]) {axis = 3 : si64} : (tensor<1x1x7x3xf32>, tensor<3xi64>) -> (tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>, tensor<1x1x7x1xf32>) +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Pad"([[VAR_15_]]#0, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_15_]]#1, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x1x7x1xf32>, tensor<8xi64>, none, none) -> tensor<1x1x7x2xf32> +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_15_]]#2) {axis = 3 : si64} : (tensor<1x1x7x2xf32>, tensor<1x1x7x2xf32>, tensor<1x1x7x1xf32>) -> tensor<1x1x7x5xf32> +// CHECK: [[VAR_19_:%.+]] = "onnx.Conv"([[VAR_18_]], [[VAR_10_]], [[VAR_5_]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [2, 2, 2, 2]} : (tensor<1x1x7x5xf32>, tensor<2x1x3x3xf32>, none) -> tensor<1x2x9x7xf32> +// CHECK: [[VAR_20_:%.+]] = "onnx.Pad"([[VAR_19_]], [[VAR_0_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x2x9x7xf32>, tensor<8xi64>, none, none) -> tensor<1x2x10x7xf32> +// CHECK: [[VAR_21_:%.+]] = "onnx.Pad"([[VAR_20_]], [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor<1x2x10x7xf32>, tensor<8xi64>, none, none) -> tensor<1x2x10x8xf32> +// CHECK: onnx.Return [[VAR_21_]] : tensor<1x2x10x8xf32> } // ----- @@ -228,56 +158,34 @@ // Test for unknown dimension in spatial dimensions func.func @test_convtranspose_unknown_spatial_dim(%arg0: tensor, %arg1: tensor) -> tensor { -// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor, tensor, none) -> tensor onnx.Return %1 : tensor -// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor) -> tensor<3x3x?x?xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReverseSequence"([[VAR_1_]], [[VAR_2_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x?x?xf32>, tensor<3xi64>) -> tensor<3x3x?x?xf32> +// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 1, 0]> : tensor<8xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> // CHECK-DAG: [[VAR_4_:%.+]] = onnx.Constant dense<3> : tensor<3xi64> -// CHECK: [[VAR_5_:%.+]] = "onnx.ReverseSequence"([[VAR_3_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x?x?xf32>, tensor<3xi64>) -> tensor<3x3x?x?xf32> -// CHECK: [[VAR_6_:%.+]] = "onnx.Transpose"([[VAR_5_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x?x?xf32>) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [1, 0, 2, 3]} : (tensor) -> tensor -// CHECK-DAG: [[VAR_8_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_8_]]) {axis = 2 : si64} : (tensor, tensor<3xi64>) -> (tensor, tensor, tensor) -// CHECK-DAG: [[VAR_10_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_11_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_9_]]#0, [[VAR_10_]], [[VAR_11_]], [[VAR_12_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK-DAG: [[VAR_14_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 2, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_9_]]#1, [[VAR_14_]], [[VAR_15_]], [[VAR_16_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_13_]], [[VAR_17_]], [[VAR_9_]]#2) {axis = 2 : si64} : (tensor, tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_19_:%.+]] = onnx.Constant dense<1> : tensor<3xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:3 = "onnx.Split"([[VAR_18_]], [[VAR_19_]]) {axis = 3 : si64} : (tensor, tensor<3xi64>) -> (tensor, tensor, tensor) -// CHECK-DAG: [[VAR_21_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_22_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Pad"([[VAR_20_]]#0, [[VAR_21_]], [[VAR_22_]], [[VAR_23_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK-DAG: [[VAR_25_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_28_:%.+]] = "onnx.Pad"([[VAR_20_]]#1, [[VAR_25_]], [[VAR_26_]], [[VAR_27_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK: [[VAR_29_:%.+]] = "onnx.Concat"([[VAR_24_]], [[VAR_28_]], [[VAR_20_]]#2) {axis = 3 : si64} : (tensor, tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Conv"([[VAR_29_]], [[VAR_7_]], [[VAR_0_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], pads = [2, 2, 2, 2]} : (tensor, tensor, none) -> tensor -// CHECK-DAG: [[VAR_31_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 1, 0]> : tensor<8xi64> -// CHECK-DAG: [[VAR_32_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_33_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_34_:%.+]] = "onnx.Pad"([[VAR_30_]], [[VAR_31_]], [[VAR_32_]], [[VAR_33_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK-DAG: [[VAR_35_:%.+]] = onnx.Constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi64> -// CHECK-DAG: [[VAR_36_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_37_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[VAR_38_:%.+]] = "onnx.Pad"([[VAR_34_]], [[VAR_35_]], [[VAR_36_]], [[VAR_37_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor -// CHECK: onnx.Return [[VAR_38_]] : tensor +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [2, 3, 0, 1]} : (tensor) -> tensor<3x3x?x?xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.ReverseSequence"([[VAR_6_]], [[VAR_4_]]) {batch_axis = 1 : si64, time_axis = 0 : si64} : (tensor<3x3x?x?xf32>, tensor<3xi64>) -> tensor<3x3x?x?xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.ReverseSequence"([[VAR_7_]], [[VAR_4_]]) {batch_axis = 0 : si64, time_axis = 1 : si64} : (tensor<3x3x?x?xf32>, tensor<3xi64>) -> tensor<3x3x?x?xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Transpose"([[VAR_8_]]) {perm = [2, 3, 0, 1]} : (tensor<3x3x?x?xf32>) -> tensor +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Transpose"([[VAR_9_]]) {perm = [1, 0, 2, 3]} : (tensor) -> tensor +// CHECK-DAG: [[VAR_11_:%.+]]:3 = "onnx.Split"([[PARAM_0_]], [[VAR_3_]]) {axis = 2 : si64} : (tensor, tensor<3xi64>) -> (tensor, tensor, tensor) +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Pad"([[VAR_11_]]#0, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Pad"([[VAR_11_]]#1, [[VAR_2_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK: [[VAR_14_:%.+]] = "onnx.Concat"([[VAR_12_]], [[VAR_13_]], [[VAR_11_]]#2) {axis = 2 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_15_:%.+]]:3 = "onnx.Split"([[VAR_14_]], [[VAR_3_]]) {axis = 3 : si64} : (tensor, tensor<3xi64>) -> (tensor, tensor, tensor) +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Pad"([[VAR_15_]]#0, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Pad"([[VAR_15_]]#1, [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK: [[VAR_18_:%.+]] = "onnx.Concat"([[VAR_16_]], [[VAR_17_]], [[VAR_15_]]#2) {axis = 3 : si64} : (tensor, tensor, tensor) -> tensor +// CHECK: [[VAR_19_:%.+]] = "onnx.Conv"([[VAR_18_]], [[VAR_10_]], [[VAR_5_]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], pads = [2, 2, 2, 2]} : (tensor, tensor, none) -> tensor +// CHECK: [[VAR_20_:%.+]] = "onnx.Pad"([[VAR_19_]], [[VAR_0_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK: [[VAR_21_:%.+]] = "onnx.Pad"([[VAR_20_]], [[VAR_1_]], [[VAR_5_]], [[VAR_5_]]) {mode = "constant"} : (tensor, tensor<8xi64>, none, none) -> tensor +// CHECK: onnx.Return [[VAR_21_]] : tensor } diff --git a/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir b/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir new file mode 100644 index 0000000000..b31c6b28a6 --- /dev/null +++ b/test/mlir/onnx/onnx_decompose_convtranspose_disable.mlir @@ -0,0 +1,104 @@ +// RUN: onnx-mlir-opt --shape-inference --decompose-onnx --disable-convtranspose-decompose %s -split-input-file | FileCheck %s + + +// ----- + +// Test unit strides. Only convert weight tensor + + func.func @test_convtrans_unitstrides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32> + onnx.Return %1 : tensor<1x2x5x5xf32> +// CHECK-LABEL: func.func @test_convtrans_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x5x5xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x5x5xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x5xf32> +// CHECK: } + } + +// ----- + +// Test 1d input + + func.func @test_convtrans1d_unitstrides(%arg0: tensor<1x1x3xf32>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32> + onnx.Return %1 : tensor<1x2x5xf32> +// CHECK-LABEL: func.func @test_convtrans1d_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3xf32>) -> tensor<1x2x5xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3xf32>, tensor<1x2x3xf32>, none) -> tensor<1x2x5xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5xf32> +// CHECK: } + } + +// ----- + +// Test 3d input + + func.func @test_convtrans3d_unitstrides(%arg0: tensor<1x1x3x4x5xf32>, %arg1: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> + onnx.Return %1 : tensor<1x2x5x6x7xf32> +// CHECK-LABEL: func.func @test_convtrans3d_unitstrides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x4x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3x3xf32>) -> tensor<1x2x5x6x7xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64} : (tensor<1x1x3x4x5xf32>, tensor<1x2x3x3x3xf32>, none) -> tensor<1x2x5x6x7xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x5x6x7xf32> +// CHECK: } + } + +// ----- + +// Test non unit strides. Added pads between elements in input data. + + func.func @test_convtrans_strides(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32> + onnx.Return %1 : tensor<1x2x7x3xf32> +// CHECK-LABEL: func.func @test_convtrans_strides( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x7x3xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, pads = [1, 2, 1, 2], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x7x3xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x7x3xf32> +// CHECK: } + } + +// ----- + +// Test output_padding. Additional pads are inserted after Conv op + + func.func @test_convtrans_outputpadding(%arg0: tensor<1x1x3x3xf32>, %arg1: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32> + onnx.Return %1 : tensor<1x2x10x8xf32> +// CHECK-LABEL: func.func @test_convtrans_outputpadding( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x3xf32>) -> tensor<1x2x10x8xf32> { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, output_shape = [10, 8], strides = [3, 2]} : (tensor<1x1x3x3xf32>, tensor<1x2x3x3xf32>, none) -> tensor<1x2x10x8xf32> +// CHECK: onnx.Return %[[VAL_3]] : tensor<1x2x10x8xf32> +// CHECK: } + } + +// ----- + +// Test for unknown dimension in spatial dimensions + + func.func @test_convtranspose_unknown_spatial_dim(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ConvTranspose"(%arg0, %arg1, %0) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor, tensor, none) -> tensor + onnx.Return %1 : tensor +// CHECK-LABEL: func.func @test_convtranspose_unknown_spatial_dim( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = "onnx.NoValue"() {value} : () -> none +// CHECK: %[[VAL_3:.*]] = "onnx.ConvTranspose"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {auto_pad = "NOTSET", group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "test", output_padding = [1, 1], output_shape = [10, 8], strides = [3, 2]} : (tensor, tensor, none) -> tensor +// CHECK: onnx.Return %[[VAL_3]] : tensor +// CHECK: } + } diff --git a/test/mlir/onnx/onnx_decompose_customop.mlir b/test/mlir/onnx/onnx_decompose_customop.mlir index 59430228ac..46f7ec703a 100644 --- a/test/mlir/onnx/onnx_decompose_customop.mlir +++ b/test/mlir/onnx/onnx_decompose_customop.mlir @@ -40,13 +40,12 @@ func.func @customop_fusedmatmul_onnxruntime_transA(%arg0: tensor<*xf32>, %arg1:t // CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime_transA // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[VAR_0_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32> -// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_1_]], [[PARAM_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_2_]], [[PARAM_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> // CHECK: onnx.Return [[VAR_4_]] : tensor<*xf32> -// CHECK: } } // ----- @@ -58,13 +57,12 @@ func.func @customop_fusedmatmul_onnxruntime_transB(%arg0: tensor<*xf32>, %arg1:t // CHECK-LABEL: func.func @customop_fusedmatmul_onnxruntime_transB // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>, [[PARAM_1_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { -// CHECK: [[VAR_0_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[VAR_0_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32> -// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> -// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.250000e-01> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1, 3]} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [0, 1, 3, 2]} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_2_]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> // CHECK: onnx.Return [[VAR_4_]] : tensor<*xf32> -// CHECK: } } // ----- diff --git a/test/mlir/onnx/onnx_decompose_einsum.mlir b/test/mlir/onnx/onnx_decompose_einsum.mlir index f21c4cec05..d8b1ed544c 100644 --- a/test/mlir/onnx/onnx_decompose_einsum.mlir +++ b/test/mlir/onnx/onnx_decompose_einsum.mlir @@ -19,19 +19,17 @@ func.func @test_einsum_matmul(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x4x5xf32> func.func @test_einsum_matmul_broadcast(%arg0: tensor<2x3x1xf32>, %arg1: tensor<1x4x5xf32>) -> tensor<2x3x5xf32> { %0 = "onnx.Einsum"(%arg0, %arg1) {equation = "...ij,...jk"} : (tensor<2x3x1xf32>, tensor<1x4x5xf32>) -> tensor<2x3x5xf32> onnx.Return %0 : tensor<2x3x5xf32> -// CHECK-LABEL: func @test_einsum_matmul_broadcast +// CHECK-LABEL: func.func @test_einsum_matmul_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x3x1xf32>, [[PARAM_1_:%.+]]: tensor<1x4x5xf32>) -> tensor<2x3x5xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[VAR_0_]]) : (tensor<2x3x1xf32>, tensor<1xi64>) -> tensor<2x3xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReduceSum"([[PARAM_1_]], [[VAR_2_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x4x5xf32>, tensor<1xi64>) -> tensor<1x5xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Transpose"([[VAR_1_]]) {perm = [1, 0]} : (tensor<2x3xf32>) -> tensor<3x2xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> -// CHECK-NEXT: [[VAR_6_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_5_]]) : (tensor<3x2xf32>, tensor<1xi64>) -> tensor<3x2x1xf32> -// CHECK-NEXT: [[VAR_7_:%.+]] = "onnx.Mul"([[VAR_6_]], [[VAR_3_]]) : (tensor<3x2x1xf32>, tensor<1x5xf32>) -> tensor<3x2x5xf32> -// CHECK-NEXT: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_7_]]) {perm = [1, 0, 2]} : (tensor<3x2x5xf32>) -> tensor<2x3x5xf32> -// CHECK-NEXT: onnx.Return [[VAR_8_]] : tensor<2x3x5xf32> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<2> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Squeeze"([[PARAM_0_]], [[VAR_1_]]) : (tensor<2x3x1xf32>, tensor<1xi64>) -> tensor<2x3xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.ReduceSum"([[PARAM_1_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x4x5xf32>, tensor<1xi64>) -> tensor<1x5xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Transpose"([[VAR_2_]]) {perm = [1, 0]} : (tensor<2x3xf32>) -> tensor<3x2xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Unsqueeze"([[VAR_4_]], [[VAR_1_]]) : (tensor<3x2xf32>, tensor<1xi64>) -> tensor<3x2x1xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Mul"([[VAR_5_]], [[VAR_3_]]) : (tensor<3x2x1xf32>, tensor<1x5xf32>) -> tensor<3x2x5xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Transpose"([[VAR_6_]]) {perm = [1, 0, 2]} : (tensor<3x2x5xf32>) -> tensor<2x3x5xf32> +// CHECK: onnx.Return [[VAR_7_]] : tensor<2x3x5xf32> } // ----- @@ -87,14 +85,14 @@ func.func @test_einsum_diagonal(%arg0: tensor<3x3xf32>) -> tensor<3xf32> { %0 = "onnx.Einsum"(%arg0) {equation = "ii->i"} : (tensor<3x3xf32>) -> tensor<3xf32> onnx.Return %0 : tensor<3xf32> - // CHECK-LABEL: func @test_einsum_diagonal - // CHECK-SAME: ([[PARAM_0:%.+]]: tensor<3x3xf32>) -> tensor<3xf32> { - // CHECK-NEXT: [[MASK:%.+]] = onnx.Constant dense<{{\[\[true, false, false\], \[false, true, false\], \[false, false, true\]\]}}> : tensor<3x3xi1> - // CHECK-NEXT: [[ZERO:%.+]] = onnx.Constant dense<0.000000e+00> : tensor - // CHECK-NEXT: [[WHER:%.+]] = "onnx.Where"([[MASK]], [[PARAM_0]], [[ZERO]]) : (tensor<3x3xi1>, tensor<3x3xf32>, tensor) -> tensor<3x3xf32> - // CHECK-NEXT: [[AXES:%.+]] = onnx.Constant dense<1> : tensor<1xi64> - // CHECK-NEXT: [[RSUM:%.+]] = "onnx.ReduceSum"([[WHER]], [[AXES]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x3xf32>, tensor<1xi64>) -> tensor<3xf32> - // CHECK-NEXT: onnx.Return [[RSUM]] : tensor<3xf32> +// CHECK-LABEL: func.func @test_einsum_diagonal +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x3xf32>) -> tensor<3xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<{{\[\[true, false, false\], \[false, true, false\], \[false, false, true\]\]}}> : tensor<3x3xi1> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK: [[VAR_3_:%.+]] = "onnx.Where"([[VAR_1_]], [[PARAM_0_]], [[VAR_2_]]) : (tensor<3x3xi1>, tensor<3x3xf32>, tensor) -> tensor<3x3xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.ReduceSum"([[VAR_3_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x3xf32>, tensor<1xi64>) -> tensor<3xf32> +// CHECK: onnx.Return [[VAR_4_]] : tensor<3xf32> } // ----- @@ -102,17 +100,15 @@ func.func @test_einsum_diagonal(%arg0: tensor<3x3xf32>) -> tensor<3xf32> { func.func @test_einsum_trace(%arg0: tensor<3x3xf32>) -> tensor { %0 = "onnx.Einsum"(%arg0) {equation = "ii"} : (tensor<3x3xf32>) -> tensor onnx.Return %0 : tensor -// CHECK-LABEL: func @test_einsum_trace +// CHECK-LABEL: func.func @test_einsum_trace // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x3xf32>) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<{{\[\[true, false, false\], \[false, true, false\], \[false, false, true\]\]}}> : tensor<3x3xi1> -// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Where"([[VAR_0_]], [[PARAM_0_]], [[VAR_1_]]) : (tensor<3x3xi1>, tensor<3x3xf32>, tensor) -> tensor<3x3xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.ReduceSum"([[VAR_2_]], [[VAR_3_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x3xf32>, tensor<1xi64>) -> tensor<3xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> -// CHECK: [[VAR_6_:%.+]] = "onnx.ReduceSum"([[VAR_4_]], [[VAR_5_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3xf32>, tensor<1xi64>) -> tensor +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<{{\[\[true, false, false\], \[false, true, false\], \[false, false, true\]\]}}> : tensor<3x3xi1> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor +// CHECK: [[VAR_4_:%.+]] = "onnx.Where"([[VAR_2_]], [[PARAM_0_]], [[VAR_3_]]) : (tensor<3x3xi1>, tensor<3x3xf32>, tensor) -> tensor<3x3xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.ReduceSum"([[VAR_4_]], [[VAR_1_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3x3xf32>, tensor<1xi64>) -> tensor<3xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.ReduceSum"([[VAR_5_]], [[VAR_0_]]) {keepdims = 0 : si64, noop_with_empty_axes = 0 : si64} : (tensor<3xf32>, tensor<1xi64>) -> tensor // CHECK: onnx.Return [[VAR_6_]] : tensor } @@ -123,13 +119,13 @@ func.func @test_einsum_ibh_hnd(%arg0: tensor<128x1x1024xf16>, %arg1: tensor<1024 onnx.Return %0 : tensor<128x1x16x64xf16> // CHECK-LABEL: func.func @test_einsum_ibh_hnd // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<128x1x1024xf16>, [[PARAM_1_:%.+]]: tensor<1024x16x64xf16>) -> tensor<128x1x16x64xf16> { -// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[128, 1024]> : tensor<2xi64> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<128x1x1024xf16>, tensor<2xi64>) -> tensor<128x1024xf16> -// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1024> : tensor<2xi64> -// CHECK: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<1024x16x64xf16>, tensor<2xi64>) -> tensor<1024x1024xf16> -// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.MatMul"([[VAR_1_]], [[VAR_3_]]) : (tensor<128x1024xf16>, tensor<1024x1024xf16>) -> tensor<128x1024xf16> -// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<[128, 1, 16, 64]> : tensor<4xi64> -// CHECK: [[VAR_6_:%.+]] = "onnx.Reshape"([[VAR_4_]], [[VAR_5_]]) {allowzero = 0 : si64} : (tensor<128x1024xf16>, tensor<4xi64>) -> tensor<128x1x16x64xf16> +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[128, 1, 16, 64]> : tensor<4xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1024> : tensor<2xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[128, 1024]> : tensor<2xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_2_]]) {allowzero = 0 : si64} : (tensor<128x1x1024xf16>, tensor<2xi64>) -> tensor<128x1024xf16> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Reshape"([[PARAM_1_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1024x16x64xf16>, tensor<2xi64>) -> tensor<1024x1024xf16> +// CHECK: [[VAR_5_:%.+]] = "onnx.MatMul"([[VAR_3_]], [[VAR_4_]]) : (tensor<128x1024xf16>, tensor<1024x1024xf16>) -> tensor<128x1024xf16> +// CHECK: [[VAR_6_:%.+]] = "onnx.Reshape"([[VAR_5_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<128x1024xf16>, tensor<4xi64>) -> tensor<128x1x16x64xf16> // CHECK: onnx.Return [[VAR_6_]] : tensor<128x1x16x64xf16> } diff --git a/test/mlir/onnx/onnx_dim_analysis.mlir b/test/mlir/onnx/onnx_dim_analysis.mlir index d0459f07ff..74f51c0f65 100644 --- a/test/mlir/onnx/onnx_dim_analysis.mlir +++ b/test/mlir/onnx/onnx_dim_analysis.mlir @@ -184,38 +184,38 @@ func.func @test_matmul_batchsize(%arg0: tensor) -> tensor) -> tensor<8x?x16x32xf32> { +func.func @test_matmul_batchsize_diff_rank(%arg0: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> { %shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> - %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor - %1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x32xf32> - "onnx.Return"(%1) : (tensor<8x?x16x32xf32>) -> () + %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor + %1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x128xf32> + "onnx.Return"(%1) : (tensor<8x?x16x128xf32>) -> () // CHECK-LABEL: func.func @test_matmul_batchsize_diff_rank -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x32xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> () // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> -// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x32xf32> -// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x32xf32>) -> () -// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x32xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor +// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x128xf32> +// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x128xf32>) -> () +// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x128xf32> // CHECK: } } // ----- -func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor { +func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor { %shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> - %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor - "onnx.Return"(%0) : (tensor) -> () + %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor + "onnx.Return"(%0) : (tensor) -> () // CHECK-LABEL: func.func @test_reshape_single_dyn_dim -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> () // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> -// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: onnx.Return [[VAR_1_]] : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor +// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: onnx.Return [[VAR_1_]] : tensor // CHECK: } } diff --git a/test/mlir/onnx/onnx_fold.mlir b/test/mlir/onnx/onnx_fold.mlir index 6913cfc638..fc4cf82fcc 100644 --- a/test/mlir/onnx/onnx_fold.mlir +++ b/test/mlir/onnx/onnx_fold.mlir @@ -30,3 +30,17 @@ func.func @test_squeezev11() -> tensor<*xf32> { // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[4.000000e+00, 1.600000e+01]> : tensor<2xf32> // CHECK: onnx.Return [[VAR_0_]] : tensor<2xf32> } + + +// ----- + +func.func @test_reduceMeanIsNoopWithEmptyAxes(%arg0: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.ReduceMean"(%arg0, %0) {noop_with_empty_axes = 1: si64} : (tensor<4x512x256x8xf32>, none) -> tensor<4x512x256x8xf32> + return %1 : tensor<4x512x256x8xf32> +} + +// CHECK-LABEL: @test_reduceMeanIsNoopWithEmptyAxes +// CHECK-SAME: (%[[VAL_0:.*]]: tensor<4x512x256x8xf32>) -> tensor<4x512x256x8xf32> { +// CHECK: return %[[VAL_0]] : tensor<4x512x256x8xf32> +// CHECK: } \ No newline at end of file diff --git a/test/mlir/onnx/onnx_fuse_add_assertion_fail.mlir b/test/mlir/onnx/onnx_fuse_add_assertion_fail.mlir new file mode 100644 index 0000000000..9d09159524 --- /dev/null +++ b/test/mlir/onnx/onnx_fuse_add_assertion_fail.mlir @@ -0,0 +1,21 @@ +// RUN: onnx-mlir-opt --shape-inference --canonicalize="test-convergence=true" --shape-inference --cse %s | FileCheck %s + +func.func @main_graph(%arg0: tensor<1x180x320x3xf32> ) -> (tensor<1x16x90x160xf32> {onnx.name = "r3o"}) { + %0 = onnx.Constant dense<0.1> : tensor<3x1x1xf32> + %2 = onnx.Constant dense<0.1> : tensor<16x3x3x3xf32> + %3 = onnx.Constant dense<0.1> : tensor<16xf32> + %4 = onnx.Constant dense<0.1> : tensor + %5 = onnx.Constant dense<0.1> : tensor + %6 = onnx.Constant dense<0.1> : tensor + %7 = onnx.Constant dense<0.1> : tensor<3x1x1xf32> + %8 = "onnx.Transpose"(%arg0) {} : (tensor<1x180x320x3xf32>) -> tensor<1x3x180x320xf32> + %9 = "onnx.Add"(%8, %0) {} : (tensor<1x3x180x320xf32>, tensor<3x1x1xf32>) -> tensor<1x3x180x320xf32> + %10 = "onnx.Div"(%9, %7) {} : (tensor<1x3x180x320xf32>, tensor<3x1x1xf32>) -> tensor<1x3x180x320xf32> + %11 = "onnx.Conv"(%10, %2, %3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "Conv_9", pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<1x3x180x320xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x16x90x160xf32> + %12 = "onnx.Add"(%11, %4) {onnx_node_name = "Add_11"} : (tensor<1x16x90x160xf32>, tensor) -> tensor<1x16x90x160xf32> + return %12 : tensor<1x16x90x160xf32> +} +"onnx.EntryPoint"() {func = @main_graph} : () -> () + +//CHECK: %{{[0-9]+}} = "onnx.Conv"(%{{.*}}, %{{.*}}, %{{.*}}) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : si64, kernel_shape = [3, 3], onnx_node_name = "Conv_9", pads = [1, 1, 1, 1], strides = [2, 2]} : (tensor<1x3x180x320xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x16x90x160xf32> +//CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) {onnx_node_name = "Add_11"} : (tensor<1x16x90x160xf32>, tensor) -> tensor<1x16x90x160xf32> diff --git a/test/mlir/onnx/onnx_lowering_call_canonicalize_O3.mlir b/test/mlir/onnx/onnx_lowering_call_canonicalize_O3.mlir index 3a976908d2..1e1ea20022 100644 --- a/test/mlir/onnx/onnx_lowering_call_canonicalize_O3.mlir +++ b/test/mlir/onnx/onnx_lowering_call_canonicalize_O3.mlir @@ -1,7 +1,7 @@ -// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --mcpu=z16 --shape-inference --convert-onnx-to-krnl='ops-for-call=Conv' --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt -O3 --mtriple=s390x-ibm-loz --march=z16 --shape-inference --convert-onnx-to-krnl='ops-for-call=Conv' --canonicalize %s -split-input-file | FileCheck %s -// use --mtriple=s390x-ibm-loz --mcpu=z16 to enable SIMD as we now need a machine -// can also use -march=x86-64 instead. +// use --mtriple=s390x-ibm-loz --march=z16 to enable SIMD as we now need a machine +// can also use --march=x86-64 instead. // ----- diff --git a/test/mlir/onnx/onnx_recompose.mlir b/test/mlir/onnx/onnx_recompose.mlir index e79d4029c4..7fbf9b282d 100644 --- a/test/mlir/onnx/onnx_recompose.mlir +++ b/test/mlir/onnx/onnx_recompose.mlir @@ -6,7 +6,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { %x = "onnx.Add"(%input, %bias) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -23,7 +23,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: t // CHECK-LABEL: func.func @layernorm_with_spurious_adds // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -33,7 +33,7 @@ func.func @layernorm_with_spurious_adds(%input: tensor<1x384x768xf32>, %scale: t // Layernorm without bias func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -47,7 +47,223 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x // CHECK-LABEL: func.func @layernorm_without_bias // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: return [[Y_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_first_reduce_unsuitable_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %mean = "onnx.ReduceMeanV13"(%x) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMeanV13"(%dd) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_first_reduce_unsuitable_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) +// CHECK: return [[Y_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_second_reduce_unsuitable_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMeanV13"(%dd) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_second_reduce_unsuitable_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.ReduceMeanV13"([[VAR_3_]]) {axes = [-2], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_0_]]) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sqrt"([[VAR_5_]]) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_2_]], [[VAR_6_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_7_]], [[PARAM_1_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> +// CHECK: return [[VAR_8_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %axis = onnx.Constant dense<-1> : tensor<1xi64> + %mean = "onnx.ReduceMean"(%x, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_v18 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) +// CHECK: return [[Y_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_v18_dynamic_axis(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>, %axis: tensor) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %mean = "onnx.ReduceMean"(%x, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %axis) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_v18_dynamic_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>, [[PARAM_3_:%.+]]: tensor) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[PARAM_3_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.ReduceMean"([[VAR_3_]], [[PARAM_3_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x1xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_0_]]) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sqrt"([[VAR_5_]]) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Div"([[VAR_2_]], [[VAR_6_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Mul"([[VAR_7_]], [[PARAM_1_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> +// CHECK: return [[VAR_8_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_first_reduce_unsuitable_axis_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %axis1 = onnx.Constant dense<-2> : tensor<1xi64> + %axis2 = onnx.Constant dense<-1> : tensor<1xi64> + %mean = "onnx.ReduceMean"(%x, %axis1) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %axis2) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_first_reduce_unsuitable_axis_v18 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64> +// CHECK: [[VAR_2_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[VAR_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_2_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_3_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none) +// CHECK: return [[Y_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %axis1 = onnx.Constant dense<-1> : tensor<1xi64> + %axis2 = onnx.Constant dense<-2> : tensor<1xi64> + %mean = "onnx.ReduceMean"(%x, %axis1) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %axis2) {keepdims = 1 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_second_reduce_unsuitable_axis_v18 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<-2> : tensor<1xi64> +// CHECK: [[VAR_3_:%.+]] = "onnx.ReduceMean"([[PARAM_0_]], [[VAR_1_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[VAR_3_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Mul"([[VAR_4_]], [[VAR_4_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.ReduceMean"([[VAR_5_]], [[VAR_2_]]) {keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<1x384x768xf32>, tensor<1xi64>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_6_]], [[VAR_0_]]) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Sqrt"([[VAR_7_]]) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Div"([[VAR_4_]], [[VAR_8_]]) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_10_:%.+]] = "onnx.Mul"([[VAR_9_]], [[PARAM_1_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> +// CHECK: return [[VAR_10_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_v18_noop(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %none = "onnx.NoValue"() {value} : () -> none + %mean = "onnx.ReduceMean"(%x, %none) {keepdims = 1 : si64, noop_with_empty_axes = 1: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x768xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %none) {keepdims = 1 : si64, noop_with_empty_axes = 1: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x768xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x768xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_v18_noop +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.200000e+00> : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Sub"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Add"([[VAR_2_]], [[VAR_0_]]) : (tensor<1x384x768xf32>, tensor) -> tensor<1x384x768xf32> +// CHECK: [[VAR_4_:%.+]] = "onnx.Sqrt"([[VAR_3_]]) : (tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_5_:%.+]] = "onnx.Div"([[VAR_1_]], [[VAR_4_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Mul"([[VAR_5_]], [[PARAM_1_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> +// CHECK: return [[VAR_6_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +func.func @layernorm_without_bias_v18_reduce_all(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor + %none = "onnx.NoValue"() {value} : () -> none + %mean = "onnx.ReduceMean"(%x, %none) {keepdims = 1 : si64, noop_with_empty_axes = 0: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x1xf32> + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> + %var = "onnx.ReduceMean"(%dd, %none) {keepdims = 1 : si64, noop_with_empty_axes = 0: si64} : (tensor<1x384x768xf32>, none) -> tensor<1x384x1xf32> + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> + return %Y : tensor<1x384x768xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @layernorm_without_bias_v18_reduce_all +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 0 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -56,8 +272,8 @@ func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768x // Layernorm, add/mul switched -func.func @layernorm_with_bias_swtiched(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor +func.func @layernorm_with_bias_switched(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -70,9 +286,9 @@ func.func @layernorm_with_bias_swtiched(%x: tensor<1x384x768xf32>, %scale: tenso return %Y : tensor<1x384x768xf32> // mlir2FileCheck.py -// CHECK-LABEL: func.func @layernorm_with_bias_swtiched +// CHECK-LABEL: func.func @layernorm_with_bias_switched // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -82,13 +298,13 @@ func.func @layernorm_with_bias_swtiched(%x: tensor<1x384x768xf32>, %scale: tenso // Recognize the bias and fold into LayerNorm. func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<1x384x768xf32> { %0 = "onnx.NoValue"() {value} : () -> none - %NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) + %NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) %Y = "onnx.Add"(%bias, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> return %Y : tensor<1x384x768xf32> // mlir2FileCheck.py // CHECK-LABEL: func.func @layernorm_without_bias // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -97,7 +313,7 @@ func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<76 // Not a Layernorm as top sub has inputs switched func.func @not_a_layer_norm(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%mean, %x) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -117,7 +333,7 @@ func.func @not_a_layer_norm(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // ----- // Check alternative layer norm with reciprocal instead of div func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> @@ -135,7 +351,7 @@ func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: ten // CHECK-LABEL: func.func @layer_norm_with_reciprocal // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -145,7 +361,7 @@ func.func @layer_norm_with_reciprocal(%input: tensor<1x384x768xf32>, %scale: ten // Check alternative layer norm with reciprocal instead of div func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %one = onnx.Constant dense<1.0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> @@ -164,7 +380,7 @@ func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: ten // CHECK-LABEL: func.func @layer_norm_with_div_by_one // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_0_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) // CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[Y_]], [[PARAM_2_]]) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> // CHECK: return [[VAR_1_]] : tensor<1x384x768xf32> // CHECK: } @@ -174,7 +390,7 @@ func.func @layer_norm_with_div_by_one(%input: tensor<1x384x768xf32>, %scale: ten // Check alternative layer norm with reciprocal instead of div, fail because it is 2 / x instead of 1 / x func.func @not_a_layer_norm_with_div_by_two(%input: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %one = onnx.Constant dense<2.0> : tensor %x = "onnx.Add"(%input, %input) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> @@ -200,7 +416,7 @@ func.func @not_a_layer_norm_with_div_by_two(%input: tensor<1x384x768xf32>, %scal // RMS Layer norm (sub switched) func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %d = "onnx.Sub"(%mean, %x) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> @@ -217,7 +433,7 @@ func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { // CHECK: [[VAR_0_:%.+]] = "onnx.ReduceMeanV13"([[PARAM_0_]]) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> // CHECK: [[VAR_1_:%.+]] = "onnx.Sub"([[VAR_0_]], [[PARAM_0_]]) : (tensor<1x384x1xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -227,7 +443,7 @@ func.func @rms_layer_norm_v1(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // RMS Layer norm func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { - %eps = onnx.Constant dense<9.99999974E-6> : tensor + %eps = onnx.Constant dense<1.2E+0> : tensor %dd = "onnx.Mul"(%x, %x) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> %var = "onnx.ReduceMeanV13"(%dd) {axes = [-1], keepdims = 1 : si64, onnx_node_name = "ReduceMean_42"} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> %varEps = "onnx.Add"(%eps, %var) : (tensor, tensor<1x384x1xf32>) -> tensor<1x384x1xf32> @@ -240,7 +456,7 @@ func.func @rms_layer_norm_v2(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, // mlir2FileCheck.py // CHECK-LABEL: func.func @rms_layer_norm_v2 // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> { -// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) +// CHECK: [[Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none) // CHECK: return [[Y_]] : tensor<1x384x768xf32> // CHECK: } } @@ -261,3 +477,176 @@ func.func @qlinear_matmul(%arg0: tensor, %arg1: tensor, %arg2: // CHECK: return [[VAR_0_]] : tensor // CHECK: } } + +// ----- + + +func.func @qlinear_matmul_with_result_type(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<768x768xi8>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor) -> (tensor<1x2x768xi8>) { + %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor + %1 = "onnx.DequantizeLinear"(%arg3, %arg4, %arg5) {axis = 1 : si64} : (tensor<768x768xi8>, tensor, tensor) -> tensor<768x768xf32> + %2 = "onnx.MatMul"(%0, %1) : (tensor, tensor<768x768xf32>) -> tensor + %3 = "onnx.QuantizeLinear"(%2, %arg6, %arg7) {axis = 1 : si64} : (tensor, tensor, tensor) -> tensor<1x2x768xi8> + return %3: tensor<1x2x768xi8> + +// CHECK-LABEL: func.func @qlinear_matmul_with_result_type +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor, [[PARAM_3_:%.+]]: tensor<768x768xi8>, [[PARAM_4_:%.+]]: tensor, [[PARAM_5_:%.+]]: tensor, [[PARAM_6_:%.+]]: tensor, [[PARAM_7_:%.+]]: tensor) -> tensor<1x2x768xi8> { +// CHECK: [[VAR_0_:%.+]] = "onnx.QLinearMatMul"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]], [[PARAM_5_]], [[PARAM_6_]], [[PARAM_7_]]) : (tensor, tensor, tensor, tensor<768x768xi8>, tensor, tensor, tensor, tensor) -> tensor<1x2x768xi8> +// CHECK: return [[VAR_0_]] : tensor<1x2x768xi8> +// CHECK: } +} + +// ----- + +// gelu(x) = [x * (erf(x/1.41421354) + 1)] * 0.5 +func.func @test_gelu_erf_cst_1(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + + +func.func @test_gelu_with_result_type(%arg0 : tensor) -> tensor<1x2x3072xf32>{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor<1x2x3072xf32> + "func.return"(%4) : (tensor<1x2x3072xf32>) -> () + +// CHECK-LABEL: func.func @test_gelu_with_result_type +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor<1x2x3072xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor<1x2x3072xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2x3072xf32> +// CHECK: } +} + +// ----- + +// gelu(x) = [x * (1 + erf(x/1.41421354))] * 0.5 +func.func @test_gelu_erf_cst_change_add_operand_order(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%one, %1) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_add_operand_order +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = [(erf(x/1.41421354) + 1) * x] * 0.5 +func.func @test_gelu_erf_cst_change_mul_operand_order_1(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%2, %arg0) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%3, %half) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_mul_operand_order_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = 0.5 * [x * (erf(x/1.41421354) + 1) * x] +func.func @test_gelu_erf_cst_change_mul_operand_order_2(%arg0 : tensor) -> tensor{ + %sqrt2 = onnx.Constant dense<1.41421354> : tensor + %one = onnx.Constant dense<1.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %0 = "onnx.Div"(%arg0, %sqrt2) : (tensor, tensor) -> tensor + %1 = "onnx.Erf"(%0) : (tensor) -> tensor + %2 = "onnx.Add"(%1, %one) : (tensor, tensor) -> tensor + %3 = "onnx.Mul"(%arg0, %2) : (tensor, tensor) -> tensor + %4 = "onnx.Mul"(%half, %3) : (tensor, tensor) -> tensor + "func.return"(%4) : (tensor) -> () + +// CHECK-LABEL: func.func @test_gelu_erf_cst_change_mul_operand_order_2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor +// CHECK: } +} + +// ----- + +// gelu(x) = x * (0.5 * (1 + tanh[0.797884583 * (x + 0.044715 * x^3)])) +func.func @test_gelu_tanh(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %one = onnx.Constant dense<1.000000e+00> : tensor + %three = onnx.Constant dense<3.000000e+00> : tensor + %half = onnx.Constant dense<5.000000e-01> : tensor + %sqrt2pi = onnx.Constant dense<0.797884583> : tensor + %cst044715 = onnx.Constant dense<4.471500e-02> : tensor + %0 = "onnx.Pow"(%arg0, %three) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = "onnx.Mul"(%cst044715, %0) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %2 = "onnx.Add"(%arg0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %3 = "onnx.Mul"(%sqrt2pi, %2) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %4 = "onnx.Tanh"(%3) : (tensor<*xf32>) -> tensor<*xf32> + %5 = "onnx.Add"(%one, %4) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %6 = "onnx.Mul"(%half, %5) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %7 = "onnx.Mul"(%arg0, %6) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %7 : tensor<*xf32> + +// CHECK-LABEL: func.func @test_gelu_tanh +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.Gelu"([[PARAM_0_]]) {approximate = "tanh"} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> +// CHECK: } +} + +// ----- + +func.func @test_gelu_erf_two_adds(%arg0: tensor, %arg1: tensor<3072x768xf32>) -> tensor { + %0 = onnx.Constant dense<5.000000e-01> : tensor + %1 = onnx.Constant dense<1.000000e+00> : tensor + %2 = onnx.Constant dense<1.41421354> : tensor + %3 = onnx.Constant dense<3.000000e-01> : tensor<3072xf32> + %4 = "onnx.Add"(%arg0, %3) : (tensor, tensor<3072xf32>) -> tensor + %5 = "onnx.Div"(%4, %2) : (tensor, tensor) -> tensor + %6 = "onnx.Erf"(%5) : (tensor) -> tensor + %7 = "onnx.Add"(%6, %1) : (tensor, tensor) -> tensor + %8 = "onnx.Mul"(%4, %7) : (tensor, tensor) -> tensor + %9 = "onnx.Mul"(%8, %0) : (tensor, tensor) -> tensor + %10 = "onnx.MatMul"(%9, %arg1) : (tensor, tensor<3072x768xf32>) -> tensor + return %10 : tensor +} +// CHECK-LABEL: func.func @test_gelu_erf_two_adds +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<3072x768xf32>) -> tensor { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<3.000000e-01> : tensor<3072xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_0_]]) : (tensor, tensor<3072xf32>) -> tensor +// CHECK: [[VAR_2_:%.+]] = "onnx.Gelu"([[VAR_1_]]) {approximate = "none"} : (tensor) -> tensor +// CHECK: [[VAR_3_:%.+]] = "onnx.MatMul"([[VAR_2_]], [[PARAM_1_]]) : (tensor, tensor<3072x768xf32>) -> tensor +// CHECK: return [[VAR_3_]] : tensor +// CHECK: } diff --git a/test/mlir/onnx/onnx_recompose_locations.mlir b/test/mlir/onnx/onnx_recompose_locations.mlir new file mode 100644 index 0000000000..18b0a75e4e --- /dev/null +++ b/test/mlir/onnx/onnx_recompose_locations.mlir @@ -0,0 +1,27 @@ +// RUN: onnx-mlir-opt --recompose-onnx --canonicalize %s --mlir-print-debuginfo -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @layernorm_without_bias +func.func @layernorm_without_bias(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) { + %eps = onnx.Constant dense<9.99999974E-6> : tensor + %mean = "onnx.ReduceMeanV13"(%x) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> loc("mReduce") + %d = "onnx.Sub"(%x, %mean) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> loc("sub") + %dd = "onnx.Mul"(%d, %d) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32> loc("ddMul") + %var = "onnx.ReduceMeanV13"(%dd) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32> loc("vReduce") + %varEps = "onnx.Add"(%var, %eps) : (tensor<1x384x1xf32>, tensor) -> tensor<1x384x1xf32> loc("add") + %StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32> loc("sqrt") + %Norm = "onnx.Div"(%d, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32> loc("div") + %Y = "onnx.Mul"(%Norm, %scale) : (tensor<1x384x768xf32>, tensor<768xf32>) -> tensor<1x384x768xf32> loc("lnMul") + return %Y : tensor<1x384x768xf32> loc("return") +// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none) loc([[LOC_FUSED:#.+]]) +// CHECK: return [[VAR_Y_]] : tensor<1x384x768xf32> +// CHECK-DAG: [[LOC_M_REDUCE:#.+]] = loc("mReduce") +// CHECK-DAG: [[LOC_SUB:#.+]] = loc("sub") +// CHECK-DAG: [[LOC_DD_MUL:#.+]] = loc("ddMul") +// CHECK-DAG: [[LOC_V_REDUCE:#.+]] = loc("vReduce") +// CHECK-DAG: [[LOC_ADD:#.+]] = loc("add") +// CHECK-DAG: [[LOC_SQRT:#.+]] = loc("sqrt") +// CHECK-DAG: [[LOC_DIV:#.+]] = loc("div") +// CHECK-DAG: [[LOC_LN_MUL:#.+]] = loc("lnMul") +// CHECK: [[LOC_FUSED]] = loc(fused[[[LOC_M_REDUCE]], [[LOC_SUB]], [[LOC_DD_MUL]], [[LOC_V_REDUCE]], [[LOC_ADD]], [[LOC_SQRT]], [[LOC_DIV]], [[LOC_LN_MUL]]]) +} diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4a625c9405..9a73d90c0c 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -573,15 +573,16 @@ func.func @test_conv_transpose_output_shape(%arg0 : tensor<1x64x36x48xf32>, %arg // ----- //===----------------------------------------------------------------------===// +/// Test for pad op. +//===----------------------------------------------------------------------===// -/// Test Pad_1 -func.func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { +func.func @test_pad_const_pads(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { %0 = onnx.Constant dense<[0, 2, 2, 4]> : tensor<4xi64> %1 = onnx.Constant dense<0.000000e+00> : tensor<1xf32> %cst = "onnx.NoValue"() {value} : () -> none %2 = "onnx.Pad"(%arg0, %0, %1, %cst) {mode = "constant"} : (tensor<16x13xf32>, tensor<4xi64>, tensor<1xf32>, none) -> tensor<*xf32> "onnx.Return"(%2) : (tensor<*xf32>) -> () - // CHECK-LABEL: test_Pad_1 + // CHECK-LABEL: test_pad_const_pads // CHECK-SAME: ([[VAR_arg0:%.+]]: tensor<16x13xf32>) -> tensor<18x19xf32> { // CHECK: [[VAR_0:%.+]] = onnx.Constant dense<[0, 2, 2, 4]> : tensor<4xi64> // CHECK: [[VAR_1:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32> @@ -591,6 +592,98 @@ func.func @test_Pad_1(%arg0 : tensor<16x13xf32>) -> tensor<*xf32> { // ----- +func.func @test_pad_const_pad_unknown_axes_size(%arg0: tensor<1x3x4x5xf32>, %arg1: tensor) -> tensor { + %1 = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + %2 = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + %3 = "onnx.Pad"(%arg0, %1, %2, %arg1) {mode = "constant"}: (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor) -> tensor + return %3 : tensor + + // CHECK-LABEL: func @test_pad_const_pad_unknown_axes_size + // CHECK-SAME: (%[[VAR_arg0:.*]]: tensor<1x3x4x5xf32>, %[[VAR_arg1:.*]]: tensor) -> tensor { + // CHECK: %[[CONST_0:.*]] = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + // CHECK: %[[CONST_1:.*]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + // CHECK: %[[PAD_0:.*]] = "onnx.Pad"(%[[VAR_arg0]], %[[CONST_0]], %[[CONST_1]], %[[VAR_arg1]]) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor) -> tensor + // CHECK: return %[[PAD_0]] : tensor +} + +// ----- +func.func @test_pad_const_pad_axes(%arg0: tensor<1x3x4x5xf32>) -> tensor { + %0 = onnx.Constant dense<[1, 3]> : tensor<2xi64> + %1 = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + %2 = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + %3 = "onnx.Pad"(%arg0, %1, %2, %0) {mode = "constant"}: (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor + return %3 : tensor + + // CHECK-LABEL: func @test_pad_const_pad_axes + // CHECK-SAME: (%[[VAR_arg0:.*]]: tensor<1x3x4x5xf32>) -> tensor<1x3x4x12xf32> { + // CHECK: %[[CONST_0:.*]] = onnx.Constant dense<[1, 3]> : tensor<2xi64> + // CHECK: %[[CONST_1:.*]] = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + // CHECK: %[[CONST_2:.*]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + // CHECK: %[[PAD_0:.*]] = "onnx.Pad"(%[[VAR_arg0]], %[[CONST_1]], %[[CONST_2]], %[[CONST_0]]) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor<1x3x4x12xf32> + // %3 = "onnx.Pad"(%arg0, %1, %2, %0) // CHECK: return %[[PAD_0]] : tensor<1x3x4x12xf32> +} + +// ----- + +func.func @test_pad_const_axes(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<4xi64>) -> tensor { + %0 = onnx.Constant dense<[1, 3]> : tensor<2xi64> + %1 = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + %2 = "onnx.Pad"(%arg0, %arg1, %1, %0) {mode = "constant"} : (tensor<1x2x3x4xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor + return %2 : tensor + + // CHECK-LABEL: func @test_pad_const_axes + // CHECK-SAME: (%[[VAR_arg0:.*]]: tensor<1x2x3x4xf32>, %[[VAR_arg1:.*]]: tensor<4xi64>) -> tensor<1x?x3x?xf32> { + // CHECK: %[[CONST_0:.*]] = onnx.Constant dense<[1, 3]> : tensor<2xi64> + // CHECK: %[[CONST_1:.*]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + // CHECK: %[[PAD_0:.*]] = "onnx.Pad"(%[[VAR_arg0]], %[[VAR_arg1]], %[[CONST_1]], %[[CONST_0]]) {mode = "constant"} : (tensor<1x2x3x4xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor<1x?x3x?xf32> +} + +// ----- + +func.func @test_pad_all_dynamic(%arg0: tensor<1x3x4x5xf32>, %arg1: tensor<4xi64>, %arg2: tensor, %arg3: tensor<2xi64>) -> tensor { + %0 = "onnx.Pad"(%arg0, %arg1, %arg2, %arg3) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor, tensor<2xi64>) -> tensor + return %0 : tensor + + // CHECK-LABEL: func @test_pad_all_dynamic + // CHECK-SAME: ([[VAR_arg0:%.+]]: tensor<1x3x4x5xf32>, [[VAR_arg1:%.+]]: tensor<4xi64>, [[VAR_arg2:%.+]]: tensor, [[VAR_arg3:%.+]]: tensor<2xi64>) -> tensor { + // CHECK: [[VAR_0:%.+]] = "onnx.Pad"([[VAR_arg0]], [[VAR_arg1]], [[VAR_arg2]], [[VAR_arg3]]) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor, tensor<2xi64>) -> tensor + // CHECK: return [[VAR_0]] : tensor +} + +// ----- + +func.func @test_pad_const_negative_axes(%arg0: tensor<1x3x4x5xf32>) -> tensor { + %0 = onnx.Constant dense<[1, -2]> : tensor<2xi64> + %1 = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + %2 = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + %3 = "onnx.Pad"(%arg0, %1, %2, %0) {mode = "constant"}: (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor + return %3 : tensor + // CHECK-LABEL: func @test_pad_const_negative_axes + // CHECK-SAME: ([[VAR_arg0:%.*]]: tensor<1x3x4x5xf32>) -> tensor<1x3x11x5xf32> { + // CHECK: [[VAR_0:%.+]] = onnx.Constant dense<[1, -2]> : tensor<2xi64> + // CHECK: [[VAR_1:%.+]] = onnx.Constant dense<[0, 3, 0, 4]> : tensor<4xi64> + // CHECK: [[VAR_2:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + // CHECK: [[VAR_3:%.+]] = "onnx.Pad"([[VAR_arg0]], [[VAR_1]], [[VAR_2]], [[VAR_0]]) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor<1x3x11x5xf32> + // CHECK: return [[VAR_3]] : tensor<1x3x11x5xf32> +} + +// ----- + +func.func @test_pad_empty_axes_and_pads(%arg0: tensor<1x3x4x5xf32>) -> tensor { + %0 = onnx.Constant dense<[]> : tensor<0xi64> + %2 = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + %3 = "onnx.Pad"(%arg0, %0, %2, %0) {mode = "constant"}: (tensor<1x3x4x5xf32>, tensor<0xi64>, tensor<1xf32>, tensor<0xi64>) -> tensor + return %3 : tensor + + // CHECK-LABEL: func @test_pad_empty_axes_and_pads + // CHECK-SAME: (%[[VAR_arg0:.*]]: tensor<1x3x4x5xf32>) -> tensor<1x3x4x5xf32> { + // CHECK: %[[CONST_0:.*]] = onnx.Constant dense<> : tensor<0xi64> + // CHECK: %[[CONST_2:.*]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> + // CHECK: %[[PAD_0:.*]] = "onnx.Pad"(%[[VAR_arg0]], %[[CONST_0]], %[[CONST_2]], %[[CONST_0]]) {mode = "constant"} : (tensor<1x3x4x5xf32>, tensor<0xi64>, tensor<1xf32>, tensor<0xi64>) -> tensor<1x3x4x5xf32> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test for constant op. //===----------------------------------------------------------------------===// @@ -1564,6 +1657,32 @@ func.func private @test_squeezev11_empty_axes(%arg0 : tensor<16x1x32x1x64xf32>) // CHECK: onnx.Return [[RES]] : tensor<16x32x64xf32> } + +// ----- + +func.func @test_squeeze_dyn_dimension_empty_axes_diff_known_dims2(%arg0 : tensor<1x?x1x?xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {onnx_node_name = "onnx.NoValue_0", value} : () -> none + %0 = "onnx.Squeeze"(%arg0, %cst) : (tensor<1x?x1x?xf32>, none) -> (tensor<*xf32>) + "func.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_squeeze_dyn_dimension_empty_axes_diff_known_dims2 + // CHECK: [[RES:%.+]] = "onnx.Squeeze" + // CHECK-SAME: tensor<*xf32> +} + + +// ----- + +func.func @test_squeeze_dyn_dimension_empty_axes_diff_known_dims(%arg0 : tensor<1x?x1x2xf32>) -> tensor<*xf32> { + %cst = "onnx.NoValue"() {onnx_node_name = "onnx.NoValue_0", value} : () -> none + %0 = "onnx.Squeeze"(%arg0, %cst) : (tensor<1x?x1x2xf32>, none) -> (tensor<*xf32>) + "func.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_squeeze_dyn_dimension_empty_axes_diff_known_dims + // CHECK: [[RES:%.+]] = "onnx.Squeeze" + // CHECK-SAME: tensor<*xf32> +} + // ----- func.func @test_unsqueeze(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { @@ -3750,6 +3869,47 @@ func.func @test_custom3(%arg0: tensor<1024xi32>, %arg1: tensor<4xf32>) -> tensor // CHECK: return [[VAR_0_]] : tensor<4xf32> // CHECK: } +// ----- + +func.func @test_batch_norm_3d(%arg0: tensor<1x256x512xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256xf32>, %arg4: tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %Y, %Mean, %Var = "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y, %Mean, %Var : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func @test_batch_norm_3d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256x512xf32>, [[PARAM_1_:%.+]]: tensor<256xf32>, [[PARAM_2_:%.+]]: tensor<256xf32>, [[PARAM_3_:%.+]]: tensor<256xf32>, [[PARAM_4_:%.+]]: tensor<256xf32>) -> (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>) { +// CHECK: [[Y_:%.+]], [[running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>) +// CHECK: return [[Y_]], [[running_mean_]], [[VAR_running_var_]] : tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32> +} + +func.func @test_batch_norm_4d(%arg0: tensor<1x256x512x2xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256xf32>, %arg4: tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %Y, %Mean, %Var = "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, momentum = 2.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512x2xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y, %Mean, %Var : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func @test_batch_norm_4d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256x512x2xf32>, [[PARAM_1_:%.+]]: tensor<256xf32>, [[PARAM_2_:%.+]]: tensor<256xf32>, [[PARAM_3_:%.+]]: tensor<256xf32>, [[PARAM_4_:%.+]]: tensor<256xf32>) -> (tensor<1x256x512x2xf32>, tensor<256xf32>, tensor<256xf32>) { +// CHECK: [[Y_:%.+]], [[running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) {epsilon = 9.99999974E-6 : f32, momentum = 2.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512x2xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<1x256x512x2xf32>, tensor<256xf32>, tensor<256xf32>) +// CHECK: return [[Y_]], [[running_mean_]], [[VAR_running_var_]] : tensor<1x256x512x2xf32>, tensor<256xf32>, tensor<256xf32> +} + +func.func @test_batch_norm_dyn_shape_input(%arg0: tensor, %arg1: tensor<256xf32>, %arg2: tensor<256xf32>, %arg3: tensor<256xf32>, %arg4: tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %Y, %Mean, %Var = "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y, %Mean, %Var : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func @test_batch_norm_dyn_shape_input +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor<256xf32>, [[PARAM_2_:%.+]]: tensor<256xf32>, [[PARAM_3_:%.+]]: tensor<256xf32>, [[PARAM_4_:%.+]]: tensor<256xf32>) -> (tensor, tensor<256xf32>, tensor<256xf32>) { +// CHECK: [[Y_:%.+]], [[running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> (tensor, tensor<256xf32>, tensor<256xf32>) +// CHECK: return [[Y_]], [[running_mean_]], [[VAR_running_var_]] : tensor, tensor<256xf32>, tensor<256xf32> +} + +func.func @test_batch_norm_dyn_shape_mean_var(%arg0: tensor<1x256x512xf32>, %arg1: tensor<256xf32>, %arg2: tensor<256xf32>, %arg3: tensor, %arg4: tensor) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %Y, %Mean, %Var = "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>, tensor, tensor) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y, %Mean, %Var : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func @test_batch_norm_dyn_shape_mean_var +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256x512xf32>, [[PARAM_1_:%.+]]: tensor<256xf32>, [[PARAM_2_:%.+]]: tensor<256xf32>, [[PARAM_3_:%.+]]: tensor, [[PARAM_4_:%.+]]: tensor) -> (tensor<1x256x512xf32>, tensor, tensor) { +// CHECK: [[Y_:%.+]], [[running_mean_:%.+]], [[VAR_running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_3_]], [[PARAM_4_]]) {epsilon = 9.99999974E-6 : f32, momentum = 1.000000e+00 : f32, training_mode = 1 : si64} : (tensor<1x256x512xf32>, tensor<256xf32>, tensor<256xf32>, tensor, tensor) -> (tensor<1x256x512xf32>, tensor, tensor) +// CHECK: return [[Y_]], [[running_mean_]], [[VAR_running_var_]] : tensor<1x256x512xf32>, tensor, tensor +} // ----- @@ -3819,3 +3979,75 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // CHECK: } } +// ----- + +// Test Grid Sample + +func.func @test_grid_sample_same_dims(%arg0: tensor<1x3x1152x1344xf32>, %arg1: tensor<1x1152x1344x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_same_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x1152x1344xf32>, [[PARAM_1_:%.+]]: tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x3x1152x1344xf32>, tensor<1x1152x1344x2xf32>) -> tensor<1x3x1152x1344xf32> +// CHECK: return [[GRID]] : tensor<1x3x1152x1344xf32> +// CHECK: } +} + +func.func @test_grid_sample_diff_dims(%arg0: tensor<1x1x4x4xf32>, %arg1: tensor<1x6x6x2xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_diff_dims +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x1x4x4xf32>, tensor<1x6x6x2xf32>) -> tensor<1x1x6x6xf32> +// CHECK: return [[GRID]] : tensor<1x1x6x6xf32> +// CHECK: } +} + +func.func @test_grid_sample_6d(%arg0: tensor<1x2x4x4x4x4xf32>, %arg1: tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_6d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x4x4x4x4xf32>, [[PARAM_1_:%.+]]: tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor<1x2x4x4x4x4xf32>, tensor<1x6x6x4x4x4xf32>) -> tensor<1x2x6x6x4x4xf32> +// CHECK: return [[GRID]] : tensor<1x2x6x6x4x4xf32> +// CHECK: } +} + +func.func @test_grid_sample_dim_shape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} + +func.func @test_grid_sample_dim_shape2(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} + +func.func @test_grid_sample_dim_shape3(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor<*xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_grid_sample_dim_shape3 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor) -> tensor { +// CHECK: [[GRID:%.+]] = "onnx.GridSample"(%arg0, %arg1) {align_corners = 1 : si64, mode = "linear", onnx_node_name = "GridSample_181", padding_mode = "border"} : (tensor, tensor) -> tensor +// CHECK: return [[GRID]] : tensor +// CHECK: } + return %0 : tensor<*xf32> +} \ No newline at end of file diff --git a/test/mlir/onnx/onnx_shape_inference_error.mlir b/test/mlir/onnx/onnx_shape_inference_error.mlir index 943d220958..fa46de65cc 100644 --- a/test/mlir/onnx/onnx_shape_inference_error.mlir +++ b/test/mlir/onnx/onnx_shape_inference_error.mlir @@ -56,9 +56,7 @@ func.func @test_reshape_2D_shape(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<1x func.func @test_lstm_not_3D_input(%arg0: tensor<4x3xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none - // expected-error @+3 {{The first input tensor must have rank 3}} - // expected-error @+2 {{Failed to scan parameters successfully}} - // expected-error @+1 {{shape inference failed}} + // expected-error @+1 {{The first input tensor must have rank 3}} %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) onnx.Return %Y_h : tensor<*xf32> } @@ -67,9 +65,7 @@ func.func @test_lstm_not_3D_input(%arg0: tensor<4x3xf32>, %arg1: tensor<1x12x2xf func.func @test_lstm_not_3D_weight(%arg0: tensor<4x3x2xf32>, %arg1: tensor<12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none - // expected-error @+3 {{The second input tensor must have rank 3}} - // expected-error @+2 {{Failed to scan parameters successfully}} - // expected-error @+1 {{shape inference failed}} + // expected-error @+1 {{The second input tensor must have rank 3}} %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x2xf32>, tensor<12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) onnx.Return %Y_h : tensor<*xf32> } @@ -78,9 +74,7 @@ func.func @test_lstm_not_3D_weight(%arg0: tensor<4x3x2xf32>, %arg1: tensor<12x2x func.func @test_lstm_not_3D_recurrent(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<12x3xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none - // expected-error @+3 {{The third input tensor must have rank 3}} - // expected-error @+2 {{Failed to scan parameters successfully}} - // expected-error @+1 {{shape inference failed}} + // expected-error @+1 {{The third input tensor must have rank 3}} %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : si64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) onnx.Return %Y_h : tensor<*xf32> } @@ -89,9 +83,7 @@ func.func @test_lstm_not_3D_recurrent(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x func.func @test_lstm_wrong_direction(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { %cst = "onnx.NoValue"() {value} : () -> none - // expected-error @+3 {{direction attribute must be one of the strings: forward, reverse, and bidirectional}} - // expected-error @+2 {{Failed to scan parameters successfully}} - // expected-error @+1 {{shape inference failed}} + // expected-error @+1 {{direction attribute must be one of the strings: forward, reverse, and bidirectional}} %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : si64, direction="forwadr"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) onnx.Return %Y_h : tensor<*xf32> } @@ -115,5 +107,3 @@ func.func @test_category_mapper_diff_size_attrs (%arg0: tensor<20x1xi32>) -> ten %0 = "onnx.CategoryMapper"(%arg0) {cats_int64s = [1], cats_strings = ["cat"]} : (tensor<20x1xi32>) -> tensor<*x!onnx.String> "onnx.Return"(%0) : (tensor<*x!onnx.String>) -> () } - -// ----- diff --git a/test/mlir/onnx/parse/batchnorm_op15.onnxtext b/test/mlir/onnx/parse/batchnorm_op15.onnxtext new file mode 100644 index 0000000000..165e4a096e --- /dev/null +++ b/test/mlir/onnx/parse/batchnorm_op15.onnxtext @@ -0,0 +1,23 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 7, + opset_import: ["" : 15] +> +test_batchnorm_op15 (float[1,1,6,2048] input) => (float[1,1,6,2048] BatchNormalization_output_Y) { + Constant_output_0 = Constant () + Constant_1_output_0 = Constant () + Constant_2_output_0 = Constant () + Constant_3_output_0 = Constant () + BatchNormalization_output_Y, BatchNormalization_output_running_mean, BatchNormalization_output_running_var = BatchNormalization (input, Constant_2_output_0, Constant_3_output_0, Constant_output_0, Constant_1_output_0) +} + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x6x2048xf32> {onnx.name = "input"}) -> (tensor<1x1x6x2048xf32> {onnx.name = "BatchNormalization_output_Y"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32> +// CHECK: [[Y_:%.+]], [[running_mean_:%.+]], [[running_var_:%.+]] = "onnx.BatchNormalization"([[PARAM_0_]], [[VAR_2_]], [[VAR_3_]], [[VAR_0_]], [[VAR_1_]]) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32, training_mode = 1 : si64} : (tensor<1x1x6x2048xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> (tensor<1x1x6x2048xf32>, tensor<1xf32>, tensor<1xf32>) +// CHECK: onnx.Return [[Y_]] : tensor<1x1x6x2048xf32> +// CHECK: } diff --git a/test/mlir/onnx/parse/batchnorm_op9.onnxtext b/test/mlir/onnx/parse/batchnorm_op9.onnxtext new file mode 100644 index 0000000000..e31a2a5d0c --- /dev/null +++ b/test/mlir/onnx/parse/batchnorm_op9.onnxtext @@ -0,0 +1,23 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 4, + opset_import: ["" : 9] +> +test_batchnorm_op9 (float[1,1,6,2048] input) => (float[1,1,6,2048] BatchNormalization_output_Y) { + Constant_output_0 = Constant () + Constant_1_output_0 = Constant () + Constant_2_output_0 = Constant () + Constant_3_output_0 = Constant () + BatchNormalization_output_Y, BatchNormalization_output_mean, BatchNormalization_output_var, BatchNormalization_output_saved_mean, BatchNormalization_output_saved_var = BatchNormalization (input, Constant_2_output_0, Constant_3_output_0, Constant_output_0, Constant_1_output_0) +} + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x6x2048xf32> {onnx.name = "input"}) -> (tensor<1x1x6x2048xf32> {onnx.name = "BatchNormalization_output_Y"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant dense<0.000000e+00> : tensor<1xf32> +// CHECK: [[Y_:%.+]], [[VAR_out_mean_:%.+]], [[VAR_out_var_:%.+]], [[VAR_saved_mean_:%.+]], [[VAR_saved_var_:%.+]] = "onnx.BatchNormalizationV9"([[PARAM_0_]], [[VAR_2_]], [[VAR_3_]], [[VAR_0_]], [[VAR_1_]]) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32} : (tensor<1x1x6x2048xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> (tensor<1x1x6x2048xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) +// CHECK: onnx.Return [[Y_]] : tensor<1x1x6x2048xf32> +// CHECK: } diff --git a/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext new file mode 100644 index 0000000000..c5005ca136 --- /dev/null +++ b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext @@ -0,0 +1,19 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 10, + opset_import: ["" : 22] +> +test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) { + int8_cast_output = Cast (input) + int4_cast_output = Cast (int8_cast_output) + uint8_cast_output = Cast (input2) + uint4_cast_output = Cast (uint8_cast_output) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> +// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4> +// CHECK: } diff --git a/test/mlir/onnx/parse/cast_with_saturate.onnxtext b/test/mlir/onnx/parse/cast_with_saturate.onnxtext new file mode 100644 index 0000000000..e094199658 --- /dev/null +++ b/test/mlir/onnx/parse/cast_with_saturate.onnxtext @@ -0,0 +1,13 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 9, + opset_import: ["" : 19], + producer_name: "backend-test" +> +test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ (float16[3,4] input) => (float8e4m3fnuz[3,4] output) { + output = Cast (input) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf16> {onnx.name = "input"}) -> (tensor<3x4xf8E4M3FNUZ> {onnx.name = "output"}) +// CHECK: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 0 : si64, to = f8E4M3FNUZ} : (tensor<3x4xf16>) -> tensor<3x4xf8E4M3FNUZ> +// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4xf8E4M3FNUZ> diff --git a/test/mlir/onnx/parse/com.microsoft.gelu.json b/test/mlir/onnx/parse/com.microsoft.gelu.json new file mode 100644 index 0000000000..8bf2ddc229 --- /dev/null +++ b/test/mlir/onnx/parse/com.microsoft.gelu.json @@ -0,0 +1,55 @@ +// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --printIR %s | FileCheck %s + +// Semi hand-written model. +// When converted to onnxtext, onnx-mlir didn't like the result. + +// CHECK: [[DQ:%.+]] = "onnx.Custom"(%arg0) {domain_name = "com.microsoft", function_name = "Gelu", onnx_node_name = "myGelu", output_element_type = bf16, shape_infer_pattern = "MDBroadcast"} : (tensor<1x64x112x112xbf16>) -> tensor<1x64x112x112xbf16> +// CHECK: return [[DQ]] +{ + "irVersion": "8", + "producerName": "pytorch", + "producerVersion": "2.1.2", + "graph": { + "node": [ + { + "input": [ + "myInput" + ], + "output": ["myGelu_output_0"], + "name": "myGelu", + "opType": "Gelu", + "domain": "com.microsoft" + } + ], + "name": "main_graph", + "input": [ + { + "name": "myInput", + "type": { + "tensorType": { + "elemType": 16, + "shape": { + "dim": [ + {"dimValue": "1"}, + {"dimValue": "64"}, + {"dimValue": "112"}, + {"dimValue": "112"} + ] + } + } + } + } + ], + "output": [ + { + "name": "myGelu_output_0", + "type": { + "tensorType": { + "elemType": 16 + } + } + } + ] + }, + "opsetImport": [{"version": "17"}] +} diff --git a/test/mlir/onnx/parse/com.microsoft.qdq_linear.json b/test/mlir/onnx/parse/com.microsoft.qdq_linear.json new file mode 100644 index 0000000000..104b61c765 --- /dev/null +++ b/test/mlir/onnx/parse/com.microsoft.qdq_linear.json @@ -0,0 +1,110 @@ +// RUN: onnx-mlir --EmitONNXIR --useOnnxModelTypes=false --printIR %s | FileCheck %s + +// Semi hand-written model. +// When converted to onnxtext, onnx-mlir didn't like the result. + +// CHECK: [[SCALE:%.+]] = onnx.Constant dense<-1.08420217E-19> : tensor +// CHECK: [[ZERO_P:%.+]] = onnx.Constant dense<0> : tensor +// CHECK: [[DQ:%.+]] = "onnx.Custom"(%arg0, [[SCALE]], [[ZERO_P]]) {domain_name = "com.microsoft", function_name = "DequantizeLinear", onnx_node_name = "myDequantizeLinear", output_element_type = f32, shape_infer_pattern = "MDBroadcast"} : (tensor<1x64x112x112xi8>, tensor, tensor) -> tensor<1x64x112x112xf32> +// CHECK: [[RELU:%.+]] = "onnx.Relu"([[DQ]]) {onnx_node_name = "myrelu1Relu"} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32> +// CHECK: [[Q:%.+]] = "onnx.Custom"([[RELU]], [[SCALE]], [[ZERO_P]]) {domain_name = "com.microsoft", function_name = "QuantizeLinear", onnx_node_name = "myQuantizeLinear_1", output_element_type = i8, shape_infer_pattern = "MDBroadcast"} : (tensor<1x64x112x112xf32>, tensor, tensor) -> tensor<1x64x112x112xi8> +// CHECK: return [[Q]] : tensor<1x64x112x112xi8> +{ + "irVersion": "8", + "producerName": "pytorch", + "producerVersion": "2.1.2", + "graph": { + "node": [ + { + "output": ["scale_output_0"], + "name": "scale", + "opType": "Constant", + "attribute": [ + { + "name": "value", + "t": {"dataType": 1, "rawData": "AAAAoD8="}, + "type": "TENSOR" + } + ] + }, + { + "output": ["zeropoint_output_0"], + "name": "zeropoint", + "opType": "Constant", + "attribute": [ + { + "name": "value", + "t": {"dataType": 3, "rawData": "AAA="}, + "type": "TENSOR" + } + ] + }, + { + "input": [ + "myQuantizeLinear_output_0", + "scale_output_0", + "zeropoint_output_0" + ], + "output": ["myDequantizeLinear_output_0"], + "name": "myDequantizeLinear", + "opType": "DequantizeLinear", + "domain": "com.microsoft" + }, + { + "input": ["myDequantizeLinear_output_0"], + "output": ["myrelu1Relu_output_0"], + "name": "myrelu1Relu", + "opType": "Relu" + }, + { + "input": [ + "myrelu1Relu_output_0", + "scale_output_0", + "zeropoint_output_0" + ], + "output": ["myQuantizeLinear_1_output_0"], + "name": "myQuantizeLinear_1", + "opType": "QuantizeLinear", + "domain": "com.microsoft" + } + ], + "name": "main_graph", + "input": [ + { + "name": "myQuantizeLinear_output_0", + "type": { + "tensorType": { + "elemType": 3, + "shape": { + "dim": [ + {"dimValue": "1"}, + {"dimValue": "64"}, + {"dimValue": "112"}, + {"dimValue": "112"} + ] + } + } + } + } + ], + "output": [ + { + "name": "myQuantizeLinear_1_output_0", + "type": { + "tensorType": { + "elemType": 3, + "shape": { + "dim": [ + {"dimValue": "1"}, + {"dimValue": "64"}, + {"dimValue": "112"}, + {"dimValue": "112"} + ] + } + } + } + } + ] + }, + "opsetImport": [{"version": "17"}] +} diff --git a/test/mlir/onnx/parse/eyelike.onnxtext b/test/mlir/onnx/parse/eyelike.onnxtext new file mode 100644 index 0000000000..0c5f7aa4c7 --- /dev/null +++ b/test/mlir/onnx/parse/eyelike.onnxtext @@ -0,0 +1,17 @@ + +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 4, + opset_import: ["" : 9] +> +test_eye_like (float[] Mul_lhs, float[unk__a,unk__b] EyeLike_in) => (float[] Mul_res) +{ + EyeLike_out = EyeLike (EyeLike_in) + Mul_res = Mul (Mul_lhs, EyeLike_out) +} + +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32> {onnx.dim_params = "0:unk__a,1:unk__b", onnx.name = "Mul_lhs"}, [[PARAM_1_:%.+]]: tensor {onnx.name = "EyeLike_in"}) -> (tensor<*xf32> {onnx.name = "Mul_res"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.EyeLike"([[PARAM_1_]]) {k = 0 : si64} : (tensor) -> tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: onnx.Return [[VAR_1_]] : tensor<*xf32> diff --git a/test/mlir/onnx/parse/eyelike_dtype.onnxtext b/test/mlir/onnx/parse/eyelike_dtype.onnxtext new file mode 100644 index 0000000000..8f69b6eaa7 --- /dev/null +++ b/test/mlir/onnx/parse/eyelike_dtype.onnxtext @@ -0,0 +1,15 @@ + +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 4, + opset_import: ["" : 9] +> +test_eye_like_dtype (float[unk__a,unk__b] EyeLike_in) => (uint16[] EyeLike_out) +{ + EyeLike_out = EyeLike (EyeLike_in) +} + +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor {onnx.dim_params = "0:unk__a,1:unk__b", onnx.name = "EyeLike_in"}) -> (tensor {onnx.dim_params = "0:unk__a,1:unk__b", onnx.name = "EyeLike_out"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.EyeLike"([[PARAM_0_]]) {dtype = 4 : si64, k = 0 : si64} : (tensor) -> tensor +// CHECK: onnx.Return [[VAR_0_]] : tensor diff --git a/test/mlir/onnx/parse/layer_normalization_recompose.onnxtext b/test/mlir/onnx/parse/layer_normalization_recompose.onnxtext new file mode 100644 index 0000000000..2ba6de57a0 --- /dev/null +++ b/test/mlir/onnx/parse/layer_normalization_recompose.onnxtext @@ -0,0 +1,24 @@ +// RUN: onnx-mlir --EmitONNXIR --printIR %s | FileCheck %s + +< + ir_version: 8, + opset_import: ["" : 17] +> +agraph (float[1,384,768] X, float[768] SCALE, float[768] BIAS) => (float[1,384,768] Y) { + EPS = Constant () + POW_EXPONENT = Constant () + MEAN = ReduceMean (X) + D = Sub (X, MEAN) + DD = Pow (D, POW_EXPONENT) + VAR = ReduceMean (DD) + VAR_EPS = Add (VAR, EPS) + STD_DEV = Sqrt (VAR_EPS) + NORM = Div (D, STD_DEV) + NORM_SCALED = Mul (NORM, SCALE) + Y = Add (NORM_SCALED, BIAS) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32> {onnx.name = "X"}, [[PARAM_1_:%.+]]: tensor<768xf32> {onnx.name = "SCALE"}, [[PARAM_2_:%.+]]: tensor<768xf32> {onnx.name = "BIAS"}) -> (tensor<1x384x768xf32> {onnx.name = "Y"}) { +// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 9.99999974E-6 : f32, onnx_node_name = "onnx.LayerNormalization_0", stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none) +// CHECK: return [[Y_]] : tensor<1x384x768xf32> +// CHECK: } diff --git a/test/mlir/onnx/parse/test_acos_22.onnxtext b/test/mlir/onnx/parse/test_acos_22.onnxtext new file mode 100644 index 0000000000..23bd2e6593 --- /dev/null +++ b/test/mlir/onnx/parse/test_acos_22.onnxtext @@ -0,0 +1,16 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s + + +< + ir_version: 7, + opset_import: ["" : 22], + producer_name: "backend-test" +> +test_abs (bfloat16[3,4,5] x) => (bfloat16[3,4,5] y) { + y = Acos (x) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xbf16> {onnx.name = "x"}) -> (tensor<3x4x5xbf16> {onnx.name = "y"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.Acos"([[PARAM_0_]]) : (tensor<3x4x5xbf16>) -> tensor<3x4x5xbf16> +// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4x5xbf16> +// CHECK: } diff --git a/test/mlir/onnx/parse/test_acos_7.onnxtext b/test/mlir/onnx/parse/test_acos_7.onnxtext new file mode 100644 index 0000000000..8c4f8ef599 --- /dev/null +++ b/test/mlir/onnx/parse/test_acos_7.onnxtext @@ -0,0 +1,16 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s + + +< + ir_version: 7, + opset_import: ["" : 7], + producer_name: "backend-test" +> +test_abs (float[3,4,5] x) => (float[3,4,5] y) { + y = Acos (x) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32> {onnx.name = "x"}) -> (tensor<3x4x5xf32> {onnx.name = "y"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.Acos"([[PARAM_0_]]) : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4x5xf32> +// CHECK: } diff --git a/test/mlir/onnx/parse/test_cast.onnxtext b/test/mlir/onnx/parse/test_cast.onnxtext new file mode 100644 index 0000000000..d4a3a16eda --- /dev/null +++ b/test/mlir/onnx/parse/test_cast.onnxtext @@ -0,0 +1,17 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s + +< + ir_version: 7, + opset_import: ["" : 19], + producer_name: "backend-test" +> +test_cast (int64[1] x) => (float[1] y) { + y = Cast(x) +} + +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi64> {onnx.name = "x"}) -> (tensor<1xf32> {onnx.name = "y"}) { +// CHECK: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 0 : si64, to = f32} : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: onnx.Return [[VAR_0_]] : tensor<1xf32> +// CHECK: } +// CHECK: "onnx.EntryPoint"() {func = @main_graph} : () -> () diff --git a/test/mlir/onnx/parse/test_dim_zero.onnxtext b/test/mlir/onnx/parse/test_dim_zero.onnxtext new file mode 100644 index 0000000000..b1d8ebdc85 --- /dev/null +++ b/test/mlir/onnx/parse/test_dim_zero.onnxtext @@ -0,0 +1,13 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s + +< + ir_version: 9, + opset_import: ["" : 19] +> +identity (float[0] x) => (float[0] y) { + y = Identity(x) +} + +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<0xf32> {onnx.name = "x"}) -> (tensor<0xf32> {onnx.name = "y"}) { + diff --git a/test/mlir/onnx/parse/test_pad_axes.onnxtext b/test/mlir/onnx/parse/test_pad_axes.onnxtext new file mode 100644 index 0000000000..e4d9e5ac5f --- /dev/null +++ b/test/mlir/onnx/parse/test_pad_axes.onnxtext @@ -0,0 +1,18 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s + +< + ir_version: 9, + opset_import: ["" : 19] +> +identity (float[1,2,3,4] x, int64[4] pads) => (float[1,2,3,4] y) { + axes = Constant () + value = Constant () + y = Pad (x, pads, value, axes) +} + +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32> {onnx.name = "x"}, [[PARAM_1_:%.+]]: tensor<4xi64> {onnx.name = "pads"}) -> (tensor<1x2x3x4xf32> {onnx.name = "y"}) { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[0, 1]> : tensor<2xi64> +// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<1xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Pad"([[PARAM_0_]], [[PARAM_1_]], [[VAR_1_]], [[VAR_0_]]) {mode = "constant"} : (tensor<1x2x3x4xf32>, tensor<4xi64>, tensor<1xf32>, tensor<2xi64>) -> tensor<1x2x3x4xf32> +// CHECK: onnx.Return [[VAR_2_]] : tensor<1x2x3x4xf32> diff --git a/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir new file mode 100644 index 0000000000..08a994db70 --- /dev/null +++ b/test/mlir/parallel/krnl_parallel_clause_to_omp.mlir @@ -0,0 +1,168 @@ +// RUN: onnx-mlir-opt -O3 --process-krnl-parallel-clause %s -split-input-file | FileCheck %s + +// ----- + +func.func @omp_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1), num_threads(%c8_i32) {proc_bind = "spread"} : index + } + omp.yield + } + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_threads_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: omp.parallel num_threads([[CST_8_]] : i32) proc_bind(spread) { +} +// ----- +func.func @omp_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1), num_threads(%c8_i32) : index + } + omp.yield + } + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_threads +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK: omp.parallel num_threads([[CST_8_]] : i32) { +} + +// ----- + +func.func @omp_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + krnl.parallel_clause(%arg1) {proc_bind = "spread"} : index + } + omp.yield + } + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_affinity +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: omp.parallel proc_bind(spread) { +} + +// ----- + +func.func @omp_normal(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c16384 = arith.constant 16384 : index + %alloc = memref.alloc() {alignment = 16 : i64} : memref<16x8x128xf32> + %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_0[%c0] : memref<1xindex> + %reshape = memref.reshape %arg0(%alloc_0) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_0 : memref<1xindex> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_1[%c0] : memref<1xindex> + %reshape_2 = memref.reshape %arg0(%alloc_1) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_1 : memref<1xindex> + %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<1xindex> + memref.store %c16384, %alloc_3[%c0] : memref<1xindex> + %reshape_4 = memref.reshape %alloc(%alloc_3) : (memref<16x8x128xf32>, memref<1xindex>) -> memref<16384xf32> + memref.dealloc %alloc_3 : memref<1xindex> + omp.parallel { + omp.wsloop { + omp.loop_nest (%arg1) : index = (%c0) to (%c16384) step (%c32) { + memref.alloca_scope { + %0 = vector.load %reshape[%arg1] : memref<16384xf32>, vector<32xf32> + %1 = vector.load %reshape_2[%arg1] : memref<16384xf32>, vector<32xf32> + %2 = arith.addf %0, %1 : vector<32xf32> + vector.store %2, %reshape_4[%arg1] : memref<16384xf32>, vector<32xf32> + } + omp.yield + } + } + omp.terminator + } + return %alloc : memref<16x8x128xf32> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @omp_normal +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) { +// CHECK: omp.parallel { +} \ No newline at end of file diff --git a/test/mlir/tools/onnx-mlir-opt/outputfile.mlir b/test/mlir/tools/onnx-mlir-opt/outputfile.mlir new file mode 100644 index 0000000000..4753efc0f6 --- /dev/null +++ b/test/mlir/tools/onnx-mlir-opt/outputfile.mlir @@ -0,0 +1,2 @@ +// RUN: onnx-mlir-opt %s -o %t +// RUN: test -f %t \ No newline at end of file diff --git a/test/modellib/ConvModel.cpp b/test/modellib/ConvModel.cpp index 7841e81153..251ef9fde8 100644 --- a/test/modellib/ConvModel.cpp +++ b/test/modellib/ConvModel.cpp @@ -101,7 +101,7 @@ bool Conv2DLibBuilder::build() { // Use the convOp shape inference method to compute output shape, and unset // the shape so that we don't leave IR in a inconsistent state. convOp.getX().setType(xType); // Use static dims to infer shape. - LogicalResult res = convOp.inferShapes([](mlir::Region &) {}); + LogicalResult res = convOp.inferShapes([](Region &) {}); if (failed(res)) return false; diff --git a/test/multiple-models/CMakeLists.txt b/test/multiple-models/CMakeLists.txt index 69846d5e4e..8305eac687 100644 --- a/test/multiple-models/CMakeLists.txt +++ b/test/multiple-models/CMakeLists.txt @@ -30,3 +30,4 @@ add_custom_target(check-multiple-models add_dependencies(check-onnx-backend onnx-mlir) add_dependencies(check-multiple-models PyRuntimeC) +add_dependencies(check-multiple-models PyCompileAndRuntimeC) diff --git a/test/unit/BType/CMakeLists.txt b/test/unit/BType/CMakeLists.txt index 14129b3f28..f89943fb5c 100644 --- a/test/unit/BType/CMakeLists.txt +++ b/test/unit/BType/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestBType +add_onnx_mlir_unittest(TestBType TestBType.cpp LINK_LIBS PRIVATE diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 849459b4c8..a4c3e5d192 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -13,14 +13,14 @@ set_target_properties(check-unittest PROPERTIES FOLDER "Tests") # Exclude the target from the default VS build set_target_properties(check-unittest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD ON) -# add_unittest(test_name sources... options... +# add_onnx_mlir_unittest(test_name sources... options... # This function (generally) has the same semantic as add_onnx_mlir_executable. # A test with test_name is added as a ctest to the unittest testsuite and # all the rest of the arguments are passed directly to add_onnx_mlir_executable. # The function usage is meant to look like a call to add_onnx_mlir_executable # for readability. # ) -function(add_unittest test_name) +function(add_onnx_mlir_unittest test_name) add_onnx_mlir_executable(${test_name} NO_INSTALL ${ARGN}) add_dependencies(unittest ${test_name}) @@ -34,7 +34,7 @@ function(add_unittest test_name) set_tests_properties(${test_name} PROPERTIES LABELS unittest) endfunction() -add_unittest(TestInstrumentation +add_onnx_mlir_unittest(TestInstrumentation TestInstrumentation.cpp INCLUDE_DIRS PUBLIC diff --git a/test/unit/CustomFn/CMakeLists.txt b/test/unit/CustomFn/CMakeLists.txt index 53453d7f31..4ac3481e54 100644 --- a/test/unit/CustomFn/CMakeLists.txt +++ b/test/unit/CustomFn/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestCustomFn +add_onnx_mlir_unittest(TestCustomFn TestCustomFn.cpp LINK_LIBS PRIVATE diff --git a/test/unit/DisposableElementsAttr/CMakeLists.txt b/test/unit/DisposableElementsAttr/CMakeLists.txt index 1ca064e2dd..3b90c06c6f 100644 --- a/test/unit/DisposableElementsAttr/CMakeLists.txt +++ b/test/unit/DisposableElementsAttr/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestDisposableElementsAttr +add_onnx_mlir_unittest(TestDisposableElementsAttr TestDisposableElementsAttr.cpp LINK_LIBS PRIVATE diff --git a/test/unit/Einsum/CMakeLists.txt b/test/unit/Einsum/CMakeLists.txt index 1f2f547155..25a93c60c2 100644 --- a/test/unit/Einsum/CMakeLists.txt +++ b/test/unit/Einsum/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestONNXEinsumOp +add_onnx_mlir_unittest(TestONNXEinsumOp TestONNXEinsumOp.cpp LINK_LIBS PRIVATE diff --git a/test/unit/Runtime/CMakeLists.txt b/test/unit/Runtime/CMakeLists.txt index eb3f6ea58d..fecf048a39 100644 --- a/test/unit/Runtime/CMakeLists.txt +++ b/test/unit/Runtime/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(OMTensorTest +add_onnx_mlir_unittest(OMTensorTest OMTensorTest.c INCLUDE_DIRS PRIVATE diff --git a/test/unit/SmallFP/CMakeLists.txt b/test/unit/SmallFP/CMakeLists.txt index 3a082fe43d..317e60e7b9 100644 --- a/test/unit/SmallFP/CMakeLists.txt +++ b/test/unit/SmallFP/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestSmallFP +add_onnx_mlir_unittest(TestSmallFP TestSmallFP.cpp LINK_LIBS PRIVATE diff --git a/test/unit/Strides/CMakeLists.txt b/test/unit/Strides/CMakeLists.txt index 55136d6865..a5e6e254bb 100644 --- a/test/unit/Strides/CMakeLists.txt +++ b/test/unit/Strides/CMakeLists.txt @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -add_unittest(TestStrides +add_onnx_mlir_unittest(TestStrides TestStrides.cpp LINK_LIBS PRIVATE diff --git a/third_party/benchmark b/third_party/benchmark index 2dd015dfef..a4cf155615 160000 --- a/third_party/benchmark +++ b/third_party/benchmark @@ -1 +1 @@ -Subproject commit 2dd015dfef425c866d9a43f2c67d8b52d709acb6 +Subproject commit a4cf155615c63e019ae549e31703bf367df5b471 diff --git a/third_party/onnx b/third_party/onnx index b86cc54efc..b8baa84466 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c +Subproject commit b8baa8446686496da4cc8fda09f2b6fe65c2a02c diff --git a/third_party/stablehlo b/third_party/stablehlo index 92203a9612..8c7946c8c9 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit 92203a9612dbdd6681ccd3e65bc61586c4290df1 +Subproject commit 8c7946c8c94182d238600fb0775a5ba594a43bd7 diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index 6d2df0f602..170222aa07 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -32,3 +32,8 @@ add_custom_target(OMONNXCheckVersion # Scan files for supported ops add_onnx_mlir_supported_ops(${ONNX_MLIR_SRC_ROOT}/test/backend/inference_backend.py cpu) add_onnx_mlir_supported_ops(${ONNX_MLIR_SRC_ROOT}/test/accelerators/NNPA/backend/CMakeLists.txt NNPA) + +add_custom_target(OMCreateONNXMLIRSource + COMMAND cp -r ${CMAKE_CURRENT_SOURCE_DIR}/onnxmlir ${CMAKE_CURRENT_BINARY_DIR} + COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/RunONNXModel.py ${CMAKE_CURRENT_BINARY_DIR}/onnxmlir/src/onnxmlir) + diff --git a/utils/CheckONNXModel.py b/utils/CheckONNXModel.py index 7b90165de9..e73895ec95 100755 --- a/utils/CheckONNXModel.py +++ b/utils/CheckONNXModel.py @@ -13,7 +13,7 @@ # This script can be used as follows: # -# CheckONNXModel.py --model=reducegpt2.mlir --test-compile-args="-O3 -march=x86-64" --shape-info=0:10x20 +# CheckONNXModel.py --model=reducegpt2.mlir --test-compile-args="-O3 --march=x86-64" --shape-info=0:10x20 # # It will compile and run the model reducegpt2.mlir twice. # * Once with the default (-O0) option, which can be overridden with @@ -166,6 +166,13 @@ def valid_onnx_input(fname): " uint64, int64, float16, float32, float64", ) +parser.add_argument( + "--rtol", type=str, default="", help="Relative tolerance for verification." +) +parser.add_argument( + "--atol", type=str, default="", help="Absolute tolerance for verification." +) + args = parser.parse_args() VERBOSE = os.environ.get("VERBOSE", False) @@ -276,6 +283,10 @@ def main(): # How to verify test_cmd += ["--verify=ref"] test_cmd += ["--verify-every-value"] + if args.atol: + test_cmd += ["--atol=" + args.atol] + if args.rtol: + test_cmd += ["--rtol=" + args.rtol] # Model name. test_cmd += [model_str] diff --git a/utils/RunONNXLib.cpp b/utils/RunONNXLib.cpp index 7a1237740f..f0b90327a3 100644 --- a/utils/RunONNXLib.cpp +++ b/utils/RunONNXLib.cpp @@ -362,7 +362,8 @@ OMTensorList *omTensorListCreateFromInputSignature( } // Create a randomly initialized tensor of the right shape. OMTensor *tensor = nullptr; - if (type.equals("float") || type.equals("f32") || type.equals("i32")) { + if (type.compare("float") == 0 || type.compare("f32") == 0 || + type.compare("i32") == 0) { // Treat floats/f32 and i32 alike as they take the same memory footprint. float *data = nullptr; if (dataPtrList) { @@ -372,8 +373,8 @@ OMTensorList *omTensorListCreateFromInputSignature( assert(data && "failed to allocate data"); } tensor = OM_TENSOR_CREATE(data, shape, rank, ONNX_TYPE_FLOAT, true); - } else if (type.equals("double") || type.equals("f64") || - type.equals("i64")) { + } else if (type.compare("double") == 0 || type.compare("f64") == 0 || + type.compare("i64") == 0) { // Treat floats/f64 and i64 alike as they take the same memory footprint. double *data = nullptr; if (dataPtrList) { @@ -383,7 +384,7 @@ OMTensorList *omTensorListCreateFromInputSignature( assert(data && "failed to allocate data"); } tensor = OM_TENSOR_CREATE(data, shape, rank, ONNX_TYPE_DOUBLE, true); - } else if (type.equals("string")) { + } else if (type.compare("string") == 0) { // Add the handling of string type. "string" is the type in function // signature. char **data = nullptr; diff --git a/utils/RunONNXModel.py b/utils/RunONNXModel.py index d58b4fff56..56c2b603cb 100755 --- a/utils/RunONNXModel.py +++ b/utils/RunONNXModel.py @@ -22,6 +22,8 @@ import tempfile import json import importlib.util +import shlex +import shutil from onnx import numpy_helper from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE @@ -61,23 +63,23 @@ def check_non_negative(argname, value): nargs="?", const="compilation.log", default=None, - help="Output compilation messages to file, default compilation.log", + help="Output compilation messages to file, default compilation.log.", ) parser.add_argument( "-m", "--model", type=lambda s: valid_onnx_input(s), - help="Path to an ONNX model (.onnx or .mlir)", + help="Path to an ONNX model (.onnx or .mlir).", ) parser.add_argument( "-c", "--compile-args", type=str, default="", - help="Arguments passed directly to onnx-mlir command." " See bin/onnx-mlir --help", + help="Arguments passed directly to onnx-mlir command." " See bin/onnx-mlir --help.", ) parser.add_argument( - "-C", "--compile-only", action="store_true", help="Only compile the input model" + "-C", "--compile-only", action="store_true", help="Only compile the input model." ) parser.add_argument( "--compile-using-input-shape", @@ -85,17 +87,22 @@ def check_non_negative(argname, value): help="Compile the model by using the shape info getting from" " the inputs in the reference folder set by --load-ref", ) -parser.add_argument("--print-input", action="store_true", help="Print out inputs") +parser.add_argument("--print-input", action="store_true", help="Print out inputs.") parser.add_argument( "--print-output", action="store_true", - help="Print out inference outputs produced by onnx-mlir", + help="Print out inference outputs produced by onnx-mlir.", +) +parser.add_argument( + "--print-signatures", + action="store_true", + help="Print out the input and output signatures of the model.", ) parser.add_argument( "--save-onnx", metavar="PATH", type=str, - help="File path to save the onnx model. Only effective if " "--verify=onnxruntime", + help="File path to save the onnx model. Only effective if --verify=onnxruntime.", ) parser.add_argument( "--verify", @@ -103,12 +110,12 @@ def check_non_negative(argname, value): help="Verify the output by using onnxruntime or reference" " inputs/outputs. By default, no verification. When being" " enabled, --verify-with-softmax or --verify-every-value" - " must be used to specify verification mode", + " must be used to specify verification mode.", ) parser.add_argument( "--verify-all-ops", action="store_true", - help="Verify all operation outputs when using onnxruntime", + help="Verify all operation outputs when using onnxruntime.", ) parser.add_argument( "--verify-with-softmax", @@ -122,35 +129,52 @@ def check_non_negative(argname, value): parser.add_argument( "--verify-every-value", action="store_true", - help="Verify every value of the output using atol and rtol", + help="Verify every value of the output using atol and rtol.", ) parser.add_argument( - "--rtol", type=str, default="0.05", help="Relative tolerance for verification" + "--rtol", type=str, default="0.05", help="Relative tolerance for verification." ) parser.add_argument( - "--atol", type=str, default="0.01", help="Absolute tolerance for verification" + "--atol", type=str, default="0.01", help="Absolute tolerance for verification." ) lib_group = parser.add_mutually_exclusive_group() lib_group.add_argument( - "--save-so", + "--save-model", + metavar="PATH", + type=str, + help="Path to a folder to save the compiled model.", +) +lib_group.add_argument( + "--load-model", metavar="PATH", type=str, - help="File path to save the generated shared library of" " the model", + help="Path to a folder to load a compiled model for " + "inference, and the ONNX model will not be re-compiled.", ) lib_group.add_argument( - "--load-so", + "--cache-model", metavar="PATH", type=str, - help="File path to load a generated shared library for " - "inference, and the ONNX model will not be re-compiled", + help="When finding a compiled model in given path, reuse it. " + "Otherwise, compile model and save it into the given path.", +) + +parser.add_argument( + "-o", + "--default-model-name", + metavar="MODEL_NAME", + type=str, + default="model", + help="Change the default model name that is used for two generated files: " + " .so and .constants.bin. Default is model.", ) parser.add_argument( "--save-ref", metavar="PATH", type=str, - help="Path to a folder to save the inputs and outputs" " in protobuf", + help="Path to a folder to save the inputs and outputs in protobuf.", ) data_group = parser.add_mutually_exclusive_group() data_group.add_argument( @@ -158,7 +182,10 @@ def check_non_negative(argname, value): metavar="PATH", type=str, help="Path to a folder containing reference inputs and outputs stored in protobuf." - " If --verify=ref, inputs and outputs are reference data for verification", + " If --verify=ref, inputs and outputs are reference data for verification.", +) +data_group.add_argument( + "--inputs-from-arrays", help="List of numpy arrays used as inputs for inference." ) data_group.add_argument( "--load-ref-from-numpy", @@ -172,7 +199,7 @@ def check_non_negative(argname, value): "--shape-info", type=str, help="Shape for each dynamic input of the model, e.g. 0:1x10x20,1:7x5x3. " - "Used to generate random inputs for the model if --load-ref is not set", + "Used to generate random inputs for the model if --load-ref is not set.", ) parser.add_argument( @@ -181,7 +208,7 @@ def check_non_negative(argname, value): help="Lower bound values for each data type. Used inputs." " E.g. --lower-bound=int64:-10,float32:-0.2,uint8:1." " Supported types are bool, uint8, int8, uint16, int16, uint32, int32," - " uint64, int64,float16, float32, float64", + " uint64, int64,float16, float32, float64.", ) parser.add_argument( "--upper-bound", @@ -189,30 +216,31 @@ def check_non_negative(argname, value): help="Upper bound values for each data type. Used to generate random inputs." " E.g. --upper-bound=int64:10,float32:0.2,uint8:9." " Supported types are bool, uint8, int8, uint16, int16, uint32, int32," - " uint64, int64, float16, float32, float64", + " uint64, int64, float16, float32, float64.", ) parser.add_argument( "-w", "--warmup", type=lambda s: check_non_negative("--warmup", s), default=0, - help="The number of warmup inference runs", + help="The number of warmup inference runs.", ) parser.add_argument( "-n", "--n-iteration", type=lambda s: check_positive("--n-iteration", s), default=1, - help="The number of inference runs excluding warmup", + help="The number of inference runs excluding warmup.", ) parser.add_argument( "--seed", type=str, default="42", - help="seed to initialize the random num generator for inputs", + help="seed to initialize the random num generator for inputs.", ) args = parser.parse_args() + if args.verify and (args.verify_with_softmax is None) and (not args.verify_every_value): raise RuntimeError( "Choose verification mode: --verify-with-softmax or " @@ -602,66 +630,111 @@ def data_without_top_bottom_quartile(data, percent): return data[trim:-trim] -def main(): - if not (args.model or args.load_so): - print("error: no input model, use argument --model and/or --load-so.") - print(parser.format_usage()) - exit(1) +class InferenceSession: + """ + in onnxruntime: + class onnxruntime.InferenceSession(path_or_bytes: str | bytes | os.PathLike, sess_options: onnxruntime.SessionOptions | None = None, providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs)[source] - # Get shape information if given. - # args.shape_info in the form of 'input_index:d1xd2, input_index:d1xd2' - input_shapes = {} - if args.shape_info: - for input_shape in args.shape_info.strip().split(","): - input_index_shape = input_shape.split(":") - input_index = input_index_shape[0] - assert not (input_index in input_shapes), "Duplicate input indices" - dims = [int(d) for d in input_index_shape[1].split("x")] - input_shapes[int(input_index)] = dims - - # Load the onnx model. - if args.model and args.model.endswith(".onnx"): - model = onnx.load(args.model) - # Get names of all intermediate tensors and modify model such that each of - # them will be an output of the model. If using onnxruntime for - # verification, we can then verify every operation output. - output_names = [o.name for o in model.graph.output] - output_names = list(OrderedDict.fromkeys(output_names)) - if args.verify and args.verify == "onnxruntime" and args.verify_all_ops: - print("Extending the onnx model to check every node output ...\n") - output_names = sum( - [[n for n in node.output if n != ""] for node in model.graph.node], [] - ) - output_names = list(OrderedDict.fromkeys(output_names)) - model = extend_model_output(model, output_names) + In onnxmlir, session_options and provider will be merged into kwargs, and + ignored. onnxruntime.SessionOptions may contain some useful info, + but onnxruntime package is needed to interpret it. Therefore, it is ignored now. + Another argument, 'options' is added for onnxmlir to specify options for RunONNXModel.py + """ - # Save the modified onnx file of the model if required. - if args.save_onnx: - print("Saving modified onnx model to ", args.save_onnx, "\n") - onnx.save(model, args.save_onnx) + def __init__(self, model_file, **kwargs): + global args + if "options" in kwargs.keys(): + options = kwargs["options"] + args = parser.parse_args(shlex.split(options)) - # Compile, run, and verify. - with tempfile.TemporaryDirectory() as temp_dir: - print("Temporary directory has been created at {}".format(temp_dir)) + if model_file: + if model_file.endswith(".onnx") or model_file.endswith(".mlir"): + args.model = model_file + else: + args.load_model = compiled_name - shared_lib_path = "" + # Default model name that will be used for the compiled model. + # e.g. model.so, model.constants.bin, ... + self.default_model_name = args.default_model_name + + # Handle cache_model. + if args.cache_model: + shared_lib_path = args.cache_model + f"/{self.default_model_name}.so" + if not os.path.exists(shared_lib_path): + print( + 'Cached compiled model not found in "' + + args.cache_model + + '": save model this run.' + ) + args.save_model = args.cache_model + else: + print( + 'Cached compiled model found in "' + + args.cache_model + + '": load model this run.' + ) + args.load_model = args.cache_model + args.cache_model = None + + # Get shape information if given. + # args.shape_info in the form of 'input_index:d1xd2, input_index:d1xd2' + input_shapes = {} + if args.shape_info: + for input_shape in args.shape_info.strip().split(","): + input_index_shape = input_shape.split(":") + input_index = input_index_shape[0] + assert not (input_index in input_shapes), "Duplicate input indices" + dims = [int(d) for d in input_index_shape[1].split("x")] + input_shapes[int(input_index)] = dims + + # Load the onnx model. + if args.model and args.model.endswith(".onnx"): + model = onnx.load(args.model) + # Get names of all intermediate tensors and modify model such that each of + # them will be an output of the model. If using onnxruntime for + # verification, we can then verify every operation output. + output_names = [o.name for o in model.graph.output] + output_names = list(OrderedDict.fromkeys(output_names)) + if args.verify and args.verify == "onnxruntime" and args.verify_all_ops: + print("Extending the onnx model to check every node output ...\n") + output_names = sum( + [[n for n in node.output if n != ""] for node in model.graph.node], + [], + ) + output_names = list(OrderedDict.fromkeys(output_names)) + model = extend_model_output(model, output_names) + + # Save the modified onnx file of the model if required. + if args.save_onnx: + print("Saving modified onnx model to ", args.save_onnx, "\n") + onnx.save(model, args.save_onnx) # If a shared library is given, use it without compiling the ONNX model. # Otherwise, compile the ONNX model. - if args.load_so: - shared_lib_path = args.load_so + if args.load_model: + self.model_dir = args.load_model else: + # Compile the ONNX model. + self.temp_dir = tempfile.TemporaryDirectory() + print("Temporary directory has been created at {}\n".format(self.temp_dir)) print("Compiling the model ...") + self.model_dir = self.temp_dir.name # Prepare input and output paths. - output_path = os.path.join(temp_dir, "model") - shared_lib_path = os.path.join(temp_dir, "model.so") + output_path = os.path.join(self.model_dir, self.default_model_name) if args.model.endswith(".onnx"): - input_model_path = os.path.join(temp_dir, "model.onnx") - onnx.save(model, input_model_path) + if args.verify and args.verify == "onnxruntime" and args.verify_all_ops: + input_model_path = os.path.join( + self.model_dir, f"{self.default_model_name}.onnx" + ) + onnx.save(model, input_model_path) + else: + input_model_path = args.model elif args.model.endswith(".mlir") or args.model.endswith(".onnxtext"): input_model_path = args.model else: - print("Invalid input model path. Must end with .onnx or .mlir") + print( + "Invalid input model path. Must end with .onnx or .mlir or .onnxtext" + ) exit(1) # Prepare compiler arguments. @@ -705,30 +778,111 @@ def main(): end = time.perf_counter() print(" took ", end - start, " seconds.\n") - # Save the generated .so file of the model if required. - if args.save_so: - print("Saving the shared library to", args.save_so, "\n") - execute_commands(["rsync", "-ar", shared_lib_path, args.save_so]) + # Save the following information: + # - .so file, + # - .constants.bin file, and + # - compilation.log containing the compilation output. + if args.save_model: + if not os.path.exists(args.save_model): + os.makedirs(args.save_model) + if not os.path.isdir(args.save_model): + print("Path to --save-model is not a folder") + exit(0) + # .so file. + shared_lib_path = self.model_dir + f"/{self.default_model_name}.so" + if os.path.exists(shared_lib_path): + print("Saving the shared library to", args.save_model) + shutil.copy2(shared_lib_path, args.save_model) + # .constants.bin file. + constants_file_path = os.path.join( + self.model_dir, f"{self.default_model_name}.constants.bin" + ) + if os.path.exists(constants_file_path): + print("Saving the constants file to", args.save_model, "\n") + shutil.copy2(constants_file_path, args.save_model) + # Compilation log. + log_file_path = os.path.join(args.save_model, "compile.log") + with open(log_file_path, "w") as f: + print("Saving the compilation log to", args.save_model, "\n") + f.write(msg) # Exit if only compiling the model. if args.compile_only: exit(0) # Use the generated shared library to create an execution session. - print("Loading the compiled model ...") start = time.perf_counter() - if args.load_so: + shared_lib_path = self.model_dir + f"/{self.default_model_name}.so" + if not os.path.exists(shared_lib_path): + print(f"Input model {shared_lib_path} does not exist") + exit(0) + print("Loading the compiled model ...") + if args.load_model: sess = OMExecutionSession(shared_lib_path, tag="None") else: sess = OMExecutionSession(shared_lib_path) end = time.perf_counter() print(" took ", end - start, " seconds.\n") + self.sess = sess + + """ + From onnxruntime API: + + run(output_names, input_feed, run_options=None) + Compute the predictions. + + PARAMETERS: + output_names – name of the outputs + input_feed – dictionary { input_name: input_value } + run_options – See onnxruntime.RunOptions. + RETURNS: + list of results, every result is either a numpy array, a sparse tensor, a list or a dictionary. + + For onnxmlir, the run_options is ignored. If 'input_feed' is None, the + input could be randomly generated or read from file, as args specified. + In future, add '--shape-info' here. Better than in InferenceSession to + allow different shape from run to run. + """ + + def run(self, outputname, input_feed, **kwargs): + # Get shape information if given. + # args.shape_info in the form of 'input_index:d1xd2, input_index:d1xd2' + input_shapes = {} + if args.shape_info: + for input_shape in args.shape_info.strip().split(","): + input_index_shape = input_shape.split(":") + input_index = input_index_shape[0] + assert not (input_index in input_shapes), "Duplicate input indices" + dims = [int(d) for d in input_index_shape[1].split("x")] + input_shapes[int(input_index)] = dims # Get the input and output signature. - input_signature = sess.input_signature() - output_signature = sess.output_signature() + input_signature = self.sess.input_signature() + output_signature = self.sess.output_signature() input_names = get_names_in_signature(input_signature) output_names = get_names_in_signature(output_signature) + if args.print_signatures: + print("Model's input signature: ", input_signature.strip()) + print("Model's output signature: ", output_signature.strip()) + + inputs = [] + # Get input from input_feed, if input_feed is provided + if input_feed: + if isinstance(input_feed, dict): + for name in input_names: + if name in input_feed: + inputs.append(input_feed[name]) + else: + print("input name given: ", input_feed.keys()) + print("input name expected by model: ", input_names) + print("do not match") + exit(1) + # Since Python guarantees the order of values in a dictionary, + # the name check could be ignored as follows: + # inputs = list(input_feed.values()) + else: + inputs = input_feed + args.inputs_from_arrays = inputs # Prepare input data. inputs = [] @@ -736,6 +890,8 @@ def main(): inputs = read_input_from_refs(len(input_names), args.load_ref) elif args.load_ref_from_numpy: inputs = read_input_from_refs(len(input_names), args.load_ref_from_numpy) + elif args.inputs_from_arrays: + inputs = args.inputs_from_arrays else: inputs = generate_random_input(input_signature, input_shapes) @@ -754,16 +910,18 @@ def main(): # Running inference. print("Running inference ...") + # Let onnx-mlir know where to find the constants file. + os.environ["OM_CONSTANT_PATH"] = self.model_dir for i in range(args.warmup): start = time.perf_counter() - outs = sess.run(inputs) + outs = self.sess.run(inputs) end = time.perf_counter() print(" {} warmup: {} seconds".format(ordinal(i + 1), end - start)) perf_results = [] for i in range(args.n_iteration): start = time.perf_counter() - outs = sess.run(inputs) + outs = self.sess.run(inputs) end = time.perf_counter() elapsed = end - start perf_results += [elapsed] @@ -868,6 +1026,22 @@ def main(): "using atol={}, rtol={} ...".format(args.atol, args.rtol), ) verify_outs(outs[i], ref_outs[i]) + if outputname: + res = {outputname[i]: outs[i] for i in range(len(outs))} + return res + else: + return outs + + +# Standalone driver +def main(): + if not (args.model or args.load_model): + print("error: no input model, use argument --model and/or --load-model.") + print(parser.format_usage()) + exit(1) + + sess = InferenceSession(None) + return sess.run(None, None) if __name__ == "__main__": diff --git a/utils/RunONNXModelZoo.py b/utils/RunONNXModelZoo.py index 91fab6a861..56fe2ab27c 100755 --- a/utils/RunONNXModelZoo.py +++ b/utils/RunONNXModelZoo.py @@ -37,7 +37,7 @@ - Use `-m model_name` to check a list of selected models. Example: - $ ONNX_MLIR_HOME=/onnx-mlir/build/Release/ /onnx-mlir/utils/RunONNXModelZoo.py -m mnist-8 -c "-O3 -mcpu=z14" + $ ONNX_MLIR_HOME=/onnx-mlir/build/Release/ /onnx-mlir/utils/RunONNXModelZoo.py -m mnist-8 -c "-O3 --march=z16" """ if not os.environ.get("ONNX_MLIR_HOME", None): diff --git a/utils/analyze-simd.py b/utils/analyze-simd.py index bb7650f9e2..6141b3957f 100755 --- a/utils/analyze-simd.py +++ b/utils/analyze-simd.py @@ -86,8 +86,10 @@ def define_arch_op_names(arch): op_name["vfma"] = "vfma" op_name["vmul"] = "vfm.b" op_name["vdiv"] = "vfd" - # vector conversion between formats (NNPA <-> fp) - op_name["vconv"] = "(vclfnh|vclfnl|vcfn|vcrnf|vcnf)" + # vector conversion between formats (NNPA <-> fp, FP <-> int, int <-> int) + op_name["vconv"] = ( + "(vclfnh|vclfnl|vcfn|vcrnf|vcnf|vclgd|vclfeb|vclgdb|vpkh|vpkf|vpkg)" + ) # add | sub| max | min | compare op_name["vadd"] = "([vw]fa|[vw]fs|[vw]fmax|[vw]fmin|[vw]f[ck][eh])" op_name["load"] = "lg" diff --git a/utils/build-mlir.sh b/utils/build-mlir.sh index 6a9ff0cddc..d3ed3772f0 100644 --- a/utils/build-mlir.sh +++ b/utils/build-mlir.sh @@ -8,7 +8,8 @@ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_RTTI=ON \ -DENABLE_LIBOMPTARGET=OFF \ - -DLLVM_ENABLE_LIBEDIT=OFF + -DLLVM_ENABLE_LIBEDIT=OFF \ + $EXTRA_CMAKE_ARGS cmake --build . -- ${MAKEFLAGS} cmake --build . --target check-mlir diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index 4965b58726..b072d0ba39 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ -git clone -n https://github.com/llvm/llvm-project.git +git clone -n https://github.com/xilinx/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 60a7d33106d3cd645d3100a8a935a1e3837f885d && cd .. +cd llvm-project && git checkout e8be3bea2ce0ec51b614cd7eb7d5d3a1e56d9524 && cd .. diff --git a/utils/documentOps.py b/utils/documentOps.py index 199ca0cfab..b23726b1f2 100755 --- a/utils/documentOps.py +++ b/utils/documentOps.py @@ -4,7 +4,7 @@ ##################### documentOps.py ######################################## # -# Copyright 2022, 2023 The IBM Research Authors. +# Copyright 2022, 2024 The IBM Research Authors. # ################################################################################ # @@ -23,6 +23,15 @@ # Default values are used when no explicit ==MIN==/==MAX== values are used. min_opset_default = 6 max_opset_default = "*" +# NNPA supported. Ordered list, with oldest first and most recent last. +nnpa_supported_list = ["z16"] + +# Derived variables (do not update). +nnpa_level_default = nnpa_supported_list[-1] # Most recent arch (last). +nnpa_supported_set = set(nnpa_supported_list) +nnpa_most_current_str = "" +if len(nnpa_supported_list) > 1: + nnpa_most_current_str = " - ^" import sys import getopt @@ -35,7 +44,10 @@ # SEMANTIC for LABELING (one line per directive) # # ==ARCH== -# where is cpu/NNPA/... this option is valid until reset by another ARCH dir +# where is cpu or NNPA.. this option is valid until reset by another ARCH dir +# +# ==LEVEL== +# where is a comma separated names of versions supported by NNPA. # # ==OP== # where is the ONNX op name @@ -69,6 +81,8 @@ def print_usage(msg=""): print(' -a, --arch : report on "==ARCH== ".') print(" -d, --debug: include debug.") print(" -i, --input : input file.") + if "NNPA" in target_arch: + print(' -l, --level : report on "==LEVEL== ".') print(" -n, --notes: include notes/TODOs.") print(" -p, --path : path to onnx-mlir util directory.") print(" -u, --unsupported: list unsupported ops.") @@ -81,6 +95,7 @@ def print_usage(msg=""): hightest_opset = None # Highest opset found in the description. opset_dict = {} # -> in "==OP== ". limit_dict = {} # -> in "==LIM== ". +level_dict = {} # -> in "==LEVEL== ". min_dict = {} # -> in "==MIN== ". max_dict = {} # -> in "==MAX== ". todo_dict = {} # -> in "==TODO== ". @@ -106,7 +121,6 @@ def parse_file(file_name): file = open(file_name, "r") except OSError: print_usage("Could not open file `" + file_name + "`") - op = "" arch = "" for line in file: @@ -117,7 +131,7 @@ def parse_file(file_name): arch = p[1] if debug: print("process arch", arch) - continue + continue if arch != target_arch: continue # Additional top paragraph @@ -142,6 +156,50 @@ def parse_file(file_name): if debug: print("got supported op", op, "at level", list_op_version[op]) continue + # Scan NNPA Level + if "NNPA" in target_arch: + p = re.search(r"==LEVEL==\s+(\w+)(?:,\s*(\w+))*", l) + if p is not None: + assert op is not None, "Level without op." + assert op not in level_dict, "Redefinition of level for op " + op + current_set = set(p.groups()) + join_set = current_set & nnpa_supported_set + if not join_set: + if debug: + print( + "process NNPA level, no overlap between", + current_set, + "and", + nnpa_supported_set, + ) + else: + # Find the first and last in set according to the order in nnpa_supported_list. + first_in_set = None + last_in_set = None + for x in nnpa_supported_list: + if x in join_set: + last_in_set = x + if first_in_set == None: + first_in_set = x + assert first_in_set and last_in_set, "should both be defined" + if debug: + print( + "join set is", + join_set, + "first", + first_in_set, + "last", + last_in_set, + ) + if last_in_set == nnpa_level_default: # First to default (current). + level_dict[op] = first_in_set + nnpa_most_current_str + elif first_in_set == last_in_set: # Only one. + level_dict[op] = first_in_set + else: # Interval finishing before current. + level_dict[op] = first_in_set + " - " + last_in_set + if debug: + print("process NNPA level", level_dict[op]) + continue # Limits. p = re.search(r"==LIM==\s+(.*)\s*$", l) if p is not None: @@ -232,6 +290,13 @@ def print_md(): + str(hightest_opset) + "." ) + if "NNPA" in target_arch: + print( + " * A ^ indicates onnx-mlir is compatible with the latest" + + " level of the NNPA Architecture which is " + + str(nnpa_level_default) + + "." + ) print("\n") # Additional top paragraph. @@ -239,8 +304,17 @@ def print_md(): print(additional_top_paragraph) print("\n") # Table. - header = ["Op", "Supported Opsets (inclusive)", "Limitations"] - separator = ["---", "---", "---"] + if "NNPA" in target_arch: + header = [ + "Op", + "Supported Opsets (inclusive)", + "Minimum NNPA Level(Inclusive)", + "Limitations", + ] + separator = ["---", "---", "---", "---"] + else: + header = ["Op", "Supported Opsets (inclusive)", "Limitations"] + separator = ["---", "---", "---"] if emit_notes: header.append("Notes") separator.append("---") @@ -248,8 +322,16 @@ def print_md(): print_row(separator) for op in sorted(list_op_version.keys()): supported_op = op in min_dict + if supported_op and "NNPA" in target_arch and op not in level_dict: + supported_op = False if supported_op: info = ["**" + op + "**", f"{min_dict[op]} - {max_dict[op]}"] + if "NNPA" in target_arch: + info = [ + "**" + op + "**", + f"{min_dict[op]} - {max_dict[op]}", + f"{level_dict[op]}", + ] else: if not emit_unsupported: continue @@ -288,7 +370,10 @@ def main(argv): for opt, arg in opts: if opt in ("-a", "--arch"): target_arch = arg - input_command += " --arch " + arg + input_command += " --arch " + target_arch + elif opt in ("-l", "--level"): + nnpa_level = arg + input_command += " --level " + nnpa_level elif opt in ("-d", "--debug"): debug = 1 elif opt in ("-h", "--help"): diff --git a/utils/fixLitTest.py b/utils/fixLitTest.py index 717f2c3567..b9e5ab3b46 100755 --- a/utils/fixLitTest.py +++ b/utils/fixLitTest.py @@ -425,6 +425,7 @@ def main(argv): dprint("\n>> Tested with " + str(test_error_num) + " errors:") for f in test_error_functions: dprint(">> " + f) + dprint(">> Completed processing of " + lit_test_filename + "\n") if __name__ == "__main__": diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 33cdc35148..5eadfaa95a 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -66,7 +66,7 @@ # ==UPDATE_ONNX_VERSION_OPSET== # Look for tag above and update all references when upgrading the ONNX support within ONNX-MLIR. -current_onnx_version = "1.15.0" +current_onnx_version = "1.17.0" # Check the version of onnx package being used. if ( @@ -86,8 +86,8 @@ version_dict = { "Abs": [13], - "Acos": [7], - "Acosh": [9], + "Acos": [22], + "Acosh": [22], "Adagrad": [1], "Adam": [1], "Add": [14], @@ -95,13 +95,13 @@ "ArgMax": [13], "ArgMin": [13], "ArrayFeatureExtractor": [1], - "Asin": [7], - "Asinh": [9], - "Atan": [7], - "Atanh": [9], - "AveragePool": [19], - "BatchNormalization": [15], - "Bernoulli": [15], + "Asin": [22], + "Asinh": [22], + "Atan": [22], + "Atanh": [22], + "AveragePool": [22], + "BatchNormalization": [15, 9], + "Bernoulli": [22], "Binarizer": [1], "BitShift": [11], "BitwiseAnd": [18], @@ -109,7 +109,7 @@ "BitwiseOr": [18], "BitwiseXor": [18], "BlackmanWindow": [17], - "Cast": [19], + "Cast": [21], "CastLike": [19], "CastMap": [1], "CategoryMapper": [1], @@ -122,60 +122,60 @@ "ConcatFromSequence": [11], "Constant": [19], "ConstantOfShape": [20], - "Conv": [11], + "Conv": [22], "ConvInteger": [10], - "ConvTranspose": [11], - "Cos": [7], - "Cosh": [9], + "ConvTranspose": [22], + "Cos": [22], + "Cosh": [22], "Col2Im": [18], "CumSum": [14], - "DeformConv": [19], + "DeformConv": [22], "DepthToSpace": [13], "DequantizeLinear": [19], - "Det": [11], + "Det": [22], "DFT": [20, 17], "DictVectorizer": [1], "Div": [14], - "Dropout": [13], + "Dropout": [22], "DynamicQuantizeLinear": [11], "Einsum": [12], - "Elu": [6], + "Elu": [22], "Equal": [19], "Erf": [13], "Exp": [13], "Expand": [13], - "EyeLike": [9], + "EyeLike": [22], "FeatureVectorizer": [1], "Flatten": [13], "Floor": [13], - "GRU": [14], + "GRU": [22], "Gather": [13], "GatherElements": [13], "GatherND": [13], "Gelu": [20], "Gemm": [13], - "GlobalAveragePool": [1], + "GlobalAveragePool": [22], "GlobalLpPool": [2], - "GlobalMaxPool": [1], + "GlobalMaxPool": [22], "Gradient": [1], "Greater": [13], "GreaterOrEqual": [16], - "GridSample": [16], - "GroupNormalization": [18], + "GridSample": [22, 20, 16], + "GroupNormalization": [21, 18], "HammingWindow": [17], "HannWindow": [17], - "HardSigmoid": [6], + "HardSigmoid": [22], "Hardmax": [13], - "HardSwish": [14], + "HardSwish": [22], "Identity": [19], "If": [19], "Imputer": [1], - "InstanceNormalization": [6], + "InstanceNormalization": [22], "IsInf": [20], "IsNaN": [20], "LayerNormalization": [17], "LRN": [13], - "LSTM": [14], + "LSTM": [22], "LabelEncoder": [2], "LeakyRelu": [16], "Less": [13], @@ -185,25 +185,25 @@ "Log": [13], "LogSoftmax": [13], "Loop": [19], - "LpNormalization": [1], - "LpPool": [18], + "LpNormalization": [22], + "LpPool": [22], "MatMul": [13], "MatMulInteger": [10], "Max": [13], - "MaxPool": [12], - "MaxRoiPool": [1], - "MaxUnpool": [11], + "MaxPool": [22], + "MaxRoiPool": [22], + "MaxUnpool": [22], "Mean": [13], "MeanVarianceNormalization": [13], "MelWeightMatrix": [17], "Min": [13], - "Mish": [18], + "Mish": [22], "Mod": [13], "Momentum": [1], "Mul": [14], - "Multinomial": [7], + "Multinomial": [22], "Neg": [13], - "NegativeLogLikelihoodLoss": [13], + "NegativeLogLikelihoodLoss": [22], "NonMaxSuppression": [11], "NonZero": [13], "Normalizer": [1], @@ -220,11 +220,11 @@ "QLinearConv": [10], "QLinearMatMul": [10], "QuantizeLinear": [19], - "RNN": [14], - "RandomNormal": [1], - "RandomNormalLike": [1], - "RandomUniform": [1], - "RandomUniformLike": [1], + "RNN": [22], + "RandomNormal": [22], + "RandomNormalLike": [22], + "RandomUniform": [22], + "RandomUniformLike": [22], "Range": [11], "Reciprocal": [13], "ReduceL1": [18, 13], @@ -241,8 +241,8 @@ "Reshape": [19], "Resize": [19, 18, 13, 11, 10], "ReverseSequence": [10], - "RoiAlign": [16], - "Round": [11], + "RoiAlign": [22], + "Round": [22], "SVMClassifier": [1], "SVMRegressor": [1], "Scaler": [1], @@ -250,7 +250,7 @@ "Scatter": [11], "ScatterElements": [18], "ScatterND": [18], - "Selu": [6], + "Selu": [22], "SequenceAt": [11], "SequenceConstruct": [11], "SequenceEmpty": [11], @@ -262,14 +262,14 @@ "Shrink": [9], "Sigmoid": [13], "Sign": [13], - "Sin": [7], - "Sinh": [9], + "Sin": [22], + "Sinh": [22], "Size": [19], "Slice": [13], "Softmax": [13, 11], "SoftmaxCrossEntropyLoss": [13], - "Softplus": [1], - "Softsign": [1], + "Softplus": [22], + "Softsign": [22], "SpaceToDepth": [13], "Split": [18, 13, 11], "SplitToSequence": [11], @@ -279,10 +279,10 @@ "STFT": [17], "Sub": [14], "Sum": [13], - "Tan": [7], + "Tan": [22], "Tanh": [13], "TfIdfVectorizer": [9], - "ThresholdedRelu": [10], + "ThresholdedRelu": [22], "Tile": [13], "TopK": [11], "Transpose": [13], @@ -396,6 +396,8 @@ "Gelu", "Greater", "GreaterOrEqual", + "GridSample", + "GroupNormalizationV18", "Hardmax", "If", "IsInf", @@ -444,14 +446,16 @@ ] # Op with fold function -OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11"] +OpsWithFolder = ["Constant", "Squeeze", "SqueezeV11", "ReduceMean"] # Op with ConstantLike trait OpsWithConstantLike = ["Constant"] # Op with Helper functions -# Here the functions are for data flow analysis. OpsWithHelpers = { + "EyeLike": """ + mlir::Type getResultElementType(); + """, "Loop": """ mlir::Operation::result_range v_final(); mlir::Operation::result_range scan_outputs(); @@ -471,6 +475,7 @@ "Constant", "Cast", "ConstantOfShape", + "EyeLike", "If", "Loop", "RandomNormal", @@ -609,6 +614,10 @@ # FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients # FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero # +# // 4-bit integer data types +# UINT4 = 21; // Unsigned integer in range [0, 15] +# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation +# # // Future extensions go here. # } onnx_types = ( @@ -633,6 +642,8 @@ "float8e4m3fnuz", "float8e5m2", "float8e5m2fnuz", + "uint4", + "int4", ) tblgen_types = ( "BF16", @@ -656,6 +667,8 @@ "F8E4M3FNUZ", "F8E5M2", "F8E5M2FNUZ", + "AnyUI4", + "AnyI4", ) # Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode @@ -1046,10 +1059,12 @@ def parse_type_str(allowedType): "seq": "SeqOf", "map": "TupleOf", "bool": "I1", + "uint4": "UI<4>", "uint8": "UI8", "uint16": "UI16", "uint32": "UI32", "uint64": "UI64", + "int4": "I<4>", "int8": "I8", "int16": "I16", "int32": "I32", @@ -1169,7 +1184,7 @@ def gen_op_def(schema, with_version=False): regions[attr.name] = "AnyRegion" # Generate decl for op traits. - traits = ["Pure"] + traits = ["Pure", f"OpVersionTrait<{schema.since_version}>"] # Generate ConstantLike traits. if opName in OpsWithConstantLike: diff --git a/utils/install-onnx-mlir.sh b/utils/install-onnx-mlir.sh index 789d1e4d54..c870fe41d1 100755 --- a/utils/install-onnx-mlir.sh +++ b/utils/install-onnx-mlir.sh @@ -7,6 +7,7 @@ if [[ -z "$pythonLocation" ]]; then -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_DIR=${MLIR_DIR} \ + $EXTRA_CMAKE_ARGS \ .. else cmake -G Ninja \ @@ -15,6 +16,7 @@ else -DLLVM_ENABLE_ASSERTIONS=ON \ -DPython3_ROOT_DIR=$pythonLocation \ -DMLIR_DIR=${MLIR_DIR} \ + $EXTRA_CMAKE_ARGS \ .. fi cmake --build . diff --git a/utils/install-protobuf.sh b/utils/install-protobuf.sh index 70e133e48d..33f8859675 100755 --- a/utils/install-protobuf.sh +++ b/utils/install-protobuf.sh @@ -4,7 +4,9 @@ git clone -b v${PROTOBUF_VERSION} --depth 1 --recursive https://github.com/proto cd protobuf ./autogen.sh -./configure --enable-static=no -make -j$(sysctl -n hw.logicalcpu) install -cd python -python3 setup.py install --cpp_implementation +./configure --enable-static=no --prefix=/usr +sudo make -j2 install + +# Doesn't work on Ubuntu, just needed for MacOS? +#cd python +#python3 setup.py install --cpp_implementation diff --git a/utils/install-venv.sh b/utils/install-venv.sh new file mode 100755 index 0000000000..ab431fb557 --- /dev/null +++ b/utils/install-venv.sh @@ -0,0 +1,6 @@ +#!/usr/bin/bash +set -e + +python3 -m pip install --upgrade wheel +python3 -m pip install -r requirements.txt +pip install onnx==1.15.0 diff --git a/utils/mlir2FileCheck.py b/utils/mlir2FileCheck.py index 415a05cde2..44491e1ce0 100755 --- a/utils/mlir2FileCheck.py +++ b/utils/mlir2FileCheck.py @@ -268,12 +268,20 @@ def process_line(i, line): new_line = process_name(new_line, def1_pat, "BLOCK_TILE_", ",", 1) new_line = process_name(new_line, def2_pat, "BLOCK_IN_", " =", 1) else: - definitions = def_qual_pat.findall(new_line) + def_qual_pat2 = re.compile( + r"((?:%(?:[a-zA-Z0-9][a-zA-Z0-9_\-]*)(?::\d+)?,?\s+)*)=" + ) + definitions = def_qual_pat2.findall(new_line) for d in definitions: - (name, num) = d - x = use_name(name, num, " =") - y = record_name_def(name, num, "VAR_" + name, " =", 0, new_line) - new_line = new_line.replace(x, y) + arg_def_pat = re.compile(r"%([a-zA-Z0-9][a-zA-Z0-9_\-]*)(:\d+)?") + arg_defs = d.split(",") + for arg_def in arg_defs: + args = arg_def_pat.findall(arg_def) + for arg in args: + (name, num) = arg + x = use_name(name, num, "") + y = record_name_def(name, num, "VAR_" + name, "", 0, new_line) + new_line = new_line.replace(x, y) # Process uses and map use. uses = use_qual_pat.findall(new_line) diff --git a/utils/onnxmlir/LICENSE b/utils/onnxmlir/LICENSE new file mode 100644 index 0000000000..20e4bd8566 --- /dev/null +++ b/utils/onnxmlir/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/utils/onnxmlir/README.md b/utils/onnxmlir/README.md new file mode 100644 index 0000000000..af1f8b18c0 --- /dev/null +++ b/utils/onnxmlir/README.md @@ -0,0 +1,68 @@ +This package provides a python interface to use onnx-mlir compiler to run inference of an onnx model similar to onnxruntime interface. The basic parameters of the interface are supported with options ignored. + +## Description +Let's start with [an onnxrutime example](https://onnxruntime.ai/docs/get-started/with-python#pytorch-cv): +``` +import onnxruntime as ort +import numpy as np +x, y = test_data[0][0], test_data[0][1] +ort_sess = ort.InferenceSession('fashion_mnist_model.onnx') +outputs = ort_sess.run(None, {'input': x.numpy()}) + +# Print Result +predicted, actual = classes[outputs[0][0].argmax(0)], classes[y] +print(f'Predicted: "{predicted}", Actual: "{actual}"') +``` + +With onnxmlir package, onnx-mlir can be used to replace onnxrutnime as follows: +``` +import onnxmlir as ort +import numpy as np +x, y = test_data[0][0], test_data[0][1] +ort_sess = ort.InferenceSession('fashion_mnist_model.onnx') +outputs = ort_sess.run(None, {'input': x.numpy()}) + +# Print Result +predicted, actual = classes[outputs[0][0].argmax(0)], classes[y] +print(f'Predicted: "{predicted}", Actual: "{actual}"') +``` + +In current version, the onnx-mlir compiler is not contained in the python +package yet. Use env variable ONNX_MLIR_HOME to specify the location of the onnx-mlir compiler to be used. For example `export ONNX_MLIR_HOME=/mypath/onnx-mlir/build/Debug`. + +Another way to run the onnx model is to precompile it into a static library first. +``` +onnx-mlir -O3 fashin_mnist_mode.onnx +``` + +This compilation will generate fashin_mnist_mode.so. Then the library can be used as model in the Python script as follows: +``` +import onnxmlir as ort +import numpy as np +x, y = test_data[0][0], test_data[0][1] +ort_sess = ort.InferenceSession('fashion_mnist_model.so') +outputs = ort_sess.run(None, {'input': x.numpy()}) + +# Print Result +predicted, actual = classes[outputs[0][0].argmax(0)], classes[y] +print(f'Predicted: "{predicted}", Actual: "{actual}"') +``` + +This package supports list or dictionary for input and output. For example, the input for run could be list of tensor. +``` +outputs = ort_sess.run(None, [a, b, c]) +``` + +Another extra named argment for InferenceSession is introduced to specify the extra arguments accepted by onnx-mlir/utils/RunONNXModel.py. Here is an example: +``` +sess = onnxmlir.inferenceSession("test_add.onnx", options='--compile-args="-O3 --parallel" --print-output') +``` + +## Installation + +### Install from local directory +At top of onnx-mlir: `pip3 install utils/onnxmlir` + +### Install from repo +After the package is uploaded, you can install with 'pip3 install onnxmlir` + diff --git a/utils/onnxmlir/VERSION_NUMBER b/utils/onnxmlir/VERSION_NUMBER new file mode 100644 index 0000000000..6e8bf73aa5 --- /dev/null +++ b/utils/onnxmlir/VERSION_NUMBER @@ -0,0 +1 @@ +0.1.0 diff --git a/utils/onnxmlir/pyproject.toml b/utils/onnxmlir/pyproject.toml new file mode 100644 index 0000000000..096389e5ee --- /dev/null +++ b/utils/onnxmlir/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" +[project] +name = "onnxmlir" +version = "0.1.0" +authors = [ + { name="Tong Chen", email="chentong@us.ibm.com" }, + { name="Alexandre Eichenberger", email="alexe@us.ibm.com" }, +] +description = "Python driver to compile/run onnx model with onnx-mlir" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", +] + +[project.urls] +Homepage = "https://github.com/onnx/onnx-mlir" +Issues = "https://github.com/onnx/onnx-mlir/issues" diff --git a/utils/onnxmlir/src/onnxmlir/__init__.py b/utils/onnxmlir/src/onnxmlir/__init__.py new file mode 100644 index 0000000000..87080c1013 --- /dev/null +++ b/utils/onnxmlir/src/onnxmlir/__init__.py @@ -0,0 +1 @@ +from .RunONNXModel import InferenceSession diff --git a/utils/onnxmlir/tests/test_1.py b/utils/onnxmlir/tests/test_1.py new file mode 100644 index 0000000000..50564a6fde --- /dev/null +++ b/utils/onnxmlir/tests/test_1.py @@ -0,0 +1,10 @@ +# Test: no name for input and output +import numpy as np +import onnxmlir + +a = np.zeros((3, 4, 5), dtype=np.float32) +b = a + 4 + +sess = onnxmlir.InferenceSession("test_add.onnx") +r = sess.run(None, [a, b]) +print(r) diff --git a/utils/onnxmlir/tests/test_2.py b/utils/onnxmlir/tests/test_2.py new file mode 100644 index 0000000000..5cc56b2ca9 --- /dev/null +++ b/utils/onnxmlir/tests/test_2.py @@ -0,0 +1,11 @@ +import numpy as np +import onnxmlir + +a0 = np.zeros((3, 4, 5), dtype=np.float32) +a = a0 + 2 +b = a0 + 4 + +sess = onnxmlir.InferenceSession("test_add.onnx") +# The name for input is specified by onnx model. +r = sess.run(["my_out"], {"x": a, "y": b}) +print(r) diff --git a/utils/onnxmlir/tests/test_3.py b/utils/onnxmlir/tests/test_3.py new file mode 100644 index 0000000000..d34defba09 --- /dev/null +++ b/utils/onnxmlir/tests/test_3.py @@ -0,0 +1,12 @@ +import numpy as np +import onnxmlir + +a0 = np.zeros((3, 4, 5), dtype=np.float32) +a = a0 + 2 +b = a0 + 4 + +sess = onnxmlir.InferenceSession( + "test_add.onnx", options='--compile-args="-O3 --parallel" --print-output' +) +r = sess.run(["my_out"], {"x": a, "y": b}) +print(r) diff --git a/utils/onnxmlir/tests/test_add.onnx b/utils/onnxmlir/tests/test_add.onnx new file mode 100644 index 0000000000..8aa4ab2bc3 Binary files /dev/null and b/utils/onnxmlir/tests/test_add.onnx differ diff --git a/utils/onnxmlirrun.py b/utils/onnxmlirrun.py index ae1332c68f..fcca711bfc 100644 --- a/utils/onnxmlirrun.py +++ b/utils/onnxmlirrun.py @@ -101,7 +101,7 @@ def compile(self): output_path = os.path.join(self.temp_dir.name, self.temp_lib_name) command_str += ["-o", output_path] if self.target == "zAIU": - command_str += ["--maccel=NNPA", "-O3", "--mcpu=z16"] + command_str += ["--maccel=NNPA", "-O3", "--march=z16"] command_str += self.options.split() # Compile the model. diff --git a/utils/pre-onnx-mlir.py b/utils/pre-onnx-mlir.py index 85e7e49ba8..66a5ea4a2c 100644 --- a/utils/pre-onnx-mlir.py +++ b/utils/pre-onnx-mlir.py @@ -39,7 +39,7 @@ # ==UPDATE_ONNX_VERSION_OPSET== # Look for tag above and update all references when upgrading the ONNX support within ONNX-MLIR. -current_onnx_opset = 19 +current_onnx_opset = 21 converted_model = version_converter.convert_version(original_model, current_onnx_opset) diff --git a/utils/python/transformers/run_gpt2_from_huggingface.py b/utils/python/transformers/run_gpt2_from_huggingface.py new file mode 100644 index 0000000000..ffe8783858 --- /dev/null +++ b/utils/python/transformers/run_gpt2_from_huggingface.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 + +##################### run_gpt2_from_huggingface.py ############################# +# +# Copyright 2019-2024 The IBM Research Authors. +# +################################################################################ +# +# This script is to run GPT2 from HuggingFace. +# Command: ONNX_MLIR_HOME=/workdir/onnx-mlir/build/Debug python run_gpt2_from_huggingface.py 2>&1 | tee log.txt +# +# When running this script for the first time, it will download onnx models from +# HuggingFace and compile the models. The onnx models and compiled models are +# cached in the current folder (by default). +# +# Change compile_flags if targeting a different machine. +################################################################################ + +import os +import sys +import time +import json +import requests as req +from urllib.request import urlretrieve + +import numpy as np +from transformers import AutoTokenizer + +# Include runtime directory in python paths, so PyRuntime can be imported. +RUNTIME_DIR = os.path.join(os.environ["ONNX_MLIR_HOME"], "lib") +sys.path.append(RUNTIME_DIR) +try: + from PyCompileAndRuntime import OMCompileExecutionSession +except ImportError: + raise ImportError( + "Looks like you did not build the PyRuntime target, build it by running `make PyRuntime`." + "You may need to set ONNX_MLIR_HOME to `onnx-mlir/build/Debug` since `make PyRuntime` outputs to `build/Debug` by default" + ) + +# Information to download onnx models from HuggingFace. +model_name_or_path = "gpt2" # can be gpt2, gpt2-medium, gpt2-large +decoder_model_name = "decoder_model.onnx" +decoder_with_past_model_name = "decoder_with_past_model.onnx" +config_json_name = "config.json" +decoder_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_model_name}" +decoder_data_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_model_name}_data" +decoder_with_past_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_with_past_model_name}" +decoder_with_past_data_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_with_past_model_name}_data" +config_json_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{config_json_name}" + +# Local directories for caching the model. +cache_dir = "./" +decoder_model_path = f"{cache_dir}/{decoder_model_name}" +decoder_data_path = f"{cache_dir}/{decoder_model_name}_data" +decoder_with_past_model_path = f"{cache_dir}/{decoder_with_past_model_name}" +decoder_with_past_data_path = f"{cache_dir}/{decoder_with_past_model_name}_data" +config_json_path = f"{cache_dir}/{config_json_name}" + +# Download the model to a local dir. +if not os.path.exists(decoder_model_path): + print(f"Downloading {decoder_url}") + urlretrieve(decoder_url, decoder_model_path) + print("Done") +if not os.path.exists(decoder_data_path): + if req.head(decoder_data_url, allow_redirects=True).status_code == 200: + print(f"Downloading {decoder_data_url}") + urlretrieve(decoder_data_url, decoder_data_path) + print("Done") +if not os.path.exists(decoder_with_past_model_path): + print(f"Downloading {decoder_with_past_url}") + urlretrieve(decoder_with_past_url, decoder_with_past_model_path) + print("Done") +if not os.path.exists(decoder_with_past_data_path): + if req.head(decoder_with_past_data_url, allow_redirects=True).status_code == 200: + print(f"Downloading {decoder_with_past_data_url}") + urlretrieve(decoder_with_past_data_url, decoder_with_past_data_path) + print("Done") +if not os.path.exists(config_json_path): + print(f"Downloading the config json file {config_json_url}") + urlretrieve(config_json_url, config_json_path) + print("Done") + +with open(config_json_path) as f: + cfg = json.load(f) + print("Model configuration: {}\n".format(cfg)) + num_attention_heads = cfg["n_head"] + hidden_size = cfg["n_embd"] + num_layers = cfg["n_layer"] + eos_token_id = cfg["eos_token_id"] + +# Create CompileExecutionSession to compile and run the model, +compile_flags = "-O3 -v --onnx-op-stats TXT" +# compile_flags = "-O3 --march=z16 --maccel=NNPA -v --onnx-op-stats TXT" +decoder_sess = OMCompileExecutionSession( + decoder_model_path, compile_flags + " -tag=decoder", reuse_compiled_model=1 +) +decoder_with_past_sess = OMCompileExecutionSession( + decoder_with_past_model_path, + compile_flags + " -tag=decoder_with_past", + reuse_compiled_model=1, +) + +# Setup a tokenizer. +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) + +# Tokenize the input text. +prompt_text = "Which is the highest mountain in Japan?" +pt_inputs = tokenizer(prompt_text, return_tensors="pt") +output_length = 13 + +# Generate tokens. +ts = [] +num_runs = 3 +for r in range(num_runs): + output_ids = [] + kv_cache = None + + t = 0 + attention_mask = pt_inputs["attention_mask"].numpy() + inputs = [pt_inputs["input_ids"].numpy(), attention_mask] + for step in range(output_length): + t0 = time.time() + if kv_cache is None: + outputs = decoder_sess.run(inputs) + else: + outputs = decoder_with_past_sess.run(inputs) + t_elap = time.time() - t0 + t += t_elap + # Greedy approach is used here. + logits = outputs[0][:, -1, :] + next_id = np.argmax(logits, axis=1, keepdims=True) + kv_cache = outputs[1:] + # Only for batchsize = 1 + attention_mask = np.append( + attention_mask, np.array([[1]], dtype=np.int64), axis=1 + ) + inputs = [next_id] + kv_cache + [attention_mask] + output_ids += [next_id[0][0]] + + ts += [t] + if r == num_runs - 1: + # Expected answer: "The highest mountain in Japan is the Mt. Fuji." + print("Prompt: {}".format(prompt_text)) + print("Generated words: {}\n".format(tokenizer.decode(output_ids).strip())) + +print("times", ts) +t = np.min(ts) +print("t_elap: %.2f seconds" % (t)) +print( + "Latency: {} msec/token, thoughput: {} tokens/sec".format( + t / output_length * 1000, output_length / t + ) +)