From 29ce1a743bc067c259ac6646ec67c111a84ee80a Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:52:51 +0800 Subject: [PATCH] Use torch_tensorrt.Device instead of torch.device in trt compile (#8051) Fixes #8050 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f301c2dd5c..bd65ffa33e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -851,7 +851,7 @@ def _onnx_trt_compile( # wrap the serialized TensorRT engine back to a TorchScript module. trt_model = torch_tensorrt.ts.embed_engine_in_new_module( f.getvalue(), - device=torch.device(f"cuda:{device}"), + device=torch_tensorrt.Device(f"cuda:{device}"), input_binding_names=input_names, output_binding_names=output_names, )