-
Notifications
You must be signed in to change notification settings - Fork 546
/
Copy pathNvOnnxParser.h
541 lines (488 loc) · 17.8 KB
/
NvOnnxParser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
/*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef NV_ONNX_PARSER_H
#define NV_ONNX_PARSER_H
#include "NvInfer.h"
#include <stddef.h>
#include <string>
#include <vector>
//!
//! \file NvOnnxParser.h
//!
//! This is the API for the ONNX Parser
//!
#define NV_ONNX_PARSER_MAJOR 0
#define NV_ONNX_PARSER_MINOR 1
#define NV_ONNX_PARSER_PATCH 0
static constexpr int32_t NV_ONNX_PARSER_VERSION
= ((NV_ONNX_PARSER_MAJOR * 10000) + (NV_ONNX_PARSER_MINOR * 100) + NV_ONNX_PARSER_PATCH);
//!
//! \typedef SubGraph_t
//!
//! \brief The data structure containing the parsing capability of
//! a set of nodes in an ONNX graph.
//!
typedef std::pair<std::vector<size_t>, bool> SubGraph_t;
//!
//! \typedef SubGraphCollection_t
//!
//! \brief The data structure containing all SubGraph_t partitioned
//! out of an ONNX graph.
//!
typedef std::vector<SubGraph_t> SubGraphCollection_t;
//!
//! \namespace nvonnxparser
//!
//! \brief The TensorRT ONNX parser API namespace
//!
namespace nvonnxparser
{
template <typename T>
constexpr inline int32_t EnumMax() noexcept;
//!
//! \enum ErrorCode
//!
//! \brief The type of error that the parser or refitter may return
//!
enum class ErrorCode : int
{
kSUCCESS = 0,
kINTERNAL_ERROR = 1,
kMEM_ALLOC_FAILED = 2,
kMODEL_DESERIALIZE_FAILED = 3,
kINVALID_VALUE = 4,
kINVALID_GRAPH = 5,
kINVALID_NODE = 6,
kUNSUPPORTED_GRAPH = 7,
kUNSUPPORTED_NODE = 8,
kUNSUPPORTED_NODE_ATTR = 9,
kUNSUPPORTED_NODE_INPUT = 10,
kUNSUPPORTED_NODE_DATATYPE = 11,
kUNSUPPORTED_NODE_DYNAMIC = 12,
kUNSUPPORTED_NODE_SHAPE = 13,
kREFIT_FAILED = 14
};
//!
//! Maximum number of flags in the ErrorCode enum.
//!
//! \see ErrorCode
//!
template <>
constexpr inline int32_t EnumMax<ErrorCode>() noexcept
{
return 14;
}
//!
//! \brief Represents one or more OnnxParserFlag values using binary OR
//! operations, e.g., 1U << OnnxParserFlag::kNATIVE_INSTANCENORM
//!
//! \see IParser::setFlags() and IParser::getFlags()
//!
using OnnxParserFlags = uint32_t;
enum class OnnxParserFlag : int32_t
{
//! Parse the ONNX model into the INetworkDefinition with the intention of using TensorRT's native layer
//! implementation over the plugin implementation for InstanceNormalization nodes.
//! This flag is required when building version-compatible or hardware-compatible engines.
//! This flag is set to be ON by default.
kNATIVE_INSTANCENORM = 0
};
//!
//! Maximum number of flags in the OnnxParserFlag enum.
//!
//! \see OnnxParserFlag
//!
template <>
constexpr inline int32_t EnumMax<OnnxParserFlag>() noexcept
{
return 1;
}
//!
//! \class IParserError
//!
//! \brief an object containing information about an error
//!
class IParserError
{
public:
//!
//!\brief the error code.
//!
virtual ErrorCode code() const = 0;
//!
//!\brief description of the error.
//!
virtual char const* desc() const = 0;
//!
//!\brief source file in which the error occurred.
//!
virtual char const* file() const = 0;
//!
//!\brief source line at which the error occurred.
//!
virtual int line() const = 0;
//!
//!\brief source function in which the error occurred.
//!
virtual char const* func() const = 0;
//!
//!\brief index of the ONNX model node in which the error occurred.
//!
virtual int node() const = 0;
//!
//!\brief name of the node in which the error occurred.
//!
virtual char const* nodeName() const = 0;
//!
//!\brief name of the node operation in which the error occurred.
//!
virtual char const* nodeOperator() const = 0;
//!
//!\brief A list of the local function names, from the top level down, constituting the current
//! stack trace in which the error occurred. A top-level node that is not inside any
//! local function would return a nullptr.
//!
virtual char const* const* localFunctionStack() const = 0;
//!
//!\brief The size of the stack of local functions at the point where the error occurred.
//! A top-level node that is not inside any local function would correspond to
// a stack size of 0.
//!
virtual int32_t localFunctionStackSize() const = 0;
protected:
virtual ~IParserError() {}
};
//!
//! \class IParser
//!
//! \brief an object for parsing ONNX models into a TensorRT network definition
//!
//! \warning If the ONNX model has a graph output with the same name as a graph input,
//! the output will be renamed by prepending "__".
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IParser
{
public:
//!
//! \brief Parse a serialized ONNX model into the TensorRT network.
//! This method has very limited diagnostics. If parsing the serialized model
//! fails for any reason (e.g. unsupported IR version, unsupported opset, etc.)
//! it the user responsibility to intercept and report the error.
//! To obtain a better diagnostic, use the parseFromFile method below.
//!
//! \param serialized_onnx_model Pointer to the serialized ONNX model
//! \param serialized_onnx_model_size Size of the serialized ONNX model
//! in bytes
//! \param model_path Absolute path to the model file for loading external weights if required
//! \return true if the model was parsed successfully
//! \see getNbErrors() getError()
//!
virtual bool parse(
void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) noexcept
= 0;
//!
//! \brief Parse an onnx model file, which can be a binary protobuf or a text onnx model
//! calls parse method inside.
//!
//! \param onnxModelFile name
//! \param verbosity Level
//!
//! \return true if the model was parsed successfully
//!
//!
virtual bool parseFromFile(const char* onnxModelFile, int verbosity) noexcept = 0;
//! [DEPRECATED] Deprecated in TensorRT 10.1. See supportsModelV2.
//!
//! \brief Check whether TensorRT supports a particular ONNX model.
//! If the function returns True, one can proceed to engine building
//! without having to call \p parse or \p parseFromFile.
//!
//! \param serialized_onnx_model Pointer to the serialized ONNX model
//! \param serialized_onnx_model_size Size of the serialized ONNX model
//! in bytes
//! \param sub_graph_collection Container to hold supported subgraphs
//! \param model_path Absolute path to the model file for loading external weights if required
//! \return true if the model is supported
//!
TRT_DEPRECATED virtual bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) noexcept = 0;
//!
//!\brief Parse a serialized ONNX model into the TensorRT network
//! with consideration of user provided weights
//!
//! \param serialized_onnx_model Pointer to the serialized ONNX model
//! \param serialized_onnx_model_size Size of the serialized ONNX model
//! in bytes
//! \return true if the model was parsed successfully
//! \see getNbErrors() getError()
//!
virtual bool parseWithWeightDescriptors(
void const* serialized_onnx_model, size_t serialized_onnx_model_size) noexcept
= 0;
//!
//!\brief Returns whether the specified operator may be supported by the
//! parser.
//!
//! Note that a result of true does not guarantee that the operator will be
//! supported in all cases (i.e., this function may return false-positives).
//!
//! \param op_name The name of the ONNX operator to check for support
//!
virtual bool supportsOperator(const char* op_name) const noexcept = 0;
//!
//!\brief Get the number of errors that occurred during prior calls to
//! \p parse
//!
//! \see getError() clearErrors() IParserError
//!
virtual int getNbErrors() const noexcept = 0;
//!
//!\brief Get an error that occurred during prior calls to \p parse
//!
//! \see getNbErrors() clearErrors() IParserError
//!
virtual IParserError const* getError(int index) const noexcept = 0;
//!
//!\brief Clear errors from prior calls to \p parse
//!
//! \see getNbErrors() getError() IParserError
//!
virtual void clearErrors() noexcept = 0;
virtual ~IParser() noexcept = default;
//!
//! \brief Query the plugin libraries needed to implement operations used by the parser in a version-compatible
//! engine.
//!
//! This provides a list of plugin libraries on the filesystem needed to implement operations
//! in the parsed network. If you are building a version-compatible engine using this network,
//! provide this list to IBuilderConfig::setPluginsToSerialize to serialize these plugins along
//! with the version-compatible engine, or, if you want to ship these plugin libraries externally
//! to the engine, ensure that IPluginRegistry::loadLibrary is used to load these libraries in the
//! appropriate runtime before deserializing the corresponding engine.
//!
//! \param[out] nbPluginLibs Returns the number of plugin libraries in the array, or -1 if there was an error.
//! \return Array of `nbPluginLibs` C-strings describing plugin library paths on the filesystem if nbPluginLibs > 0,
//! or nullptr otherwise. This array is owned by the IParser, and the pointers in the array are only valid until
//! the next call to parse(), supportsModel(), parseFromFile(), or parseWithWeightDescriptors().
//!
virtual char const* const* getUsedVCPluginLibraries(int64_t& nbPluginLibs) const noexcept = 0;
//!
//! \brief Set the parser flags.
//!
//! The flags are listed in the OnnxParserFlag enum.
//!
//! \param OnnxParserFlag The flags used when parsing an ONNX model.
//!
//! \note This function will override the previous set flags, rather than bitwise ORing the new flag.
//!
//! \see getFlags()
//!
virtual void setFlags(OnnxParserFlags onnxParserFlags) noexcept = 0;
//!
//! \brief Get the parser flags. Defaults to 0.
//!
//! \return The parser flags as a bitmask.
//!
//! \see setFlags()
//!
virtual OnnxParserFlags getFlags() const noexcept = 0;
//!
//! \brief clear a parser flag.
//!
//! clears the parser flag from the enabled flags.
//!
//! \see setFlags()
//!
virtual void clearFlag(OnnxParserFlag onnxParserFlag) noexcept = 0;
//!
//! \brief Set a single parser flag.
//!
//! Add the input parser flag to the already enabled flags.
//!
//! \see setFlags()
//!
virtual void setFlag(OnnxParserFlag onnxParserFlag) noexcept = 0;
//!
//! \brief Returns true if the parser flag is set
//!
//! \see getFlags()
//!
//! \return True if flag is set, false if unset.
//!
virtual bool getFlag(OnnxParserFlag onnxParserFlag) const noexcept = 0;
//!
//!\brief Return the i-th output ITensor object for the ONNX layer "name".
//!
//! Return the i-th output ITensor object for the ONNX layer "name".
//! If "name" is not found or i is out of range, return nullptr.
//! In the case of multiple nodes sharing the same name this function will return
//! the output tensors of the first instance of the node in the ONNX graph.
//!
//! \param name The name of the ONNX layer.
//!
//! \param i The index of the output. i must be in range [0, layer.num_outputs).
//!
virtual nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i) noexcept = 0;
//!
//! \brief Check whether TensorRT supports a particular ONNX model.
//! If the function returns True, one can proceed to engine building
//! without having to call \p parse or \p parseFromFile.
//! Results can be queried through \p getNbSubgraphs, \p isSubgraphSupported,
//! \p getSubgraphNodes.
//!
//! \param serializedOnnxModel Pointer to the serialized ONNX model
//! \param serializedOnnxModelSize Size of the serialized ONNX model in bytes
//! \param modelPath Absolute path to the model file for loading external weights if required
//! \return true if the model is supported
//!
virtual bool supportsModelV2(
void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath = nullptr) noexcept = 0;
//!
//! \brief Get the number of subgraphs. Calling this function before calling \p supportsModelV2 results in undefined
//! behavior.
//!
//!
//! \return Number of subgraphs.
//!
virtual int64_t getNbSubgraphs() noexcept = 0;
//!
//! \brief Returns whether the subgraph is supported. Calling this function before calling \p supportsModelV2
//! results in undefined behavior.
//!
//!
//! \param index Index of the subgraph.
//! \return Whether the subgraph is supported.
//!
virtual bool isSubgraphSupported(int64_t const index) noexcept = 0;
//!
//! \brief Get the nodes of the specified subgraph. Calling this function before calling \p supportsModelV2 results
//! in undefined behavior.
//!
//!
//! \param index Index of the subgraph.
//! \param subgraphLength Returns the length of the subgraph as reference.
//!
//! \return Pointer to the subgraph nodes array. This pointer is owned by the Parser.
//!
virtual int64_t* getSubgraphNodes(int64_t const index, int64_t& subgraphLength) noexcept = 0;
};
//!
//! \class IParserRefitter
//!
//! \brief An interface designed to refit weights from an ONNX model.
//!
//! \warning Do not inherit from this class, as doing so will break forward-compatibility of the API and ABI.
//!
class IParserRefitter
{
public:
//!
//! \brief Load a serialized ONNX model from memory and perform weight refit.
//!
//! \param serializedOnnxModel Pointer to the serialized ONNX model
//! \param serializedOnnxModelSize Size of the serialized ONNX model
//! in bytes
//! \param modelPath Absolute path to the model file for loading external weights if required
//! \return true if all the weights in the engine were refit successfully.
//!
//! The serialized ONNX model must be identical to the one used to generate the engine
//! that will be refit.
//!
virtual bool refitFromBytes(
void const* serializedOnnxModel, size_t serializedOnnxModelSize, char const* modelPath = nullptr) noexcept
= 0;
//!
//! \brief Load and parse a ONNX model from disk and perform weight refit.
//!
//! \param onnxModelFile Path to the ONNX model to load from disk.
//!
//! \return true if the model was loaded successfully, and if all the weights in the engine were refit successfully.
//!
//! The provided ONNX model must be identical to the one used to generate the engine
//! that will be refit.
//!
virtual bool refitFromFile(char const* onnxModelFile) noexcept = 0;
//!
//!\brief Get the number of errors that occurred during prior calls to \p refitFromBytes or \p refitFromFile
//!
//! \see getError() IParserError
//!
virtual int32_t getNbErrors() const noexcept = 0;
//!
//!\brief Get an error that occurred during prior calls to \p refitFromBytes or \p refitFromFile
//!
//! \see getNbErrors() IParserError
//!
virtual IParserError const* getError(int32_t index) const noexcept = 0;
//!
//!\brief Clear errors from prior calls to \p refitFromBytes or \p refitFromFile
//!
//! \see getNbErrors() getError() IParserError
//!
virtual void clearErrors() = 0;
virtual ~IParserRefitter() noexcept = default;
};
} // namespace nvonnxparser
extern "C" TENSORRTAPI void* createNvOnnxParser_INTERNAL(void* network, void* logger, int version) noexcept;
extern "C" TENSORRTAPI void* createNvOnnxParserRefitter_INTERNAL(
void* refitter, void* logger, int32_t version) noexcept;
extern "C" TENSORRTAPI int getNvOnnxParserVersion() noexcept;
namespace nvonnxparser
{
namespace
{
//!
//! \brief Create a new parser object
//!
//! \param network The network definition that the parser will write to
//! \param logger The logger to use
//! \return a new parser object or NULL if an error occurred
//!
//! Any input dimensions that are constant should not be changed after parsing,
//! because correctness of the translation may rely on those constants.
//! Changing a dynamic input dimension, i.e. one that translates to -1 in
//! TensorRT, to a constant is okay if the constant is consistent with the model.
//! Each instance of the parser is designed to only parse one ONNX model once.
//!
//! \see IParser
//!
inline IParser* createParser(nvinfer1::INetworkDefinition& network, nvinfer1::ILogger& logger) noexcept
{
try
{
return static_cast<IParser*>(createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
}
catch (std::exception& e)
{
logger.log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, e.what());
}
return nullptr;
}
//!
//! \brief Create a new ONNX refitter object
//!
//! \param refitter The Refitter object used to refit the model
//! \param logger The logger to use
//! \return a new ParserRefitter object or NULL if an error occurred
//!
//! \see IParserRefitter
//!
inline IParserRefitter* createParserRefitter(nvinfer1::IRefitter& refitter, nvinfer1::ILogger& logger) noexcept
{
try
{
return static_cast<IParserRefitter*>(
createNvOnnxParserRefitter_INTERNAL(&refitter, &logger, NV_ONNX_PARSER_VERSION));
}
catch (std::exception& e)
{
logger.log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, e.what());
}
return nullptr;
}
} // namespace
} // namespace nvonnxparser
#endif // NV_ONNX_PARSER_H