From ad8a89c7bf16c395c897d7797ea88a1460dfd368 Mon Sep 17 00:00:00 2001 From: tntwise Date: Thu, 23 May 2024 21:59:14 +0000 Subject: [PATCH] speed up onnx output --- src/torch/UpscaleImageTensorRT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torch/UpscaleImageTensorRT.py b/src/torch/UpscaleImageTensorRT.py index ed2c6ae6..afb89143 100644 --- a/src/torch/UpscaleImageTensorRT.py +++ b/src/torch/UpscaleImageTensorRT.py @@ -94,7 +94,7 @@ def pytorchExportToONNX(self): # Loads model via spandrel, and exports to onnx state_dict = model.state_dict() model.eval().cuda() model.load_state_dict(state_dict, strict=True) - input = torch.rand(1, 3, 256, 256).cuda() + input = torch.rand(1, 3, 20, 20).cuda() if self.half: try: model.half()