From 4c5f0b6166882c797d01368549197e6f89d3e1b9 Mon Sep 17 00:00:00 2001 From: Xuan Hien Date: Sat, 14 Oct 2023 13:44:32 +0700 Subject: [PATCH] Add EfficientNMSCustom_TRT --- CMakeLists.txt | 1 + README.md | 267 ++----- README_ORIGIN.md | 212 ++++++ plugin/CMakeLists.txt | 1 + plugin/README.md | 1 + plugin/api/InferPlugin.cpp | 2 + .../efficientNMSCustomPlugin/CMakeLists.txt | 21 + ...EfficientNMSCustomPlugin_PluginConfig.yaml | 52 ++ plugin/efficientNMSCustomPlugin/README.md | 162 +++++ .../efficientNMSCustomInference.cu | 675 ++++++++++++++++++ .../efficientNMSCustomInference.cuh | 260 +++++++ .../efficientNMSCustomInference.h | 30 + .../efficientNMSCustomParameters.h | 60 ++ .../efficientNMSCustomPlugin.cpp | 463 ++++++++++++ .../efficientNMSCustomPlugin.h | 96 +++ 15 files changed, 2099 insertions(+), 204 deletions(-) create mode 100644 README_ORIGIN.md create mode 100644 plugin/efficientNMSCustomPlugin/CMakeLists.txt create mode 100644 plugin/efficientNMSCustomPlugin/EfficientNMSCustomPlugin_PluginConfig.yaml create mode 100644 plugin/efficientNMSCustomPlugin/README.md create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cu create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cuh create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.h create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomParameters.h create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.cpp create mode 100644 plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 880bdf48..7450bc3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -119,6 +119,7 @@ endif() include_directories( ${CUDA_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/cub ) find_library(CUDNN_LIB cudnn HINTS ${CUDA_TOOLKIT_ROOT_DIR} ${CUDNN_ROOT_DIR} PATH_SUFFIXES lib64 lib/x64 lib) diff --git a/README.md b/README.md index a9627165..b5caccef 100644 --- a/README.md +++ b/README.md @@ -1,212 +1,71 @@ -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Documentation](https://img.shields.io/badge/TensorRT-documentation-brightgreen.svg)](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html) +# TensorRT custom plugin -# TensorRT Open Source Software -This repository contains the Open Source Software (OSS) components of NVIDIA TensorRT. It includes the sources for TensorRT plugins and parsers (Caffe and ONNX), as well as sample applications demonstrating usage and capabilities of the TensorRT platform. These open source software components are a subset of the TensorRT General Availability (GA) release with some extensions and bug-fixes. +Just add some new custom tensorRT plugin -* For code contributions to TensorRT-OSS, please see our [Contribution Guide](CONTRIBUTING.md) and [Coding Guidelines](CODING-GUIDELINES.md). -* For a summary of new additions and updates shipped with TensorRT-OSS releases, please refer to the [Changelog](CHANGELOG.md). -* For business inquiries, please contact [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com) -* For press and other inquiries, please contact Hector Marinez at [hmarinez@nvidia.com](mailto:hmarinez@nvidia.com) +## New plugin -Need enterprise support? NVIDIA global support is available for TensorRT with the [NVIDIA AI Enterprise software suite](https://www.nvidia.com/en-us/data-center/products/ai-enterprise/). Check out [NVIDIA LaunchPad](https://www.nvidia.com/en-us/launchpad/ai/ai-enterprise/) for free access to a set of hands-on labs with TensorRT hosted on NVIDIA infrastructure. +- [EfficientNMSCustom_TRT](./plugin/efficientNMSCustomPlugin/): Same Efficient NMS, but return boxes indices -Join the [TensorRT and Triton community](https://www.nvidia.com/en-us/deep-learning-ai/triton-tensorrt-newsletter/) and stay current on the latest product updates, bug fixes, content, best practices, and more. +## Prerequisites + +- Deepstream 6.3 + +## Install + +Follow guide from + +Please refer to the guide under [github.com/NVIDIA-AI-IOT/deepstream_tao_apps](https://github.com/NVIDIA-AI-IOT/deepstream_tao_apps/blob/master/TRT-OSS/x86/README.md) + +### 1. Installl Cmake (>= 3.13) + +TensorRT OSS requires cmake >= v3.13, so install cmake 3.13 if your cmake version is lower than 3.13 + +``` +wget https://github.com/Kitware/CMake/releases/download/v3.19.4/cmake-3.19.4.tar.gz +tar xvf cmake-3.19.4.tar.gz +cd cmake-3.19.4/ +mkdir $HOME/install +./configure --prefix=$HOME/install +make -j$(nproc) +sudo make install +``` -# Prebuilt TensorRT Python Package -We provide the TensorRT Python package for an easy installation. \ -To install: -```bash -pip install tensorrt +### 2. Build TensorRT OSS Plugin + +| DeepStream Release | TRT Version | TRT_OSS_CHECKOUT_TAG | Support | +| ------------------ | ----------- | -------------------- | ------- | +| 5.0 | TRT 7.0.0 | release/7.0 | No | +| 5.0.1 | TRT 7.0.0 | release/7.0 | No | +| 5.1 | TRT 7.2.X | 21.03 | No | +| 6.0 EA | TRT 7.2.2 | 21.03 | No | +| 6.0 GA | TRT 8.0.1 | release/8.0 | No | +| 6.0.1 | TRT 8.2.1 | release/8.2 | Yes | +| 6.1 | TRT 8.2.5.1 | release/8.2 | Yes | +| 6.1.1 | TRT 8.4.1.5 | release/8.4 | Yes | +| 6.2 | TRT 8.5.2.2 | release/8.5 | Yes | +| 6.3 | TRT 8.5.3.1 | release/8.5 | Yes | + +``` +git clone -b release/8.5 https://github.com/hiennguyen9874/TensorRT +cd TensorRT/ +git submodule update --init --recursive +export TRT_SOURCE=`pwd` +cd $TRT_SOURCE +mkdir -p build && cd build +## NOTE: as mentioned above, please make sure your GPU_ARCHS in TRT OSS CMakeLists.txt +## if GPU_ARCHS is not in TRT OSS CMakeLists.txt, add -DGPU_ARCHS=xy as below, for xy, refer to below "How to Get GPU_ARCHS" section +$HOME/install/bin/cmake .. -DGPU_ARCHS=xy -DTRT_LIB_DIR=/usr/lib/x86_64-linux-gnu/ -DCMAKE_C_COMPILER=/usr/bin/gcc -DTRT_BIN_DIR=`pwd`/out +make nvinfer_plugin -j$(nproc) ``` -You can skip the **Build** section to enjoy TensorRT with Python. -# Build +After building ends successfully, libnvinfer_plugin.so\* will be generated under `pwd`/out/ or ./build. -## Prerequisites -To build the TensorRT-OSS components, you will first need the following software packages. - -**TensorRT GA build** -* [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download) v8.5.3.1 - -**System Packages** -* [CUDA](https://developer.nvidia.com/cuda-toolkit) - * Recommended versions: - * cuda-11.8.0 + cuDNN-8.6 - * cuda-10.2 + cuDNN-8.4 -* [GNU make](https://ftp.gnu.org/gnu/make/) >= v4.1 -* [cmake](https://github.com/Kitware/CMake/releases) >= v3.13 -* [python]() >= v3.6.9, <= v3.10.x -* [pip](https://pypi.org/project/pip/#history) >= v19.0 -* Essential utilities - * [git](https://git-scm.com/downloads), [pkg-config](https://www.freedesktop.org/wiki/Software/pkg-config/), [wget](https://www.gnu.org/software/wget/faq.html#download) - -**Optional Packages** -* Containerized build - * [Docker](https://docs.docker.com/install/) >= 19.03 - * [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) -* Toolchains and SDKs - * (Cross compilation for Jetson platform) [NVIDIA JetPack](https://developer.nvidia.com/embedded/jetpack) >= 5.0 (current support only for TensorRT 8.4.0) - * (Cross compilation for QNX platform) [QNX Toolchain](https://blackberry.qnx.com/en) -* PyPI packages (for demo applications/tests) - * [onnx](https://pypi.org/project/onnx/) 1.9.0 - * [onnxruntime](https://pypi.org/project/onnxruntime/) 1.8.0 - * [tensorflow-gpu](https://pypi.org/project/tensorflow/) >= 2.5.1 - * [Pillow](https://pypi.org/project/Pillow/) >= 9.0.1 - * [pycuda](https://pypi.org/project/pycuda/) < 2021.1 - * [numpy](https://pypi.org/project/numpy/) - * [pytest](https://pypi.org/project/pytest/) -* Code formatting tools (for contributors) - * [Clang-format](https://clang.llvm.org/docs/ClangFormat.html) - * [Git-clang-format](https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/git-clang-format) - - > NOTE: [onnx-tensorrt](https://github.com/onnx/onnx-tensorrt), [cub](http://nvlabs.github.io/cub/), and [protobuf](https://github.com/protocolbuffers/protobuf.git) packages are downloaded along with TensorRT OSS, and not required to be installed. - -## Downloading TensorRT Build - -1. #### Download TensorRT OSS - ```bash - git clone -b master https://github.com/nvidia/TensorRT TensorRT - cd TensorRT - git submodule update --init --recursive - ``` - -2. #### (Optional - if not using TensorRT container) Specify the TensorRT GA release build path - - If using the TensorRT OSS build container, TensorRT libraries are preinstalled under `/usr/lib/x86_64-linux-gnu` and you may skip this step. - - Else download and extract the TensorRT GA build from [NVIDIA Developer Zone](https://developer.nvidia.com/nvidia-tensorrt-download). - - **Example: Ubuntu 20.04 on x86-64 with cuda-11.8.0** - - ```bash - cd ~/Downloads - tar -xvzf TensorRT-8.5.3.1.Linux.x86_64-gnu.cuda-11.8.cudnn8.6.tar.gz - export TRT_LIBPATH=`pwd`/TensorRT-8.5.3.1 - ``` - - -3. #### (Optional - for Jetson builds only) Download the JetPack SDK - 1. Download and launch the JetPack SDK manager. Login with your NVIDIA developer account. - 2. Select the platform and target OS (example: Jetson AGX Xavier, `Linux Jetpack 5.0`), and click Continue. - 3. Under `Download & Install Options` change the download folder and select `Download now, Install later`. Agree to the license terms and click Continue. - 4. Move the extracted files into the `/docker/jetpack_files` folder. - - -## Setting Up The Build Environment - -For Linux platforms, we recommend that you generate a docker container for building TensorRT OSS as described below. For native builds, please install the [prerequisite](#prerequisites) *System Packages*. - -1. #### Generate the TensorRT-OSS build container. - The TensorRT-OSS build container can be generated using the supplied Dockerfiles and build scripts. The build containers are configured for building TensorRT OSS out-of-the-box. - - **Example: Ubuntu 20.04 on x86-64 with cuda-11.8.0 (default)** - ```bash - ./docker/build.sh --file docker/ubuntu-20.04.Dockerfile --tag tensorrt-ubuntu20.04-cuda11.8 - ``` - **Example: CentOS/RedHat 7 on x86-64 with cuda-10.2** - ```bash - ./docker/build.sh --file docker/centos-7.Dockerfile --tag tensorrt-centos7-cuda10.2 --cuda 10.2 - ``` - **Example: Ubuntu 20.04 cross-compile for Jetson (aarch64) with cuda-11.4.2 (JetPack SDK)** - ```bash - ./docker/build.sh --file docker/ubuntu-cross-aarch64.Dockerfile --tag tensorrt-jetpack-cuda11.4 - ``` - **Example: Ubuntu 20.04 on aarch64 with cuda-11.4.2** - ```bash - ./docker/build.sh --file docker/ubuntu-20.04-aarch64.Dockerfile --tag tensorrt-aarch64-ubuntu20.04-cuda11.4 - ``` - -2. #### Launch the TensorRT-OSS build container. - **Example: Ubuntu 20.04 build container** - ```bash - ./docker/launch.sh --tag tensorrt-ubuntu20.04-cuda11.8 --gpus all - ``` - > NOTE: -
1. Use the `--tag` corresponding to build container generated in Step 1. -
2. [NVIDIA Container Toolkit](#prerequisites) is required for GPU access (running TensorRT applications) inside the build container. -
3. `sudo` password for Ubuntu build containers is 'nvidia'. -
4. Specify port number using `--jupyter ` for launching Jupyter notebooks. - -## Building TensorRT-OSS -* Generate Makefiles and build. - - **Example: Linux (x86-64) build with default cuda-11.8.0** - ```bash - cd $TRT_OSSPATH - mkdir -p build && cd build - cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out - make -j$(nproc) - ``` - - > NOTE: On CentOS7, the default g++ version does not support C++14. For native builds (not using the CentOS7 build container), first install devtoolset-8 to obtain the updated g++ toolchain as follows: - ```bash - yum -y install centos-release-scl - yum-config-manager --enable rhel-server-rhscl-7-rpms - yum -y install devtoolset-8 - export PATH="/opt/rh/devtoolset-8/root/bin:${PATH} - ``` - - **Example: Linux (aarch64) build with default cuda-11.8.0** - ```bash - cd $TRT_OSSPATH - mkdir -p build && cd build - cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64-native.toolchain - make -j$(nproc) - ``` - - **Example: Native build on Jetson (aarch64) with cuda-11.4** - ```bash - cd $TRT_OSSPATH - mkdir -p build && cd build - cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DTRT_PLATFORM_ID=aarch64 -DCUDA_VERSION=11.4 - CC=/usr/bin/gcc make -j$(nproc) - ``` - > NOTE: C compiler must be explicitly specified via `CC=` for native `aarch64` builds of protobuf. - - **Example: Ubuntu 20.04 Cross-Compile for Jetson (aarch64) with cuda-11.4 (JetPack)** - ```bash - cd $TRT_OSSPATH - mkdir -p build && cd build - cmake .. -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64.toolchain -DCUDA_VERSION=11.4 -DCUDNN_LIB=/pdk_files/cudnn/usr/lib/aarch64-linux-gnu/libcudnn.so -DCUBLAS_LIB=/usr/local/cuda-11.4/targets/aarch64-linux/lib/stubs/libcublas.so -DCUBLASLT_LIB=/usr/local/cuda-11.4/targets/aarch64-linux/lib/stubs/libcublasLt.so -DTRT_LIB_DIR=/pdk_files/tensorrt/lib - - make -j$(nproc) - ``` - > NOTE: The latest JetPack SDK v5.0 only supports TensorRT 8.4.0. - - > NOTE: -
1. The default CUDA version used by CMake is 11.8.0. To override this, for example to 10.2, append `-DCUDA_VERSION=10.2` to the cmake command. -
2. If samples fail to link on CentOS7, create this symbolic link: `ln -s $TRT_OUT_DIR/libnvinfer_plugin.so $TRT_OUT_DIR/libnvinfer_plugin.so.8` -* Required CMake build arguments are: - - `TRT_LIB_DIR`: Path to the TensorRT installation directory containing libraries. - - `TRT_OUT_DIR`: Output directory where generated build artifacts will be copied. -* Optional CMake build arguments: - - `CMAKE_BUILD_TYPE`: Specify if binaries generated are for release or debug (contain debug symbols). Values consists of [`Release`] | `Debug` - - `CUDA_VERISON`: The version of CUDA to target, for example [`11.7.1`]. - - `CUDNN_VERSION`: The version of cuDNN to target, for example [`8.6`]. - - `PROTOBUF_VERSION`: The version of Protobuf to use, for example [`3.0.0`]. Note: Changing this will not configure CMake to use a system version of Protobuf, it will configure CMake to download and try building that version. - - `CMAKE_TOOLCHAIN_FILE`: The path to a toolchain file for cross compilation. - - `BUILD_PARSERS`: Specify if the parsers should be built, for example [`ON`] | `OFF`. If turned OFF, CMake will try to find precompiled versions of the parser libraries to use in compiling samples. First in `${TRT_LIB_DIR}`, then on the system. If the build type is Debug, then it will prefer debug builds of the libraries before release versions if available. - - `BUILD_PLUGINS`: Specify if the plugins should be built, for example [`ON`] | `OFF`. If turned OFF, CMake will try to find a precompiled version of the plugin library to use in compiling samples. First in `${TRT_LIB_DIR}`, then on the system. If the build type is Debug, then it will prefer debug builds of the libraries before release versions if available. - - `BUILD_SAMPLES`: Specify if the samples should be built, for example [`ON`] | `OFF`. - - `GPU_ARCHS`: GPU (SM) architectures to target. By default we generate CUDA code for all major SMs. Specific SM versions can be specified here as a quoted space-separated list to reduce compilation time and binary size. Table of compute capabilities of NVIDIA GPUs can be found [here](https://developer.nvidia.com/cuda-gpus). Examples: - - NVidia A100: `-DGPU_ARCHS="80"` - - Tesla T4, GeForce RTX 2080: `-DGPU_ARCHS="75"` - - Titan V, Tesla V100: `-DGPU_ARCHS="70"` - - Multiple SMs: `-DGPU_ARCHS="80 75"` - - `TRT_PLATFORM_ID`: Bare-metal build (unlike containerized cross-compilation) on non Linux/x86 platforms must explicitly specify the target platform. Currently supported options: `x86_64` (default), `aarch64` - -# References - -## TensorRT Resources - -* [TensorRT Developer Home](https://developer.nvidia.com/tensorrt) -* [TensorRT QuickStart Guide](https://docs.nvidia.com/deeplearning/tensorrt/quick-start-guide/index.html) -* [TensorRT Developer Guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html) -* [TensorRT Sample Support Guide](https://docs.nvidia.com/deeplearning/tensorrt/sample-support-guide/index.html) -* [TensorRT ONNX Tools](https://docs.nvidia.com/deeplearning/tensorrt/index.html#tools) -* [TensorRT Discussion Forums](https://devtalk.nvidia.com/default/board/304/tensorrt/) -* [TensorRT Release Notes](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html) - -## Known Issues - -* Please refer to [TensorRT 8.5 Release Notes](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/tensorrt-8.html#tensorrt-8) +### 3. Replace "libnvinfer_plugin.so\*" + +``` +// backup original libnvinfer_plugin.so.x.y, e.g. libnvinfer_plugin.so.8.0.0 +sudo mv /usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so.8.p.q ${HOME}/libnvinfer_plugin.so.8.p.q.bak +// only replace the real file, don't touch the link files, e.g. libnvinfer_plugin.so, libnvinfer_plugin.so.8 +sudo cp $TRT_SOURCE/`pwd`/out/libnvinfer_plugin.so.8.m.n /usr/lib/x86_64-linux-gnu/libnvinfer_plugin.so.8.p.q +sudo ldconfig +``` diff --git a/README_ORIGIN.md b/README_ORIGIN.md new file mode 100644 index 00000000..a9627165 --- /dev/null +++ b/README_ORIGIN.md @@ -0,0 +1,212 @@ +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Documentation](https://img.shields.io/badge/TensorRT-documentation-brightgreen.svg)](https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html) + +# TensorRT Open Source Software +This repository contains the Open Source Software (OSS) components of NVIDIA TensorRT. It includes the sources for TensorRT plugins and parsers (Caffe and ONNX), as well as sample applications demonstrating usage and capabilities of the TensorRT platform. These open source software components are a subset of the TensorRT General Availability (GA) release with some extensions and bug-fixes. + +* For code contributions to TensorRT-OSS, please see our [Contribution Guide](CONTRIBUTING.md) and [Coding Guidelines](CODING-GUIDELINES.md). +* For a summary of new additions and updates shipped with TensorRT-OSS releases, please refer to the [Changelog](CHANGELOG.md). +* For business inquiries, please contact [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com) +* For press and other inquiries, please contact Hector Marinez at [hmarinez@nvidia.com](mailto:hmarinez@nvidia.com) + +Need enterprise support? NVIDIA global support is available for TensorRT with the [NVIDIA AI Enterprise software suite](https://www.nvidia.com/en-us/data-center/products/ai-enterprise/). Check out [NVIDIA LaunchPad](https://www.nvidia.com/en-us/launchpad/ai/ai-enterprise/) for free access to a set of hands-on labs with TensorRT hosted on NVIDIA infrastructure. + +Join the [TensorRT and Triton community](https://www.nvidia.com/en-us/deep-learning-ai/triton-tensorrt-newsletter/) and stay current on the latest product updates, bug fixes, content, best practices, and more. + +# Prebuilt TensorRT Python Package +We provide the TensorRT Python package for an easy installation. \ +To install: +```bash +pip install tensorrt +``` +You can skip the **Build** section to enjoy TensorRT with Python. + +# Build + +## Prerequisites +To build the TensorRT-OSS components, you will first need the following software packages. + +**TensorRT GA build** +* [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-download) v8.5.3.1 + +**System Packages** +* [CUDA](https://developer.nvidia.com/cuda-toolkit) + * Recommended versions: + * cuda-11.8.0 + cuDNN-8.6 + * cuda-10.2 + cuDNN-8.4 +* [GNU make](https://ftp.gnu.org/gnu/make/) >= v4.1 +* [cmake](https://github.com/Kitware/CMake/releases) >= v3.13 +* [python]() >= v3.6.9, <= v3.10.x +* [pip](https://pypi.org/project/pip/#history) >= v19.0 +* Essential utilities + * [git](https://git-scm.com/downloads), [pkg-config](https://www.freedesktop.org/wiki/Software/pkg-config/), [wget](https://www.gnu.org/software/wget/faq.html#download) + +**Optional Packages** +* Containerized build + * [Docker](https://docs.docker.com/install/) >= 19.03 + * [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-docker) +* Toolchains and SDKs + * (Cross compilation for Jetson platform) [NVIDIA JetPack](https://developer.nvidia.com/embedded/jetpack) >= 5.0 (current support only for TensorRT 8.4.0) + * (Cross compilation for QNX platform) [QNX Toolchain](https://blackberry.qnx.com/en) +* PyPI packages (for demo applications/tests) + * [onnx](https://pypi.org/project/onnx/) 1.9.0 + * [onnxruntime](https://pypi.org/project/onnxruntime/) 1.8.0 + * [tensorflow-gpu](https://pypi.org/project/tensorflow/) >= 2.5.1 + * [Pillow](https://pypi.org/project/Pillow/) >= 9.0.1 + * [pycuda](https://pypi.org/project/pycuda/) < 2021.1 + * [numpy](https://pypi.org/project/numpy/) + * [pytest](https://pypi.org/project/pytest/) +* Code formatting tools (for contributors) + * [Clang-format](https://clang.llvm.org/docs/ClangFormat.html) + * [Git-clang-format](https://github.com/llvm-mirror/clang/blob/master/tools/clang-format/git-clang-format) + + > NOTE: [onnx-tensorrt](https://github.com/onnx/onnx-tensorrt), [cub](http://nvlabs.github.io/cub/), and [protobuf](https://github.com/protocolbuffers/protobuf.git) packages are downloaded along with TensorRT OSS, and not required to be installed. + +## Downloading TensorRT Build + +1. #### Download TensorRT OSS + ```bash + git clone -b master https://github.com/nvidia/TensorRT TensorRT + cd TensorRT + git submodule update --init --recursive + ``` + +2. #### (Optional - if not using TensorRT container) Specify the TensorRT GA release build path + + If using the TensorRT OSS build container, TensorRT libraries are preinstalled under `/usr/lib/x86_64-linux-gnu` and you may skip this step. + + Else download and extract the TensorRT GA build from [NVIDIA Developer Zone](https://developer.nvidia.com/nvidia-tensorrt-download). + + **Example: Ubuntu 20.04 on x86-64 with cuda-11.8.0** + + ```bash + cd ~/Downloads + tar -xvzf TensorRT-8.5.3.1.Linux.x86_64-gnu.cuda-11.8.cudnn8.6.tar.gz + export TRT_LIBPATH=`pwd`/TensorRT-8.5.3.1 + ``` + + +3. #### (Optional - for Jetson builds only) Download the JetPack SDK + 1. Download and launch the JetPack SDK manager. Login with your NVIDIA developer account. + 2. Select the platform and target OS (example: Jetson AGX Xavier, `Linux Jetpack 5.0`), and click Continue. + 3. Under `Download & Install Options` change the download folder and select `Download now, Install later`. Agree to the license terms and click Continue. + 4. Move the extracted files into the `/docker/jetpack_files` folder. + + +## Setting Up The Build Environment + +For Linux platforms, we recommend that you generate a docker container for building TensorRT OSS as described below. For native builds, please install the [prerequisite](#prerequisites) *System Packages*. + +1. #### Generate the TensorRT-OSS build container. + The TensorRT-OSS build container can be generated using the supplied Dockerfiles and build scripts. The build containers are configured for building TensorRT OSS out-of-the-box. + + **Example: Ubuntu 20.04 on x86-64 with cuda-11.8.0 (default)** + ```bash + ./docker/build.sh --file docker/ubuntu-20.04.Dockerfile --tag tensorrt-ubuntu20.04-cuda11.8 + ``` + **Example: CentOS/RedHat 7 on x86-64 with cuda-10.2** + ```bash + ./docker/build.sh --file docker/centos-7.Dockerfile --tag tensorrt-centos7-cuda10.2 --cuda 10.2 + ``` + **Example: Ubuntu 20.04 cross-compile for Jetson (aarch64) with cuda-11.4.2 (JetPack SDK)** + ```bash + ./docker/build.sh --file docker/ubuntu-cross-aarch64.Dockerfile --tag tensorrt-jetpack-cuda11.4 + ``` + **Example: Ubuntu 20.04 on aarch64 with cuda-11.4.2** + ```bash + ./docker/build.sh --file docker/ubuntu-20.04-aarch64.Dockerfile --tag tensorrt-aarch64-ubuntu20.04-cuda11.4 + ``` + +2. #### Launch the TensorRT-OSS build container. + **Example: Ubuntu 20.04 build container** + ```bash + ./docker/launch.sh --tag tensorrt-ubuntu20.04-cuda11.8 --gpus all + ``` + > NOTE: +
1. Use the `--tag` corresponding to build container generated in Step 1. +
2. [NVIDIA Container Toolkit](#prerequisites) is required for GPU access (running TensorRT applications) inside the build container. +
3. `sudo` password for Ubuntu build containers is 'nvidia'. +
4. Specify port number using `--jupyter ` for launching Jupyter notebooks. + +## Building TensorRT-OSS +* Generate Makefiles and build. + + **Example: Linux (x86-64) build with default cuda-11.8.0** + ```bash + cd $TRT_OSSPATH + mkdir -p build && cd build + cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out + make -j$(nproc) + ``` + + > NOTE: On CentOS7, the default g++ version does not support C++14. For native builds (not using the CentOS7 build container), first install devtoolset-8 to obtain the updated g++ toolchain as follows: + ```bash + yum -y install centos-release-scl + yum-config-manager --enable rhel-server-rhscl-7-rpms + yum -y install devtoolset-8 + export PATH="/opt/rh/devtoolset-8/root/bin:${PATH} + ``` + + **Example: Linux (aarch64) build with default cuda-11.8.0** + ```bash + cd $TRT_OSSPATH + mkdir -p build && cd build + cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64-native.toolchain + make -j$(nproc) + ``` + + **Example: Native build on Jetson (aarch64) with cuda-11.4** + ```bash + cd $TRT_OSSPATH + mkdir -p build && cd build + cmake .. -DTRT_LIB_DIR=$TRT_LIBPATH -DTRT_OUT_DIR=`pwd`/out -DTRT_PLATFORM_ID=aarch64 -DCUDA_VERSION=11.4 + CC=/usr/bin/gcc make -j$(nproc) + ``` + > NOTE: C compiler must be explicitly specified via `CC=` for native `aarch64` builds of protobuf. + + **Example: Ubuntu 20.04 Cross-Compile for Jetson (aarch64) with cuda-11.4 (JetPack)** + ```bash + cd $TRT_OSSPATH + mkdir -p build && cd build + cmake .. -DCMAKE_TOOLCHAIN_FILE=$TRT_OSSPATH/cmake/toolchains/cmake_aarch64.toolchain -DCUDA_VERSION=11.4 -DCUDNN_LIB=/pdk_files/cudnn/usr/lib/aarch64-linux-gnu/libcudnn.so -DCUBLAS_LIB=/usr/local/cuda-11.4/targets/aarch64-linux/lib/stubs/libcublas.so -DCUBLASLT_LIB=/usr/local/cuda-11.4/targets/aarch64-linux/lib/stubs/libcublasLt.so -DTRT_LIB_DIR=/pdk_files/tensorrt/lib + + make -j$(nproc) + ``` + > NOTE: The latest JetPack SDK v5.0 only supports TensorRT 8.4.0. + + > NOTE: +
1. The default CUDA version used by CMake is 11.8.0. To override this, for example to 10.2, append `-DCUDA_VERSION=10.2` to the cmake command. +
2. If samples fail to link on CentOS7, create this symbolic link: `ln -s $TRT_OUT_DIR/libnvinfer_plugin.so $TRT_OUT_DIR/libnvinfer_plugin.so.8` +* Required CMake build arguments are: + - `TRT_LIB_DIR`: Path to the TensorRT installation directory containing libraries. + - `TRT_OUT_DIR`: Output directory where generated build artifacts will be copied. +* Optional CMake build arguments: + - `CMAKE_BUILD_TYPE`: Specify if binaries generated are for release or debug (contain debug symbols). Values consists of [`Release`] | `Debug` + - `CUDA_VERISON`: The version of CUDA to target, for example [`11.7.1`]. + - `CUDNN_VERSION`: The version of cuDNN to target, for example [`8.6`]. + - `PROTOBUF_VERSION`: The version of Protobuf to use, for example [`3.0.0`]. Note: Changing this will not configure CMake to use a system version of Protobuf, it will configure CMake to download and try building that version. + - `CMAKE_TOOLCHAIN_FILE`: The path to a toolchain file for cross compilation. + - `BUILD_PARSERS`: Specify if the parsers should be built, for example [`ON`] | `OFF`. If turned OFF, CMake will try to find precompiled versions of the parser libraries to use in compiling samples. First in `${TRT_LIB_DIR}`, then on the system. If the build type is Debug, then it will prefer debug builds of the libraries before release versions if available. + - `BUILD_PLUGINS`: Specify if the plugins should be built, for example [`ON`] | `OFF`. If turned OFF, CMake will try to find a precompiled version of the plugin library to use in compiling samples. First in `${TRT_LIB_DIR}`, then on the system. If the build type is Debug, then it will prefer debug builds of the libraries before release versions if available. + - `BUILD_SAMPLES`: Specify if the samples should be built, for example [`ON`] | `OFF`. + - `GPU_ARCHS`: GPU (SM) architectures to target. By default we generate CUDA code for all major SMs. Specific SM versions can be specified here as a quoted space-separated list to reduce compilation time and binary size. Table of compute capabilities of NVIDIA GPUs can be found [here](https://developer.nvidia.com/cuda-gpus). Examples: + - NVidia A100: `-DGPU_ARCHS="80"` + - Tesla T4, GeForce RTX 2080: `-DGPU_ARCHS="75"` + - Titan V, Tesla V100: `-DGPU_ARCHS="70"` + - Multiple SMs: `-DGPU_ARCHS="80 75"` + - `TRT_PLATFORM_ID`: Bare-metal build (unlike containerized cross-compilation) on non Linux/x86 platforms must explicitly specify the target platform. Currently supported options: `x86_64` (default), `aarch64` + +# References + +## TensorRT Resources + +* [TensorRT Developer Home](https://developer.nvidia.com/tensorrt) +* [TensorRT QuickStart Guide](https://docs.nvidia.com/deeplearning/tensorrt/quick-start-guide/index.html) +* [TensorRT Developer Guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html) +* [TensorRT Sample Support Guide](https://docs.nvidia.com/deeplearning/tensorrt/sample-support-guide/index.html) +* [TensorRT ONNX Tools](https://docs.nvidia.com/deeplearning/tensorrt/index.html#tools) +* [TensorRT Discussion Forums](https://devtalk.nvidia.com/default/board/304/tensorrt/) +* [TensorRT Release Notes](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html) + +## Known Issues + +* Please refer to [TensorRT 8.5 Release Notes](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/tensorrt-8.html#tensorrt-8) diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index fb66a5bc..6d1321b7 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -43,6 +43,7 @@ set(PLUGIN_LISTS disentangledAttentionPlugin efficientNMSPlugin efficientNMSPlugin/tftrt + efficientNMSCustomPlugin flattenConcat generateDetectionPlugin gridAnchorPlugin diff --git a/plugin/README.md b/plugin/README.md index d9eb7da9..70dcdf3a 100644 --- a/plugin/README.md +++ b/plugin/README.md @@ -46,6 +46,7 @@ | [specialSlicePlugin](specialSlicePlugin) | SpecialSlice_TRT | 1 | | [splitPlugin](splitPlugin) | Split | 1 | | [voxelGeneratorPlugin](voxelGeneratorPlugin) | VoxelGeneratorPlugin | 1 | +| [efficientNMSCustomPlugin](efficientNMSCustomPlugin) | EfficientNMSCustom_TRT | 1 | ## Known Limitations diff --git a/plugin/api/InferPlugin.cpp b/plugin/api/InferPlugin.cpp index 1bf89202..104cf19b 100644 --- a/plugin/api/InferPlugin.cpp +++ b/plugin/api/InferPlugin.cpp @@ -38,6 +38,7 @@ using namespace nvinfer1::plugin; #include "efficientNMSPlugin.h" #include "tftrt/efficientNMSImplicitTFTRTPlugin.h" #include "tftrt/efficientNMSExplicitTFTRTPlugin.h" +#include "efficientNMSCustomPlugin.h" #include "flattenConcat.h" #include "fmhcaPlugin.h" #include "generateDetectionPlugin.h" @@ -190,6 +191,7 @@ extern "C" initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); + initializePlugin(logger, libNamespace); initializePlugin(logger, libNamespace); #if defined(ENABLE_SM75) || defined(ENABLE_SM80) || defined(ENABLE_SM86) || defined(ENABLE_SM89) initializePlugin(logger, libNamespace); diff --git a/plugin/efficientNMSCustomPlugin/CMakeLists.txt b/plugin/efficientNMSCustomPlugin/CMakeLists.txt new file mode 100644 index 00000000..53b70a7e --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/CMakeLists.txt @@ -0,0 +1,21 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +file(GLOB SRCS *.cpp) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS}) +set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE) +file(GLOB CU_SRCS *.cu) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS}) +set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE) diff --git a/plugin/efficientNMSCustomPlugin/EfficientNMSCustomPlugin_PluginConfig.yaml b/plugin/efficientNMSCustomPlugin/EfficientNMSCustomPlugin_PluginConfig.yaml new file mode 100644 index 00000000..7f81dbec --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/EfficientNMSCustomPlugin_PluginConfig.yaml @@ -0,0 +1,52 @@ +--- +name: EfficientNMSCustom_TRT +interface: "IPluginV2DynamicExt" +versions: + "1": + attributes: + - score_threshold + - iou_threshold + - max_output_boxes + - background_class + - score_activation + - box_coding + attribute_types: + score_threshold: float32 + iou_threshold: float32 + max_output_boxes: int32 + background_class: int32 + score_activation: int32 + box_coding: int32 + attribute_length: + score_threshold: 1 + iou_threshold: 1 + max_output_boxes: 1 + background_class: 1 + score_activation: 1 + box_coding: 1 + attribute_options: + score_threshold: + min: "=0" + max: "=pinf" + iou_threshold: + min: "0" + max: "=pinf" + max_output_boxes: + min: "0" + max: "=pinf" + background_class: + min: "=ninf" + max: "=pinf" + score_activation: + - 0 + - 1 + box_coding: + - 0 + - 1 + attributes_required: + - score_threshold + - iou_threshold + - max_output_boxes + - background_class + - score_activation + - box_coding diff --git a/plugin/efficientNMSCustomPlugin/README.md b/plugin/efficientNMSCustomPlugin/README.md new file mode 100644 index 00000000..1728fc06 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/README.md @@ -0,0 +1,162 @@ +# Efficient NMS Custom Plugin + +#### Table of Contents +- [Description](#description) +- [Structure](#structure) + * [Inputs](#inputs) + * [Dynamic Shape Support](#dynamic-shape-support) + * [Box Coding Type](#box-coding-type) + * [Outputs](#outputs) + * [Parameters](#parameters) +- [Algorithm](#algorithm) + * [Process Description](#process-description) + * [Performance Tuning](#performance-tuning) + * [Additional Resources](#additional-resources) +- [License](#license) + +## Description + +This TensorRT plugin implements an efficient algorithm to perform Non Maximum Suppression for object detection networks. + +This plugin is primarily intended for using with EfficientDet on TensorRT, as this network is particularly sensitive to the latencies introduced by slower NMS implementations. However, the plugin is generic enough that it will work correctly for other detections architectures, such as SSD or FasterRCNN. + +## Structure + +### Inputs + +The plugin has two modes of operation, depending on the given input data. The plugin will automatically detect which mode to operate as, depending on the number of inputs it receives, as follows: + +1. **Standard NMS Mode:** Only two input tensors are given, (i) the bounding box coordinates and (ii) the corresponding classification scores for each box. + +2. **Fused Box Decoder Mode:** Three input tensors are given, (i) the raw localization predictions for each box originating directly from the localization head of the network, (ii) the corresponding classification scores originating from the classification head of the network, and (iii) the default anchor box coordinates usually hardcoded as constant tensors in the network. + +Most object detection networks work by generating raw predictions from a "localization head" which adjust the coordinates of standard non-learned anchor coordinates to produce a tighter fitting bounding box. This process is called "box decoding", and it usually involves a large number of element-wise operations to transform the anchors to final box coordinates. As this can involve exponential operations on a large number of anchors, it can be computationally expensive, so this plugin gives the option of fusing the box decoder within the NMS operation which can be done in a far more efficient manner, resulting in lower latency for the network. + +#### Boxes Input +> **Input Shape:** `[batch_size, number_boxes, 4]` or `[batch_size, number_boxes, number_classes, 4]` +> +> **Data Type:** `float32` or `float16` + +The boxes input can have 3 dimensions in case a single box prediction is produced for all classes (such as in EfficientDet or SSD), or 4 dimensions when separate box predictions are generated for each class (such as in FasterRCNN), in which case `number_classes` >= 1 and must match the number of classes in the scores input. The final dimension represents the four coordinates that define the bounding box prediction. + +For *Standard NMS* mode, this tensor should contain the final box coordinates for each predicted detection. For *Fused Box Decoder* mode, this tensor should have the raw localization predictions. In either case, this data is given as `4` coordinates which makes up the final shape dimension. + +#### Scores Input +> **Input Shape:** `[batch_size, number_boxes, number_classes]` +> +> **Data Type:** `float32` or `float16` + +The scores input has `number_classes` elements with the predicted scores for each candidate class for each of the `number_boxes` anchor boxes. + +Usually, the score values will have passed through a sigmoid activation function before reaching the NMS operation. However, as an optimization, the pre-sigmoid raw scores can also be provided to the NMS plugin to reduce overall network latency. If raw scores are given, enable the `score_activation` parameter so they are processed accordingly. + +#### Anchors Input (Optional) +> **Input Shape:** `[1, number_boxes, 4]` or `[batch_size, number_boxes, 4]` +> +> **Data Type:** `float32` or `float16` + +Only used in *Fused Box Decoder* mode. It is much more efficient to perform the box decoding within this plugin. In this case, the boxes input will be treated as the raw localization head box corrections, and this third input should contain the default anchor/prior box coordinates. + +When used, the input must have 3 dimensions, where the first one may be either `1` in case anchors are constant for all images in a batch, or `batch_size` in case each image has different anchors -- such as in the box refinement NMS of FasterRCNN's second stage. + +### Dynamic Shape Support + +Most input shape dimensions, namely `batch_size`, `number_boxes`, and `number_classes`, for all inputs can be defined dynamically at runtime if the TensorRT engine is built with dynamic input shapes. However, once defined, these dimensions must match across all tensors that use them (e.g. the same `number_boxes` dimension must be given for both boxes and scores, etc.) + +### Box Coding Type +Different object detection networks represent their box coordinate system differently. The two types supported by this plugin are: + +1. **BoxCorners:** The four coordinates represent `[x1, y1, x2, y2]` values, where each x,y pair defines the top-left and bottom-right corners of a bounding box. +2. **BoxCenterSize:** The four coordinates represent `[x, y, w, h]` values, where the x,y pair define the box center location, and the w,h pair define its width and height. + +Note that for NMS purposes, horizontal and vertical coordinates are fully interchangeable. TensorFlow-trained networks, for example, often uses vertical-first coordinates such as `[y1, x1, y2, x2]`, but this coordinate system will work equally well under the BoxCorner coding. Similarly, `[y, x, h, w]` will be properly covered by the BoxCornerSize coding. + +In *Fused Box Decoder* mode, the boxes and anchor tensors should both use the same coding. + +### Outputs + +The following four output tensors are generated: + +- **num_detections:** + This is a `[batch_size, 1]` tensor of data type `int32`. The last dimension is a scalar indicating the number of valid detections per batch image. It can be less than `max_output_boxes`. Only the top `num_detections[i]` entries in `nms_boxes[i]`, `nms_scores[i]` and `nms_classes[i]` are valid. + +- **detection_boxes:** + This is a `[batch_size, max_output_boxes, 4]` tensor of data type `float32` or `float16`, containing the coordinates of non-max suppressed boxes. The output coordinates will always be in BoxCorner format, regardless of the input code type. + +- **detection_scores:** + This is a `[batch_size, max_output_boxes]` tensor of data type `float32` or `float16`, containing the scores for the boxes. + +- **detection_classes:** + This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the classes for the boxes. + +- **detection_indices:** + This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the indices for the boxes. + +### Parameters + +| Type | Parameter | Description +|----------|--------------------------|-------------------------------------------------------- +|`float` |`score_threshold` * |The scalar threshold for score (low scoring boxes are removed). +|`float` |`iou_threshold` |The scalar threshold for IOU (additional boxes that have high IOU overlap with previously selected boxes are removed). +|`int` |`max_output_boxes` |The maximum number of detections to output per image. +|`int` |`background_class` |The label ID for the background class. If there is no background class, set it to `-1`. +|`bool` |`score_activation` * |Set to true to apply sigmoid activation to the confidence scores during NMS operation. +|`int` |`box_coding` |Coding type used for boxes (and anchors if applicable), 0 = BoxCorner, 1 = BoxCenterSize. + +Parameters marked with a `*` have a non-negligible effect on runtime latency. See the [Performance Tuning](#performance-tuning) section below for more details on how to set them optimally. + +## Algorithm + +### Process Description + +The NMS algorithm in this plugin first filters the scores below the given `scoreThreshold`. This subset of scores is then sorted, and their corresponding boxes are then further filtered out by removing boxes that overlap each other with an IOU above the given `iouThreshold`. + +The algorithm launcher and its relevant CUDA kernels are all defined in the `efficientNMSCustomInference.cu` file. + +Specifically, the NMS algorithm does the following: + +- The scores are filtered with the `score_threshold` parameter to reject any scores below the score threshold, while maintaining indexing to cross-reference these scores to their corresponding box coordinates. This is done with the `EfficientNMSCustomFilter` CUDA kernel. + +- If too many elements are kept, due to a very low (or zero) score threshold, the filter operation can become a bottleneck due to the atomic operations involved. To mitigate this, a fallback kernel `EfficientNMSCustomDenseIndex` is used instead which passes all the score elements densely packed and indexed. This method is heuristically selected only if the score threshold is less than 0.007. + +- The selected scores that remain after filtering are sorted in descending order. The indexing is carefully handled to still maintain score to box relationships after sorting. + +- After sorting, the highest 4096 scores are processed by the `EfficientNMSCustom` CUDA kernel. This algorithm uses the index data maintained throughout the previous steps to find the boxes corresponding to the remaining scores. If the fused box decoder is being used, decoding will happen until this stage, where only the top scoring boxes need to be decoded. + +- The NMS kernel uses an efficient filtering algorithm that largely reduces the number of IOU overlap cross-checks between box pairs. The boxes that survive the IOU filtering finally pass through to the output results. At this stage, the sigmoid activation is applied to only the final remaining scores, if `score_activation` is enabled, thereby greatly reducing the amount of sigmoid calculations required otherwise. + +### Performance Tuning + +The plugin implements a very efficient NMS algorithm which largely reduces the latency of this operation in comparison to other NMS plugins. However, there are certain considerations that can help to better fine tune its performance: + +#### Choosing the Score Threshold + +The algorithm is highly sensitive to the selected `score_threshold` parameter. With a higher threshold, fewer elements need to be processed and so the algorithm runs much faster. Therefore, it's beneficial to always select the highest possible score threshold that fulfills the application requirements. Threshold values lower than approximately 0.01 may cause substantially higher latency. + +#### Using Sigmoid Activation + +Depending on network configuration, it is usually more efficient to provide raw scores (pre-sigmoid) to the NMS plugin scores input, and enable the `score_activation` parameter. Doing so applies a sigmoid activation only to the last `max_output_boxes` selected scores, instead of all the predicted scores, largely reducing the computational cost. + +#### Using the Fused Box Decoder + +When using networks with many anchors, such as EfficientDet or SSD, it may be more efficient to do box decoding within the NMS plugin. For this, pass the raw box predictions as the boxes input, and the default anchor coordinates as the optional third input to the plugin. + +### Additional Resources + +The following resources provide a deeper understanding of the NMS algorithm: + +#### Networks +- [EfficientDet](https://arxiv.org/abs/1911.09070) +- [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) +- [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) +- [Mask R-CNN](https://arxiv.org/abs/1703.06870) + + +#### Documentation +- [NMS algorithm](https://www.coursera.org/lecture/convolutional-neural-networks/non-max-suppression-dvrjH) +- [NonMaxSuppression ONNX Op](https://github.com/onnx/onnx/blob/master/docs/Operators.md#NonMaxSuppression) + +## License + +For terms and conditions for use, reproduction, and distribution, see the [TensorRT Software License Agreement](https://docs.nvidia.com/deeplearning/sdk/tensorrt-sla/index.html) +documentation. diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cu b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cu new file mode 100644 index 00000000..3f225344 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cu @@ -0,0 +1,675 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxUtils.h" +#include "cub/cub.cuh" +#include "cuda_runtime_api.h" + +#include "efficientNMSCustomInference.cuh" +#include "efficientNMSCustomInference.h" + +#define NMS_TILES 5 + +using namespace nvinfer1; + +template +__device__ float IOU(EfficientNMSCustomParameters param, BoxCorner box1, BoxCorner box2) +{ + // Regardless of the selected box coding, IOU is always performed in BoxCorner coding. + // The boxes are copied so that they can be reordered without affecting the originals. + BoxCorner b1 = box1; + BoxCorner b2 = box2; + b1.reorder(); + b2.reorder(); + float intersectArea = BoxCorner::intersect(b1, b2).area(); + if (intersectArea <= 0.f) + { + return 0.f; + } + float unionArea = b1.area() + b2.area() - intersectArea; + if (unionArea <= 0.f) + { + return 0.f; + } + return intersectArea / unionArea; +} + +template +__device__ BoxCorner DecodeBoxes(EfficientNMSCustomParameters param, int boxIdx, int anchorIdx, + const Tb* __restrict__ boxesInput, const Tb* __restrict__ anchorsInput) +{ + // The inputs will be in the selected coding format, as well as the decoding function. But the decoded box + // will always be returned as BoxCorner. + Tb box = boxesInput[boxIdx]; + if (!param.boxDecoder) + { + return BoxCorner(box); + } + Tb anchor = anchorsInput[anchorIdx]; + box.reorder(); + anchor.reorder(); + return BoxCorner(box.decode(anchor)); +} + +template +__device__ void MapNMSData(EfficientNMSCustomParameters param, int idx, int imageIdx, const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, + const int* __restrict__ topNumData, const T* __restrict__ sortedScoresData, const int* __restrict__ sortedIndexData, + T& scoreMap, int& classMap, BoxCorner& boxMap, int& boxIdxMap) +{ + // idx: Holds the NMS box index, within the current batch. + // idxSort: Holds the batched NMS box index, which indexes the (filtered, but sorted) score buffer. + // scoreMap: Holds the score that corresponds to the indexed box being processed by NMS. + if (idx >= topNumData[imageIdx]) + { + return; + } + int idxSort = imageIdx * param.numScoreElements + idx; + scoreMap = sortedScoresData[idxSort]; + + // idxMap: Holds the re-mapped index, which indexes the (filtered, but unsorted) buffers. + // classMap: Holds the class that corresponds to the idx'th sorted score being processed by NMS. + // anchorMap: Holds the anchor that corresponds to the idx'th sorted score being processed by NMS. + int idxMap = imageIdx * param.numScoreElements + sortedIndexData[idxSort]; + classMap = topClassData[idxMap]; + int anchorMap = topAnchorsData[idxMap]; + + // boxIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, and unsorted) boxes input buffer. + boxIdxMap = -1; + if (param.shareLocation) // Shape of boxesInput: [batchSize, numAnchors, 1, 4] + { + boxIdxMap = imageIdx * param.numAnchors + anchorMap; + } + else // Shape of boxesInput: [batchSize, numAnchors, numClasses, 4] + { + int batchOffset = imageIdx * param.numAnchors * param.numClasses; + int anchorOffset = anchorMap * param.numClasses; + boxIdxMap = batchOffset + anchorOffset + classMap; + } + // anchorIdxMap: Holds the re-re-mapped index, which indexes the (unfiltered, and unsorted) anchors input buffer. + int anchorIdxMap = -1; + if (param.shareAnchors) // Shape of anchorsInput: [1, numAnchors, 4] + { + anchorIdxMap = anchorMap; + } + else // Shape of anchorsInput: [batchSize, numAnchors, 4] + { + anchorIdxMap = imageIdx * param.numAnchors + anchorMap; + } + // boxMap: Holds the box that corresponds to the idx'th sorted score being processed by NMS. + boxMap = DecodeBoxes(param, boxIdxMap, anchorIdxMap, boxesInput, anchorsInput); +} + +template +__device__ void WriteNMSResult(EfficientNMSCustomParameters param, int* __restrict__ numDetectionsOutput, + T* __restrict__ nmsScoresOutput, int* __restrict__ nmsClassesOutput, BoxCorner* __restrict__ nmsBoxesOutput, + int* __restrict__ nmsIndicesOutput, T threadScore, int threadClass, BoxCorner threadBox, int imageIdx, + unsigned int resultsCounter, int boxIdxMap) +{ + int outputIdx = imageIdx * param.numOutputBoxes + resultsCounter - 1; + if (param.scoreSigmoid) + { + nmsScoresOutput[outputIdx] = sigmoid_mp(threadScore); + } + else if (param.scoreBits > 0) + { + nmsScoresOutput[outputIdx] = add_mp(threadScore, (T) -1); + } + else + { + nmsScoresOutput[outputIdx] = threadScore; + } + nmsClassesOutput[outputIdx] = threadClass; + if (param.clipBoxes) + { + nmsBoxesOutput[outputIdx] = threadBox.clip((T) 0, (T) 1); + } + else + { + nmsBoxesOutput[outputIdx] = threadBox; + } + numDetectionsOutput[imageIdx] = resultsCounter; + + int index = boxIdxMap % param.numAnchors; + + nmsIndicesOutput[outputIdx] = index; +} + +template +__global__ void EfficientNMSCustom(EfficientNMSCustomParameters param, const int* topNumData, int* outputIndexData, + int* outputClassData, const int* sortedIndexData, const T* __restrict__ sortedScoresData, + const int* __restrict__ topClassData, const int* __restrict__ topAnchorsData, const Tb* __restrict__ boxesInput, + const Tb* __restrict__ anchorsInput, int* __restrict__ numDetectionsOutput, T* __restrict__ nmsScoresOutput, + int* __restrict__ nmsClassesOutput, int* __restrict__ nmsIndicesOutput, BoxCorner* __restrict__ nmsBoxesOutput) +{ + unsigned int thread = threadIdx.x; + unsigned int imageIdx = blockIdx.y; + unsigned int tileSize = blockDim.x; + if (imageIdx >= param.batchSize) + { + return; + } + + int numSelectedBoxes = min(topNumData[imageIdx], param.numSelectedBoxes); + int numTiles = (numSelectedBoxes + tileSize - 1) / tileSize; + if (thread >= numSelectedBoxes) + { + return; + } + + __shared__ int blockState; + __shared__ unsigned int resultsCounter; + if (thread == 0) + { + blockState = 0; + resultsCounter = 0; + } + + int threadState[NMS_TILES]; + unsigned int boxIdx[NMS_TILES]; + T threadScore[NMS_TILES]; + int threadClass[NMS_TILES]; + BoxCorner threadBox[NMS_TILES]; + int boxIdxMap[NMS_TILES]; + for (int tile = 0; tile < numTiles; tile++) + { + threadState[tile] = 0; + boxIdx[tile] = thread + tile * blockDim.x; + MapNMSData(param, boxIdx[tile], imageIdx, boxesInput, anchorsInput, topClassData, topAnchorsData, + topNumData, sortedScoresData, sortedIndexData, threadScore[tile], threadClass[tile], threadBox[tile], + boxIdxMap[tile]); + } + + // Iterate through all boxes to NMS against. + for (int i = 0; i < numSelectedBoxes; i++) + { + int tile = i / tileSize; + + if (boxIdx[tile] == i) + { + // Iteration lead thread, figure out what the other threads should do, + // this will be signaled via the blockState shared variable. + if (threadState[tile] == -1) + { + // Thread already dead, this box was already dropped in a previous iteration, + // because it had a large IOU overlap with another lead thread previously, so + // it would never be kept anyway, therefore it can safely be skip all IOU operations + // in this iteration. + blockState = -1; // -1 => Signal all threads to skip iteration + } + else if (threadState[tile] == 0) + { + // As this box will be kept, this is a good place to find what index in the results buffer it + // should have, as this allows to perform an early loop exit if there are enough results. + if (resultsCounter >= param.numOutputBoxes) + { + blockState = -2; // -2 => Signal all threads to do an early loop exit. + } + else + { + // Thread is still alive, because it has not had a large enough IOU overlap with + // any other kept box previously. Therefore, this box will be kept for sure. However, + // we need to check against all other subsequent boxes from this position onward, + // to see how those other boxes will behave in future iterations. + blockState = 1; // +1 => Signal all (higher index) threads to calculate IOU against this box + threadState[tile] = 1; // +1 => Mark this box's thread to be kept and written out to results + + // If the numOutputBoxesPerClass check is enabled, write the result only if the limit for this + // class on this image has not been reached yet. Other than (possibly) skipping the write, this + // won't affect anything else in the NMS threading. + bool write = true; + if (param.numOutputBoxesPerClass >= 0) + { + int classCounterIdx = imageIdx * param.numClasses + threadClass[tile]; + write = (outputClassData[classCounterIdx] < param.numOutputBoxesPerClass); + outputClassData[classCounterIdx]++; + } + if (write) + { + // This branch is visited by one thread per iteration, so it's safe to do non-atomic increments. + resultsCounter++; + + WriteNMSResult(param, numDetectionsOutput, nmsScoresOutput, nmsClassesOutput, nmsBoxesOutput, + nmsIndicesOutput, threadScore[tile], threadClass[tile], threadBox[tile], imageIdx, + resultsCounter, boxIdxMap[tile]); + } + } + } + else + { + // This state should never be reached, but just in case... + blockState = 0; // 0 => Signal all threads to not do any updates, nothing happens. + } + } + + __syncthreads(); + + if (blockState == -2) + { + // This is the signal to exit from the loop. + return; + } + + if (blockState == -1) + { + // This is the signal for all threads to just skip this iteration, as no IOU's need to be checked. + continue; + } + + // Grab a box and class to test the current box against. The test box corresponds to iteration i, + // therefore it will have a lower index than the current thread box, and will therefore have a higher score + // than the current box because it's located "before" in the sorted score list. + T testScore; + int testClass; + BoxCorner testBox; + int testBoxIdxMap; + MapNMSData(param, i, imageIdx, boxesInput, anchorsInput, topClassData, topAnchorsData, topNumData, + sortedScoresData, sortedIndexData, testScore, testClass, testBox, testBoxIdxMap); + + for (int tile = 0; tile < numTiles; tile++) + { + // IOU + if (boxIdx[tile] > i && // Make sure two different boxes are being tested, and that it's a higher index; + boxIdx[tile] < numSelectedBoxes && // Make sure the box is within numSelectedBoxes; + blockState == 1 && // Signal that allows IOU checks to be performed; + threadState[tile] == 0 && // Make sure this box hasn't been either dropped or kept already; + threadClass[tile] == testClass && // Compare only boxes of matching classes; + lte_mp(threadScore[tile], testScore) && // Make sure the sorting order of scores is as expected; + IOU(param, threadBox[tile], testBox) >= param.iouThreshold) // And... IOU overlap. + { + // Current box overlaps with the box tested in this iteration, this box will be skipped. + threadState[tile] = -1; // -1 => Mark this box's thread to be dropped. + } + } + } +} + +template +cudaError_t EfficientNMSCustomLauncher(EfficientNMSCustomParameters& param, int* topNumData, int* outputIndexData, + int* outputClassData, int* sortedIndexData, T* sortedScoresData, int* topClassData, int* topAnchorsData, + const void* boxesInput, const void* anchorsInput, int* numDetectionsOutput, T* nmsScoresOutput, + int* nmsClassesOutput, int* nmsIndicesOutput, void* nmsBoxesOutput, cudaStream_t stream) +{ + unsigned int tileSize = param.numSelectedBoxes / NMS_TILES; + if (param.numSelectedBoxes <= 512) + { + tileSize = 512; + } + if (param.numSelectedBoxes <= 256) + { + tileSize = 256; + } + + const dim3 blockSize = {tileSize, 1, 1}; + const dim3 gridSize = {1, (unsigned int) param.batchSize, 1}; + + if (param.boxCoding == 0) + { + EfficientNMSCustom><<>>(param, topNumData, outputIndexData, + outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, + (BoxCorner*) boxesInput, (BoxCorner*) anchorsInput, numDetectionsOutput, nmsScoresOutput, + nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); + } + else if (param.boxCoding == 1) + { + // Note that nmsBoxesOutput is always coded as BoxCorner, regardless of the input coding type. + EfficientNMSCustom><<>>(param, topNumData, outputIndexData, + outputClassData, sortedIndexData, sortedScoresData, topClassData, topAnchorsData, + (BoxCenterSize*) boxesInput, (BoxCenterSize*) anchorsInput, numDetectionsOutput, nmsScoresOutput, + nmsClassesOutput, nmsIndicesOutput, (BoxCorner*) nmsBoxesOutput); + } + + return cudaGetLastError(); +} + +__global__ void EfficientNMSCustomFilterSegments(EfficientNMSCustomParameters param, const int* __restrict__ topNumData, + int* __restrict__ topOffsetsStartData, int* __restrict__ topOffsetsEndData) +{ + int imageIdx = threadIdx.x; + if (imageIdx > param.batchSize) + { + return; + } + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = imageIdx * param.numScoreElements + topNumData[imageIdx]; +} + +template +__global__ void EfficientNMSCustomFilter(EfficientNMSCustomParameters param, const T* __restrict__ scoresInput, + int* __restrict__ topNumData, int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, + T* __restrict__ topScoresData, int* __restrict__ topClassData) +{ + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + // Boundary Conditions + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) + { + return; + } + + // Shape of scoresInput: [batchSize, numAnchors, numClasses] + int scoresInputIdx = imageIdx * param.numScoreElements + elementIdx; + + // For each class, check its corresponding score if it crosses the threshold, and if so select this anchor, + // and keep track of the maximum score and the corresponding (argmax) class id + T score = scoresInput[scoresInputIdx]; + if (gte_mp(score, (T) param.scoreThreshold)) + { + // Unpack the class and anchor index from the element index + int classIdx = elementIdx % param.numClasses; + int anchorIdx = elementIdx / param.numClasses; + + // If this is a background class, ignore it. + if (classIdx == param.backgroundClass) + { + return; + } + + // Use an atomic to find an open slot where to write the selected anchor data. + if (topNumData[imageIdx] >= param.numScoreElements) + { + return; + } + int selectedIdx = atomicAdd((unsigned int*) &topNumData[imageIdx], 1); + if (selectedIdx >= param.numScoreElements) + { + topNumData[imageIdx] = param.numScoreElements; + return; + } + + // Shape of topScoresData / topClassData: [batchSize, numScoreElements] + int topIdx = imageIdx * param.numScoreElements + selectedIdx; + + if (param.scoreBits > 0) + { + score = add_mp(score, (T) 1); + if (gt_mp(score, (T) (2.f - 1.f / 1024.f))) + { + // Ensure the incremented score fits in the mantissa without changing the exponent + score = (2.f - 1.f / 1024.f); + } + } + + topIndexData[topIdx] = selectedIdx; + topAnchorsData[topIdx] = anchorIdx; + topScoresData[topIdx] = score; + topClassData[topIdx] = classIdx; + } +} + +template +__global__ void EfficientNMSCustomDenseIndex(EfficientNMSCustomParameters param, int* __restrict__ topNumData, + int* __restrict__ topIndexData, int* __restrict__ topAnchorsData, int* __restrict__ topOffsetsStartData, + int* __restrict__ topOffsetsEndData, T* __restrict__ topScoresData, int* __restrict__ topClassData) +{ + int elementIdx = blockDim.x * blockIdx.x + threadIdx.x; + int imageIdx = blockDim.y * blockIdx.y + threadIdx.y; + + if (elementIdx >= param.numScoreElements || imageIdx >= param.batchSize) + { + return; + } + + int dataIdx = imageIdx * param.numScoreElements + elementIdx; + int anchorIdx = elementIdx / param.numClasses; + int classIdx = elementIdx % param.numClasses; + if (param.scoreBits > 0) + { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T) param.scoreThreshold)) + { + score = (T) 1; + } + else if (classIdx == param.backgroundClass) + { + score = (T) 1; + } + else + { + score = add_mp(score, (T) 1); + if (gt_mp(score, (T) (2.f - 1.f / 1024.f))) + { + // Ensure the incremented score fits in the mantissa without changing the exponent + score = (2.f - 1.f / 1024.f); + } + } + topScoresData[dataIdx] = score; + } + else + { + T score = topScoresData[dataIdx]; + if (lt_mp(score, (T) param.scoreThreshold)) + { + topScoresData[dataIdx] = -(1 << 15); + } + else if (classIdx == param.backgroundClass) + { + topScoresData[dataIdx] = -(1 << 15); + } + } + + topIndexData[dataIdx] = elementIdx; + topAnchorsData[dataIdx] = anchorIdx; + topClassData[dataIdx] = classIdx; + + if (elementIdx == 0) + { + // Saturate counters + topNumData[imageIdx] = param.numScoreElements; + topOffsetsStartData[imageIdx] = imageIdx * param.numScoreElements; + topOffsetsEndData[imageIdx] = (imageIdx + 1) * param.numScoreElements; + } +} + +template +cudaError_t EfficientNMSCustomFilterLauncher(EfficientNMSCustomParameters& param, const T* scoresInput, int* topNumData, + int* topIndexData, int* topAnchorsData, int* topOffsetsStartData, int* topOffsetsEndData, T* topScoresData, + int* topClassData, cudaStream_t stream) +{ + const unsigned int elementsPerBlock = 512; + const unsigned int imagesPerBlock = 1; + const unsigned int elementBlocks = (param.numScoreElements + elementsPerBlock - 1) / elementsPerBlock; + const unsigned int imageBlocks = (param.batchSize + imagesPerBlock - 1) / imagesPerBlock; + const dim3 blockSize = {elementsPerBlock, imagesPerBlock, 1}; + const dim3 gridSize = {elementBlocks, imageBlocks, 1}; + + float kernelSelectThreshold = 0.007f; + if (param.scoreSigmoid) + { + // Inverse Sigmoid + if (param.scoreThreshold <= 0.f) + { + param.scoreThreshold = -(1 << 15); + } + else + { + param.scoreThreshold = logf(param.scoreThreshold / (1.f - param.scoreThreshold)); + } + kernelSelectThreshold = logf(kernelSelectThreshold / (1.f - kernelSelectThreshold)); + // Disable Score Bits Optimization + param.scoreBits = -1; + } + + if (param.scoreThreshold < kernelSelectThreshold) + { + // A full copy of the buffer is necessary because sorting will scramble the input data otherwise. + PLUGIN_CHECK_CUDA(cudaMemcpyAsync(topScoresData, scoresInput, param.batchSize * param.numScoreElements * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + + EfficientNMSCustomDenseIndex<<>>(param, topNumData, topIndexData, topAnchorsData, + topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData); + } + else + { + EfficientNMSCustomFilter<<>>( + param, scoresInput, topNumData, topIndexData, topAnchorsData, topScoresData, topClassData); + + EfficientNMSCustomFilterSegments<<<1, param.batchSize, 0, stream>>>( + param, topNumData, topOffsetsStartData, topOffsetsEndData); + } + + return cudaGetLastError(); +} + +template +size_t EfficientNMSCustomSortWorkspaceSize(int batchSize, int numScoreElements) +{ + size_t sortedWorkspaceSize = 0; + cub::DoubleBuffer keysDB(nullptr, nullptr); + cub::DoubleBuffer valuesDB(nullptr, nullptr); + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, sortedWorkspaceSize, keysDB, valuesDB, numScoreElements, + batchSize, (const int*) nullptr, (const int*) nullptr); + return sortedWorkspaceSize; +} + +size_t EfficientNMSCustomWorkspaceSize(int batchSize, int numScoreElements, int numClasses, DataType datatype) +{ + size_t total = 0; + const size_t align = 256; + // Counters + // 3 for Filtering + // 1 for Output Indexing + // C for Max per Class Limiting + size_t size = (3 + 1 + numClasses) * batchSize * sizeof(int); + total += size + (size % align ? align - (size % align) : 0); + // Int Buffers + for (int i = 0; i < 4; i++) + { + size = batchSize * numScoreElements * sizeof(int); + total += size + (size % align ? align - (size % align) : 0); + } + // Float Buffers + for (int i = 0; i < 2; i++) + { + size = batchSize * numScoreElements * dataTypeSize(datatype); + total += size + (size % align ? align - (size % align) : 0); + } + // Sort Workspace + if (datatype == DataType::kHALF) + { + size = EfficientNMSCustomSortWorkspaceSize<__half>(batchSize, numScoreElements); + total += size + (size % align ? align - (size % align) : 0); + } + else if (datatype == DataType::kFLOAT) + { + size = EfficientNMSCustomSortWorkspaceSize(batchSize, numScoreElements); + total += size + (size % align ? align - (size % align) : 0); + } + + return total; +} + +template +T* EfficientNMSCustomWorkspace(void* workspace, size_t& offset, size_t elements) +{ + T* buffer = (T*) ((size_t) workspace + offset); + size_t align = 256; + size_t size = elements * sizeof(T); + size_t sizeAligned = size + (size % align ? align - (size % align) : 0); + offset += sizeAligned; + return buffer; +} + +template +pluginStatus_t EfficientNMSCustomDispatch(EfficientNMSCustomParameters param, const void* boxesInput, const void* scoresInput, + const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, + void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) +{ + // Clear Outputs (not all elements will get overwritten by the kernels, so safer to clear everything out) + CSC(cudaMemsetAsync(numDetectionsOutput, 0x00, param.batchSize * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsScoresOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsBoxesOutput, 0x00, param.batchSize * param.numOutputBoxes * 4 * sizeof(T), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsClassesOutput, 0x00, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); + CSC(cudaMemsetAsync(nmsIndicesOutput, 0xFF, param.batchSize * param.numOutputBoxes * sizeof(int), stream), STATUS_FAILURE); + + // Empty Inputs + if (param.numScoreElements < 1) + { + return STATUS_SUCCESS; + } + + // Counters Workspace + size_t workspaceOffset = 0; + int countersTotalSize = (3 + 1 + param.numClasses) * param.batchSize; + int* topNumData = EfficientNMSCustomWorkspace(workspace, workspaceOffset, countersTotalSize); + int* topOffsetsStartData = topNumData + param.batchSize; + int* topOffsetsEndData = topNumData + 2 * param.batchSize; + int* outputIndexData = topNumData + 3 * param.batchSize; + int* outputClassData = topNumData + 4 * param.batchSize; + CSC(cudaMemsetAsync(topNumData, 0x00, countersTotalSize * sizeof(int), stream), STATUS_FAILURE); + cudaError_t status = cudaGetLastError(); + CSC(status, STATUS_FAILURE); + + // Other Buffers Workspace + int* topIndexData + = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topClassData + = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* topAnchorsData + = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + int* sortedIndexData + = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* topScoresData = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + T* sortedScoresData + = EfficientNMSCustomWorkspace(workspace, workspaceOffset, param.batchSize * param.numScoreElements); + size_t sortedWorkspaceSize = EfficientNMSCustomSortWorkspaceSize(param.batchSize, param.numScoreElements); + char* sortedWorkspaceData = EfficientNMSCustomWorkspace(workspace, workspaceOffset, sortedWorkspaceSize); + cub::DoubleBuffer scoresDB(topScoresData, sortedScoresData); + cub::DoubleBuffer indexDB(topIndexData, sortedIndexData); + + // Kernels + status = EfficientNMSCustomFilterLauncher(param, (T*) scoresInput, topNumData, topIndexData, topAnchorsData, + topOffsetsStartData, topOffsetsEndData, topScoresData, topClassData, stream); + CSC(status, STATUS_FAILURE); + + status = cub::DeviceSegmentedRadixSort::SortPairsDescending(sortedWorkspaceData, sortedWorkspaceSize, scoresDB, + indexDB, param.batchSize * param.numScoreElements, param.batchSize, topOffsetsStartData, topOffsetsEndData, + param.scoreBits > 0 ? (10 - param.scoreBits) : 0, param.scoreBits > 0 ? 10 : sizeof(T) * 8, stream, false); + CSC(status, STATUS_FAILURE); + + status = EfficientNMSCustomLauncher(param, topNumData, outputIndexData, outputClassData, indexDB.Current(), + scoresDB.Current(), topClassData, topAnchorsData, boxesInput, anchorsInput, (int*) numDetectionsOutput, + (T*) nmsScoresOutput, (int*) nmsClassesOutput, (int*) nmsIndicesOutput, nmsBoxesOutput, stream); + CSC(status, STATUS_FAILURE); + + return STATUS_SUCCESS; +} + +pluginStatus_t EfficientNMSCustomInference(EfficientNMSCustomParameters param, const void* boxesInput, const void* scoresInput, + const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, + void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream) +{ + if (param.datatype == DataType::kFLOAT) + { + param.scoreBits = -1; + return EfficientNMSCustomDispatch(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); + } + else if (param.datatype == DataType::kHALF) + { + if (param.scoreBits <= 0 || param.scoreBits > 10) + { + param.scoreBits = -1; + } + return EfficientNMSCustomDispatch<__half>(param, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, + nmsBoxesOutput, nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); + } + else + { + return STATUS_NOT_SUPPORTED; + } +} diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cuh b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cuh new file mode 100644 index 00000000..491bc1a9 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.cuh @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_CUSTOM_INFERENCE_CUH +#define TRT_EFFICIENT_NMS_CUSTOM_INFERENCE_CUH + +#include + +// FP32 Intrinsics + +float __device__ __inline__ exp_mp(const float a) +{ + return __expf(a); +} +float __device__ __inline__ sigmoid_mp(const float a) +{ + return __frcp_rn(__fadd_rn(1.f, __expf(-a))); +} +float __device__ __inline__ add_mp(const float a, const float b) +{ + return __fadd_rn(a, b); +} +float __device__ __inline__ sub_mp(const float a, const float b) +{ + return __fsub_rn(a, b); +} +float __device__ __inline__ mul_mp(const float a, const float b) +{ + return __fmul_rn(a, b); +} +bool __device__ __inline__ gt_mp(const float a, const float b) +{ + return a > b; +} +bool __device__ __inline__ lt_mp(const float a, const float b) +{ + return a < b; +} +bool __device__ __inline__ lte_mp(const float a, const float b) +{ + return a <= b; +} +bool __device__ __inline__ gte_mp(const float a, const float b) +{ + return a >= b; +} + +#if __CUDA_ARCH__ >= 530 + +// FP16 Intrinsics + +__half __device__ __inline__ exp_mp(const __half a) +{ + return hexp(a); +} +__half __device__ __inline__ sigmoid_mp(const __half a) +{ + return hrcp(__hadd((__half) 1, hexp(__hneg(a)))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) +{ + return __hadd(a, b); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) +{ + return __hsub(a, b); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) +{ + return __hmul(a, b); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) +{ + return __hgt(a, b); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) +{ + return __hlt(a, b); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) +{ + return __hle(a, b); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) +{ + return __hge(a, b); +} + +#else + +// FP16 Fallbacks on older architectures that lack support + +__half __device__ __inline__ exp_mp(const __half a) +{ + return __float2half(exp_mp(__half2float(a))); +} +__half __device__ __inline__ sigmoid_mp(const __half a) +{ + return __float2half(sigmoid_mp(__half2float(a))); +} +__half __device__ __inline__ add_mp(const __half a, const __half b) +{ + return __float2half(add_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ sub_mp(const __half a, const __half b) +{ + return __float2half(sub_mp(__half2float(a), __half2float(b))); +} +__half __device__ __inline__ mul_mp(const __half a, const __half b) +{ + return __float2half(mul_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gt_mp(const __half a, const __half b) +{ + return __float2half(gt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lt_mp(const __half a, const __half b) +{ + return __float2half(lt_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ lte_mp(const __half a, const __half b) +{ + return __float2half(lte_mp(__half2float(a), __half2float(b))); +} +bool __device__ __inline__ gte_mp(const __half a, const __half b) +{ + return __float2half(gte_mp(__half2float(a), __half2float(b))); +} + +#endif + +template +struct __align__(4 * sizeof(T)) BoxCorner; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize; + +template +struct __align__(4 * sizeof(T)) BoxCorner +{ + // For NMS/IOU purposes, YXYX coding is identical to XYXY + T y1, x1, y2, x2; + + __device__ void reorder() + { + if (gt_mp(y1, y2)) + { + // Swap values, so y1 < y2 + y1 = sub_mp(y1, y2); + y2 = add_mp(y1, y2); + y1 = sub_mp(y2, y1); + } + if (gt_mp(x1, x2)) + { + // Swap values, so x1 < x2 + x1 = sub_mp(x1, x2); + x2 = add_mp(x1, x2); + x1 = sub_mp(x2, x1); + } + } + + __device__ BoxCorner clip(T low, T high) const + { + return {lt_mp(y1, low) ? low : (gt_mp(y1, high) ? high : y1), + lt_mp(x1, low) ? low : (gt_mp(x1, high) ? high : x1), lt_mp(y2, low) ? low : (gt_mp(y2, high) ? high : y2), + lt_mp(x2, low) ? low : (gt_mp(x2, high) ? high : x2)}; + } + + __device__ BoxCorner decode(BoxCorner anchor) const + { + return {add_mp(y1, anchor.y1), add_mp(x1, anchor.x1), add_mp(y2, anchor.y2), add_mp(x2, anchor.x2)}; + } + + __device__ float area() const + { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + if (lte_mp(h, (T) 0)) + { + return 0; + } + if (lte_mp(w, (T) 0)) + { + return 0; + } + return (float) h * (float) w; + } + + __device__ operator BoxCenterSize() const + { + T w = sub_mp(x2, x1); + T h = sub_mp(y2, y1); + return BoxCenterSize{add_mp(y1, mul_mp((T) 0.5, h)), add_mp(x1, mul_mp((T) 0.5, w)), h, w}; + } + + __device__ static BoxCorner intersect(BoxCorner a, BoxCorner b) + { + return {gt_mp(a.y1, b.y1) ? a.y1 : b.y1, gt_mp(a.x1, b.x1) ? a.x1 : b.x1, lt_mp(a.y2, b.y2) ? a.y2 : b.y2, + lt_mp(a.x2, b.x2) ? a.x2 : b.x2}; + } +}; + +template +struct __align__(4 * sizeof(T)) BoxCenterSize +{ + // For NMS/IOU purposes, YXHW coding is identical to XYWH + T y, x, h, w; + + __device__ void reorder() {} + + __device__ BoxCenterSize clip(T low, T high) const + { + return BoxCenterSize(BoxCorner(*this).clip(low, high)); + } + + __device__ BoxCenterSize decode(BoxCenterSize anchor) const + { + return {add_mp(mul_mp(y, anchor.h), anchor.y), add_mp(mul_mp(x, anchor.w), anchor.x), + mul_mp(anchor.h, exp_mp(h)), mul_mp(anchor.w, exp_mp(w))}; + } + + __device__ float area() const + { + if (h <= (T) 0) + { + return 0; + } + if (w <= (T) 0) + { + return 0; + } + return (float) h * (float) w; + } + + __device__ operator BoxCorner() const + { + T h2 = mul_mp(h, (T) 0.5); + T w2 = mul_mp(w, (T) 0.5); + return BoxCorner{sub_mp(y, h2), sub_mp(x, w2), add_mp(y, h2), add_mp(x, w2)}; + } + __device__ static BoxCenterSize intersect(BoxCenterSize a, BoxCenterSize b) + { + return BoxCenterSize(BoxCorner::intersect(BoxCorner(a), BoxCorner(b))); + } +}; + +#endif \ No newline at end of file diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.h b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.h new file mode 100644 index 00000000..525f46b8 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomInference.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_CUSTOM_INFERENCE_H +#define TRT_EFFICIENT_NMS_CUSTOM_INFERENCE_H + +#include "common/plugin.h" + +#include "efficientNMSCustomParameters.h" + +size_t EfficientNMSCustomWorkspaceSize(int batchSize, int numScoreElements, int numClasses, nvinfer1::DataType datatype); + +pluginStatus_t EfficientNMSCustomInference(nvinfer1::plugin::EfficientNMSCustomParameters param, const void* boxesInput, const void* scoresInput, + const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput, + void* nmsClassesOutput, void* nmsIndicesOutput, void* workspace, cudaStream_t stream); + +#endif diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomParameters.h b/plugin/efficientNMSCustomPlugin/efficientNMSCustomParameters.h new file mode 100644 index 00000000..4afb3b77 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomParameters.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRT_EFFICIENT_NMS_CUSTOM_PARAMETERS_H +#define TRT_EFFICIENT_NMS_CUSTOM_PARAMETERS_H + +#include "common/plugin.h" + +namespace nvinfer1 +{ +namespace plugin +{ + +struct EfficientNMSCustomParameters +{ + // Related to NMS Options + float iouThreshold = 0.5f; + float scoreThreshold = 0.5f; + int numOutputBoxes = 100; + int numOutputBoxesPerClass = -1; + bool padOutputBoxesPerClass = false; + int backgroundClass = -1; + bool scoreSigmoid = false; + bool clipBoxes = false; + int boxCoding = 0; + + // Related to NMS Internals + int numSelectedBoxes = 4096; + int scoreBits = -1; + + // Related to Tensor Configuration + // (These are set by the various plugin configuration methods, no need to define them during plugin creation.) + int batchSize = -1; + int numClasses = 1; + int numBoxElements = -1; + int numScoreElements = -1; + int numAnchors = -1; + bool shareLocation = true; + bool shareAnchors = true; + bool boxDecoder = false; + nvinfer1::DataType datatype = nvinfer1::DataType::kFLOAT; +}; + +} // namespace plugin +} // namespace nvinfer1 + +#endif diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.cpp b/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.cpp new file mode 100644 index 00000000..0a065e23 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.cpp @@ -0,0 +1,463 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "efficientNMSCustomPlugin.h" +#include "efficientNMSCustomInference.h" + +using namespace nvinfer1; +using nvinfer1::plugin::EfficientNMSCustomPlugin; +using nvinfer1::plugin::EfficientNMSCustomParameters; +using nvinfer1::plugin::EfficientNMSCustomPluginCreator; + +namespace +{ +const char* EFFICIENT_NMS_CUSTOM_PLUGIN_VERSION{"1"}; +const char* EFFICIENT_NMS_CUSTOM_PLUGIN_NAME{"EfficientNMSCustom_TRT"}; +} // namespace + +EfficientNMSCustomPlugin::EfficientNMSCustomPlugin(EfficientNMSCustomParameters param) + : mParam(param) +{ +} + +EfficientNMSCustomPlugin::EfficientNMSCustomPlugin(const void* data, size_t length) +{ + const char *d = reinterpret_cast(data), *a = d; + mParam = read(d); + PLUGIN_ASSERT(d == a + length); +} + +const char* EfficientNMSCustomPlugin::getPluginType() const noexcept +{ + return EFFICIENT_NMS_CUSTOM_PLUGIN_NAME; +} + +const char* EfficientNMSCustomPlugin::getPluginVersion() const noexcept +{ + return EFFICIENT_NMS_CUSTOM_PLUGIN_VERSION; +} + +int EfficientNMSCustomPlugin::getNbOutputs() const noexcept +{ + // Standard Plugin Implementation + return 5; +} + +int EfficientNMSCustomPlugin::initialize() noexcept +{ + if (!initialized) + { + int32_t device; + CSC(cudaGetDevice(&device), STATUS_FAILURE); + struct cudaDeviceProp properties; + CSC(cudaGetDeviceProperties(&properties, device), STATUS_FAILURE); + if (properties.regsPerBlock >= 65536) + { + // Most Devices + mParam.numSelectedBoxes = 5000; + } + else + { + // Jetson TX1/TX2 + mParam.numSelectedBoxes = 2000; + } + initialized = true; + } + return STATUS_SUCCESS; +} + +void EfficientNMSCustomPlugin::terminate() noexcept {} + +size_t EfficientNMSCustomPlugin::getSerializationSize() const noexcept +{ + return sizeof(EfficientNMSCustomParameters); +} + +void EfficientNMSCustomPlugin::serialize(void* buffer) const noexcept +{ + char *d = reinterpret_cast(buffer), *a = d; + write(d, mParam); + PLUGIN_ASSERT(d == a + getSerializationSize()); +} + +void EfficientNMSCustomPlugin::destroy() noexcept +{ + delete this; +} + +void EfficientNMSCustomPlugin::setPluginNamespace(const char* pluginNamespace) noexcept +{ + try + { + mNamespace = pluginNamespace; + } + catch (const std::exception& e) + { + caughtError(e); + } +} + +const char* EfficientNMSCustomPlugin::getPluginNamespace() const noexcept +{ + return mNamespace.c_str(); +} + +nvinfer1::DataType EfficientNMSCustomPlugin::getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept +{ + // On standard NMS, num_detections and detection_classes use integer outputs + if (index == 0 || index == 3 || index == 4) + { + return nvinfer1::DataType::kINT32; + } + // All others should use the same datatype as the input + return inputTypes[0]; +} + +IPluginV2DynamicExt* EfficientNMSCustomPlugin::clone() const noexcept +{ + try + { + auto* plugin = new EfficientNMSCustomPlugin(mParam); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (const std::exception& e) + { + caughtError(e); + } + return nullptr; +} + +DimsExprs EfficientNMSCustomPlugin::getOutputDimensions( + int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) noexcept +{ + try + { + DimsExprs out_dim; + + // When pad per class is set, the output size may need to be reduced: + // i.e.: outputBoxes = min(outputBoxes, outputBoxesPerClass * numClasses) + // As the number of classes may not be static, numOutputBoxes must be a dynamic + // expression. The corresponding parameter can not be set at this time, so the + // value will be calculated again in configurePlugin() and the param overwritten. + const IDimensionExpr* numOutputBoxes = exprBuilder.constant(mParam.numOutputBoxes); + if (mParam.padOutputBoxesPerClass && mParam.numOutputBoxesPerClass > 0) + { + const IDimensionExpr* numOutputBoxesPerClass = exprBuilder.constant(mParam.numOutputBoxesPerClass); + const IDimensionExpr* numClasses = inputs[1].d[2]; + numOutputBoxes = exprBuilder.operation(DimensionOperation::kMIN, *numOutputBoxes, + *exprBuilder.operation(DimensionOperation::kPROD, *numOutputBoxesPerClass, *numClasses)); + } + + // Standard NMS + PLUGIN_ASSERT(outputIndex >= 0 && outputIndex <= 4); + + // num_detections + if (outputIndex == 0) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = exprBuilder.constant(1); + } + // detection_boxes + else if (outputIndex == 1) + { + out_dim.nbDims = 3; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + out_dim.d[2] = exprBuilder.constant(4); + } + // detection_scores + else if (outputIndex == 2) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + } + // detection_classes + else if (outputIndex == 3) + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + } + // detection_indices + else + { + out_dim.nbDims = 2; + out_dim.d[0] = inputs[0].d[0]; + out_dim.d[1] = numOutputBoxes; + } + + return out_dim; + } + catch (const std::exception& e) + { + caughtError(e); + } + return DimsExprs{}; +} + +bool EfficientNMSCustomPlugin::supportsFormatCombination( + int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept +{ + if (inOut[pos].format != PluginFormat::kLINEAR) + { + return false; + } + + PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 5); + if (nbInputs == 2) + { + PLUGIN_ASSERT(0 <= pos && pos <= 6); + } + if (nbInputs == 3) + { + PLUGIN_ASSERT(0 <= pos && pos <= 7); + } + + // num_detections and detection_classes output: int + const int posOut = pos - nbInputs; + if (posOut == 0 || posOut == 3 || posOut == 4) + { + return inOut[pos].type == DataType::kINT32 && inOut[pos].format == PluginFormat::kLINEAR; + } + + // all other inputs/outputs: fp32 or fp16 + return (inOut[pos].type == DataType::kHALF || inOut[pos].type == DataType::kFLOAT) + && (inOut[0].type == inOut[pos].type); +} + +void EfficientNMSCustomPlugin::configurePlugin( + const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, int nbOutputs) noexcept +{ + try + { + // Accepts two or three inputs + // If two inputs: [0] boxes, [1] scores + // If three inputs: [0] boxes, [1] scores, [2] anchors + PLUGIN_ASSERT(nbInputs == 2 || nbInputs == 3); + PLUGIN_ASSERT(nbOutputs == 5); + + mParam.datatype = in[0].desc.type; + + // Shape of scores input should be + // [batch_size, num_boxes, num_classes] or [batch_size, num_boxes, num_classes, 1] + PLUGIN_ASSERT(in[1].desc.dims.nbDims == 3 || (in[1].desc.dims.nbDims == 4 && in[1].desc.dims.d[3] == 1)); + mParam.numScoreElements = in[1].desc.dims.d[1] * in[1].desc.dims.d[2]; + mParam.numClasses = in[1].desc.dims.d[2]; + + // When pad per class is set, the total ouput boxes size may need to be reduced. + // This operation is also done in getOutputDimension(), but for dynamic shapes, the + // numOutputBoxes param can't be set until the number of classes is fully known here. + if (mParam.padOutputBoxesPerClass && mParam.numOutputBoxesPerClass > 0) + { + if (mParam.numOutputBoxesPerClass * mParam.numClasses < mParam.numOutputBoxes) + { + mParam.numOutputBoxes = mParam.numOutputBoxesPerClass * mParam.numClasses; + } + } + + // Shape of boxes input should be + // [batch_size, num_boxes, 4] or [batch_size, num_boxes, 1, 4] or [batch_size, num_boxes, num_classes, 4] + PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3 || in[0].desc.dims.nbDims == 4); + if (in[0].desc.dims.nbDims == 3) + { + PLUGIN_ASSERT(in[0].desc.dims.d[2] == 4); + mParam.shareLocation = true; + mParam.numBoxElements = in[0].desc.dims.d[1] * in[0].desc.dims.d[2]; + } + else + { + mParam.shareLocation = (in[0].desc.dims.d[2] == 1); + PLUGIN_ASSERT(in[0].desc.dims.d[2] == mParam.numClasses || mParam.shareLocation); + PLUGIN_ASSERT(in[0].desc.dims.d[3] == 4); + mParam.numBoxElements = in[0].desc.dims.d[1] * in[0].desc.dims.d[2] * in[0].desc.dims.d[3]; + } + mParam.numAnchors = in[0].desc.dims.d[1]; + + if (nbInputs == 2) + { + // Only two inputs are used, disable the fused box decoder + mParam.boxDecoder = false; + } + if (nbInputs == 3) + { + // All three inputs are used, enable the box decoder + // Shape of anchors input should be + // Constant shape: [1, numAnchors, 4] or [batch_size, numAnchors, 4] + PLUGIN_ASSERT(in[2].desc.dims.nbDims == 3); + mParam.boxDecoder = true; + mParam.shareAnchors = (in[2].desc.dims.d[0] == 1); + } + } + catch (const std::exception& e) + { + caughtError(e); + } +} + +size_t EfficientNMSCustomPlugin::getWorkspaceSize( + const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const noexcept +{ + int batchSize = inputs[1].dims.d[0]; + int numScoreElements = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + int numClasses = inputs[1].dims.d[2]; + return EfficientNMSCustomWorkspaceSize(batchSize, numScoreElements, numClasses, mParam.datatype); +} + +int EfficientNMSCustomPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +{ + try + { + mParam.batchSize = inputDesc[0].dims.d[0]; + + // Standard NMS Operation + const void* const boxesInput = inputs[0]; + const void* const scoresInput = inputs[1]; + const void* const anchorsInput = mParam.boxDecoder ? inputs[2] : nullptr; + + void* numDetectionsOutput = outputs[0]; + void* nmsBoxesOutput = outputs[1]; + void* nmsScoresOutput = outputs[2]; + void* nmsClassesOutput = outputs[3]; + void* nmsIndicesOutput = outputs[4]; + + return EfficientNMSCustomInference(mParam, boxesInput, scoresInput, anchorsInput, numDetectionsOutput, nmsBoxesOutput, + nmsScoresOutput, nmsClassesOutput, nmsIndicesOutput, workspace, stream); + } + catch (const std::exception& e) + { + caughtError(e); + } + return -1; +} + +// Standard NMS Plugin Operation + +EfficientNMSCustomPluginCreator::EfficientNMSCustomPluginCreator() + : mParam{} +{ + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("max_output_boxes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("background_class", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("score_activation", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1)); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* EfficientNMSCustomPluginCreator::getPluginName() const noexcept +{ + return EFFICIENT_NMS_CUSTOM_PLUGIN_NAME; +} + +const char* EfficientNMSCustomPluginCreator::getPluginVersion() const noexcept +{ + return EFFICIENT_NMS_CUSTOM_PLUGIN_VERSION; +} + +const PluginFieldCollection* EfficientNMSCustomPluginCreator::getFieldNames() noexcept +{ + return &mFC; +} + +IPluginV2DynamicExt* EfficientNMSCustomPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +{ + try + { + PLUGIN_VALIDATE(fc != nullptr); + PluginField const* fields = fc->fields; + PLUGIN_VALIDATE(fields != nullptr); + plugin::validateRequiredAttributesExist({"score_threshold", "iou_threshold", "max_output_boxes", + "background_class", "score_activation", "box_coding"}, + fc); + for (int32_t i{0}; i < fc->nbFields; ++i) + { + char const* attrName = fields[i].name; + if (!strcmp(attrName, "score_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + auto const scoreThreshold = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(scoreThreshold >= 0.0F); + } + if (!strcmp(attrName, "iou_threshold")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kFLOAT32); + auto const iouThreshold = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(iouThreshold > 0.0F); + mParam.iouThreshold = iouThreshold; + } + if (!strcmp(attrName, "max_output_boxes")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + auto const numOutputBoxes = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(numOutputBoxes > 0); + mParam.numOutputBoxes = numOutputBoxes; + } + if (!strcmp(attrName, "background_class")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + mParam.backgroundClass = *(static_cast(fields[i].data)); + } + if (!strcmp(attrName, "score_activation")) + { + auto const scoreSigmoid = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(scoreSigmoid == 0 || scoreSigmoid == 1); + mParam.scoreSigmoid = static_cast(scoreSigmoid); + } + if (!strcmp(attrName, "box_coding")) + { + PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); + auto const boxCoding = *(static_cast(fields[i].data)); + PLUGIN_VALIDATE(boxCoding == 0 || boxCoding == 1); + mParam.boxCoding = boxCoding; + } + } + + auto* plugin = new EfficientNMSCustomPlugin(mParam); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e) + { + caughtError(e); + } + return nullptr; +} + +IPluginV2DynamicExt* EfficientNMSCustomPluginCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) noexcept +{ + try + { + // This object will be deleted when the network is destroyed, which will + // call EfficientNMSCustomPlugin::destroy() + auto* plugin = new EfficientNMSCustomPlugin(serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (const std::exception& e) + { + caughtError(e); + } + return nullptr; +} diff --git a/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.h b/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.h new file mode 100644 index 00000000..9190ab10 --- /dev/null +++ b/plugin/efficientNMSCustomPlugin/efficientNMSCustomPlugin.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef TRT_EFFICIENT_NMS_CUSTOM_PLUGIN_H +#define TRT_EFFICIENT_NMS_CUSTOM_PLUGIN_H + +#include + +#include "common/plugin.h" +#include "efficientNMSCustomParameters.h" + +namespace nvinfer1 +{ +namespace plugin +{ + +class EfficientNMSCustomPlugin : public IPluginV2DynamicExt +{ +public: + explicit EfficientNMSCustomPlugin(EfficientNMSCustomParameters param); + EfficientNMSCustomPlugin(const void* data, size_t length); + ~EfficientNMSCustomPlugin() override = default; + + // IPluginV2 methods + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* libNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + + // IPluginV2Ext methods + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept override; + + // IPluginV2DynamicExt methods + IPluginV2DynamicExt* clone() const noexcept override; + DimsExprs getOutputDimensions( + int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination( + int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + +protected: + EfficientNMSCustomParameters mParam{}; + std::string mNamespace; +}; + +// Standard NMS Plugin Operation +class EfficientNMSCustomPluginCreator : public nvinfer1::pluginInternal::BaseCreator +{ +public: + EfficientNMSCustomPluginCreator(); + ~EfficientNMSCustomPluginCreator() override = default; + + const char* getPluginName() const noexcept override; + const char* getPluginVersion() const noexcept override; + const PluginFieldCollection* getFieldNames() noexcept override; + + IPluginV2DynamicExt* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override; + IPluginV2DynamicExt* deserializePlugin( + const char* name, const void* serialData, size_t serialLength) noexcept override; + +protected: + PluginFieldCollection mFC; + EfficientNMSCustomParameters mParam; + bool initialized{false}; + std::vector mPluginAttributes; + std::string mPluginName; +}; + +} // namespace plugin +} // namespace nvinfer1 + +#endif // TRT_EFFICIENT_NMS_CUSTOM_PLUGIN_H