forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 3
/
roIAlign2Plugin.h
143 lines (122 loc) · 5.33 KB
/
roIAlign2Plugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ROIALIGN2_PLUGIN_H
#define ROIALIGN2_PLUGIN_H
#include <cassert>
#include <cuda_runtime_api.h>
#include <string.h>
#include <string>
#include <vector>
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "kernel.h"
#include "maskRCNNKernels.h"
#include "mrcnn_config.h"
#include "plugin.h"
using namespace nvinfer1::plugin;
// One of the preferred ways of making TensorRT to be able to see
// our custom layer requires extending IPluginV2Ext and BaseCreator classes.
// For requirements for overriden functions, check TensorRT API docs.
namespace nvinfer1
{
namespace plugin
{
class RoIAlign2DynamicPlugin : public IPluginV2DynamicExt
{
public:
RoIAlign2DynamicPlugin(const std::string name);
RoIAlign2DynamicPlugin(const std::string name, int pooledSize, int transformCoords, bool absCoords, bool swapCoords,
int samplingRatio, bool legacy, int imageSize);
RoIAlign2DynamicPlugin(const std::string name, int pooledSize, int transformCoords, bool absCoords, bool swapCoords,
int samplingRatio, bool legacy, int imageSize, int featureLength, int roiCount, int inputWidth, int inputHeight);
RoIAlign2DynamicPlugin(const std::string name, const void* data, size_t length);
// It doesn't make sense to make RoIAlign2DynamicPlugin without arguments, so we delete default constructor.
RoIAlign2DynamicPlugin() noexcept = delete;
~RoIAlign2DynamicPlugin() noexcept override;
// IPluginV2 methods
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(const char* libNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
// IPluginV2Ext methods
DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept override;
// IPluginV2DynamicExt methods
IPluginV2DynamicExt* clone() const noexcept override;
DimsExprs getOutputDimensions(
int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override;
void configurePlugin(const DynamicPluginTensorDesc* in, int nbInputs, const DynamicPluginTensorDesc* out,
int nbOutputs) noexcept override;
size_t getWorkspaceSize(const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs,
int nbOutputs) const noexcept override;
int enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
private:
const std::string mLayerName;
std::string mNamespace;
int mPooledSize;
int mImageSize;
int mTransformCoords;
bool mAbsCoords;
bool mSwapCoords;
int mSamplingRatio;
bool mIsLegacy{false};
int mFeatureLength;
int mROICount;
int mInputWidth;
int mInputHeight;
};
class RoIAlign2BasePluginCreator : public BaseCreator
{
public:
RoIAlign2BasePluginCreator() noexcept;
~RoIAlign2BasePluginCreator() noexcept override = default;
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const PluginFieldCollection* getFieldNames() noexcept override;
protected:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
std::string mPluginName;
};
class RoIAlign2PluginCreator : public RoIAlign2BasePluginCreator
{
public:
RoIAlign2PluginCreator() noexcept;
~RoIAlign2PluginCreator() noexcept override = default;
IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;
IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override;
};
class RoIAlign2DynamicPluginCreator : public RoIAlign2BasePluginCreator
{
public:
RoIAlign2DynamicPluginCreator() noexcept;
~RoIAlign2DynamicPluginCreator() noexcept override = default;
IPluginV2DynamicExt* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;
IPluginV2DynamicExt* deserializePlugin(
const char* name, const void* serialData, size_t serialLength) noexcept override;
};
} // namespace plugin
} // namespace nvinfer1
#endif // ROIALIGN2_PLUGIN_H