From 6faf3e3599455869415d94a8dd154f1d6dcb07e7 Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Mon, 8 Jul 2024 18:08:16 +0530 Subject: [PATCH] updating trt workspace int64 (#1094) Signed-off-by: Manix --- sherpa-onnx/csrc/provider-config.cc | 2 +- sherpa-onnx/csrc/provider-config.h | 4 ++-- sherpa-onnx/python/csrc/tensorrt-config.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 8a58746c74..3c8f0ee473 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", + SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.", trt_max_workspace_size); return false; } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index ff96079090..fdc875e0a4 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -27,7 +27,7 @@ struct CudaConfig { }; struct TensorrtConfig { - int32_t trt_max_workspace_size = 2147483647; + int64_t trt_max_workspace_size = 2147483647; int32_t trt_max_partition_iterations = 10; int32_t trt_min_subgraph_size = 5; bool trt_fp16_enable = true; @@ -39,7 +39,7 @@ struct TensorrtConfig { bool trt_dump_subgraphs = false; TensorrtConfig() = default; - TensorrtConfig(int32_t trt_max_workspace_size, + TensorrtConfig(int64_t trt_max_workspace_size, int32_t trt_max_partition_iterations, int32_t trt_min_subgraph_size, bool trt_fp16_enable, diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc index 87962a2d34..ae48a945b4 100644 --- a/sherpa-onnx/python/csrc/tensorrt-config.cc +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) { using PyClass = TensorrtConfig; py::class_(*m, "TensorrtConfig") .def(py::init<>()) - .def(py::init([](int32_t trt_max_workspace_size, + .def(py::init([](int64_t trt_max_workspace_size, int32_t trt_max_partition_iterations, int32_t trt_min_subgraph_size, bool trt_fp16_enable,