Skip to content

Commit

Permalink
vstrt/vs_tensorrt.cpp: added support for tensorrt 10
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Mar 27, 2024
1 parent 4281a4c commit 1426a2b
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions vstrt/vs_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ static const VSFrameRef *VS_CC vsTrtGetFrame(
const nvinfer1::Dims src_dim { instance.exec_context->getBindingDimensions(0) };
#endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85

const int src_planes { src_dim.d[1] };
const int src_tile_h { src_dim.d[2] };
const int src_tile_w { src_dim.d[3] };
const int src_planes { static_cast<int>(src_dim.d[1]) };
const int src_tile_h { static_cast<int>(src_dim.d[2]) };
const int src_tile_w { static_cast<int>(src_dim.d[3]) };

std::vector<const uint8_t *> src_ptrs;
src_ptrs.reserve(src_planes);
Expand All @@ -176,9 +176,9 @@ static const VSFrameRef *VS_CC vsTrtGetFrame(
const nvinfer1::Dims dst_dim { instance.exec_context->getBindingDimensions(1) };
#endif // NV_TENSORRT_MAJOR * 10 + NV_TENSORRT_MINOR >= 85

const int dst_planes { dst_dim.d[1] };
const int dst_tile_h { dst_dim.d[2] };
const int dst_tile_w { dst_dim.d[3] };
const int dst_planes { static_cast<int>(dst_dim.d[1]) };
const int dst_tile_h { static_cast<int>(dst_dim.d[2]) };
const int dst_tile_w { static_cast<int>(dst_dim.d[3]) };

std::vector<uint8_t *> dst_ptrs;
dst_ptrs.reserve(dst_planes);
Expand Down Expand Up @@ -521,7 +521,7 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
std::fprintf(stderr, "vstrt: TensorRT failed to load.\n");
return;
}
#else
#else // NV_TENSORRT_MAJOR == 9 && defined(_WIN32)
int ver = getInferLibVersion(); // must ensure this is the first nvinfer function called
#ifdef _WIN32
if (ver == 0) { // a sentinel value, see dummy function in win32.cpp.
Expand All @@ -530,9 +530,23 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit(
}
#endif // _WIN32
if (ver != NV_TENSORRT_VERSION) {
std::fprintf(stderr, "vstrt: TensorRT version mismatch, built with %d but loaded with %d; continue but fingers crossed...\n", NV_TENSORRT_VERSION, ver);
#if NV_TENSORRT_MAJOR >= 10
std::fprintf(
stderr,
"vstrt: TensorRT version mismatch, built with %ld but loaded with %d; continue but fingers crossed...\n",
NV_TENSORRT_VERSION,
ver
);
#else // NV_TENSORRT_MAJOR >= 10
std::fprintf(
stderr,
"vstrt: TensorRT version mismatch, built with %d but loaded with %d; continue but fingers crossed...\n",
NV_TENSORRT_VERSION,
ver
);
#endif // NV_TENSORRT_MAJOR >= 10
}
#endif
#endif // NV_TENSORRT_MAJOR == 9 && defined(_WIN32)

myself = plugin;

Expand Down

0 comments on commit 1426a2b

Please sign in to comment.