forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NvOnnxParser.h
282 lines (253 loc) · 9.21 KB
/
NvOnnxParser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#ifndef NV_ONNX_PARSER_H
#define NV_ONNX_PARSER_H
#include "NvInfer.h"
#include <stddef.h>
#include <vector>
//!
//! \file NvOnnxParser.h
//!
//! This is the API for the ONNX Parser
//!
#define NV_ONNX_PARSER_MAJOR 0
#define NV_ONNX_PARSER_MINOR 1
#define NV_ONNX_PARSER_PATCH 0
static const int NV_ONNX_PARSER_VERSION = ((NV_ONNX_PARSER_MAJOR * 10000) + (NV_ONNX_PARSER_MINOR * 100) + NV_ONNX_PARSER_PATCH);
//! \typedef SubGraph_t
//!
//! \brief The data structure containing the parsing capability of
//! a set of nodes in an ONNX graph.
//!
using SubGraph_t = std::pair<std::vector<size_t>, bool>;
//! \typedef SubGraphCollection_t
//!
//! \brief The data structure containing all SubGraph_t partitioned
//! out of an ONNX graph.
//!
using SubGraphCollection_t = std::vector<SubGraph_t>;
class onnxTensorDescriptorV1;
//!
//! \namespace nvonnxparser
//!
//! \brief The TensorRT ONNX parser API namespace
//!
namespace nvonnxparser
{
template <typename T>
inline int32_t EnumMax();
/** \enum ErrorCode
*
* \brief the type of parser error
*/
enum class ErrorCode : int
{
kSUCCESS = 0,
kINTERNAL_ERROR = 1,
kMEM_ALLOC_FAILED = 2,
kMODEL_DESERIALIZE_FAILED = 3,
kINVALID_VALUE = 4,
kINVALID_GRAPH = 5,
kINVALID_NODE = 6,
kUNSUPPORTED_GRAPH = 7,
kUNSUPPORTED_NODE = 8
};
template <>
inline int32_t EnumMax<ErrorCode>()
{
return 9;
}
/** \class IParserError
*
* \brief an object containing information about an error
*/
class IParserError
{
public:
/** \brief the error code
*/
virtual ErrorCode code() const = 0;
/** \brief description of the error
*/
virtual const char* desc() const = 0;
/** \brief source file in which the error occurred
*/
virtual const char* file() const = 0;
/** \brief source line at which the error occurred
*/
virtual int line() const = 0;
/** \brief source function in which the error occurred
*/
virtual const char* func() const = 0;
/** \brief index of the ONNX model node in which the error occurred
*/
virtual int node() const = 0;
protected:
virtual ~IParserError() {}
};
/** \class IParser
*
* \brief an object for parsing ONNX models into a TensorRT network definition
*/
class IParser
{
public:
/** \brief Parse a serialized ONNX model into the TensorRT network.
* This method has very limited diagnostic. If parsing the serialized model
* fails for any reason (e.g. unsupported IR version, unsupported opset, etc.)
* it the user responsibility to intercept and report the error.
* To obtain a better diagnostic, use the parseFromFile method below.
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parse(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
const char* model_path = nullptr)
= 0;
/** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model
* calls parse method inside.
*
* \param File name
* \param Verbosity Level
*
* \return true if the model was parsed successfully
*
*/
virtual bool parseFromFile(const char* onnxModelFile, int verbosity) = 0;
/** \brief Check whether TensorRT supports a particular ONNX model
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param sub_graph_collection Container to hold supported subgraphs
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model is supported
*/
virtual bool supportsModel(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection,
const char* model_path = nullptr)
= 0;
/** \brief Parse a serialized ONNX model into the TensorRT network
* with consideration of user provided weights
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param weight_count number of user provided weights
* \param weight_descriptors pointer to user provided weight array
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parseWithWeightDescriptors(
void const* serialized_onnx_model, size_t serialized_onnx_model_size,
uint32_t weight_count,
onnxTensorDescriptorV1 const* weight_descriptors)
= 0;
/** \brief Returns whether the specified operator may be supported by the
* parser.
*
* Note that a result of true does not guarantee that the operator will be
* supported in all cases (i.e., this function may return false-positives).
*
* \param op_name The name of the ONNX operator to check for support
*/
virtual bool supportsOperator(const char* op_name) const = 0;
/** \brief destroy this object
*/
virtual void destroy() = 0;
/** \brief Get the number of errors that occurred during prior calls to
* \p parse
*
* \see getError() clearErrors() IParserError
*/
virtual int getNbErrors() const = 0;
/** \brief Get an error that occurred during prior calls to \p parse
*
* \see getNbErrors() clearErrors() IParserError
*/
virtual IParserError const* getError(int index) const = 0;
/** \brief Clear errors from prior calls to \p parse
*
* \see getNbErrors() getError() IParserError
*/
virtual void clearErrors() = 0;
/** \brief Get description of all ONNX weights that can be refitted.
*
* \param weightsNames Where to write the weight names to
* \param layerNames Where to write the layer names to
* \param roles Where to write the roles to
*
* \return The number of weights from the ONNX model that can be refitted
*
* If weightNames or layerNames != nullptr, each written pointer points to a string owned by
* the parser, and becomes invalid when the parser is destroyed
*
* If the same weight is used in multiple TRT layers it will be represented as a new
* entry in weightNames with name <weightName>_x, with x being the number of times the weight
* has been used before the current layer
*/
virtual int getRefitMap(const char** weightNames, const char** layerNames, nvinfer1::WeightsRole* roles) = 0;
protected:
virtual ~IParser() {}
};
} // namespace nvonnxparser
extern "C" TENSORRTAPI void* createNvOnnxParser_INTERNAL(void* network, void* logger, int version);
extern "C" TENSORRTAPI int getNvOnnxParserVersion();
namespace nvonnxparser
{
#ifdef SWIG
inline IParser* createParser(nvinfer1::INetworkDefinition* network,
nvinfer1::ILogger* logger)
{
return static_cast<IParser*>(
createNvOnnxParser_INTERNAL(network, logger, NV_ONNX_PARSER_VERSION));
}
#endif // SWIG
namespace
{
/** \brief Create a new parser object
*
* \param network The network definition that the parser will write to
* \param logger The logger to use
* \return a new parser object or NULL if an error occurred
* \see IParser
*/
#ifdef _MSC_VER
TENSORRTAPI IParser* createParser(nvinfer1::INetworkDefinition& network,
nvinfer1::ILogger& logger)
#else
inline IParser* createParser(nvinfer1::INetworkDefinition& network,
nvinfer1::ILogger& logger)
#endif
{
return static_cast<IParser*>(
createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
}
} // namespace
} // namespace nvonnxparser
#endif // NV_ONNX_PARSER_H