diff --git a/backend/src/UpscaleTorch.py b/backend/src/UpscaleTorch.py index 5c49a400..e7ca5db8 100644 --- a/backend/src/UpscaleTorch.py +++ b/backend/src/UpscaleTorch.py @@ -149,7 +149,7 @@ def _load(self): if self.trt_workspace_size > 0 else "" ) - + ".dyn" + + ".ts" ), ) @@ -161,26 +161,32 @@ def _load(self): device=self.device, ) ] + dummy_input_cpu_fp32 = [ + torch.zeros( + (1, 3, 32, 32), + dtype=torch.float32, + device="cpu", + ) + ] + module = torch.jit.trace(model.float().cpu(), dummy_input_cpu_fp32) + module.to(device=self.device, dtype=self.dtype) module = torch_tensorrt.compile( - model, - ir="dynamo", + module, + ir="ts", inputs=inputs, enabled_precisions={self.dtype}, - device=self.device, - debug=self.trt_debug, + device=torch_tensorrt.Device(gpu_id=0), workspace_size=self.trt_workspace_size, + truncate_long_and_double=True, min_block_size=1, - max_aux_streams=self.trt_aux_streams, - optimization_level=self.trt_optimization_level, - cache_built_engines=False, - reuse_cached_engines=False, ) + printAndLog(f"Saving TensorRT engine to {trt_engine_path}") - torch_tensorrt.save(module, trt_engine_path, inputs=inputs) + torch.jit.save(module, trt_engine_path) printAndLog(f"Loading TensorRT engine from {trt_engine_path}") - model = torch.export.load(trt_engine_path).module() + model = torch.jit.load(trt_engine_path) self.model = model self.prepareStream.synchronize() diff --git a/backend/src/spandrel/architectures/SPAN/__arch/span.py b/backend/src/spandrel/architectures/SPAN/__arch/span.py index a982fabb..1cd7485b 100644 --- a/backend/src/spandrel/architectures/SPAN/__arch/span.py +++ b/backend/src/spandrel/architectures/SPAN/__arch/span.py @@ -254,11 +254,11 @@ def __init__( self.in_channels = num_in_ch self.out_channels = num_out_ch self.img_range = img_range - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda() + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) self.no_norm: torch.Tensor | None if not norm: - self.register_buffer("no_norm", torch.zeros(1).cuda()) + self.register_buffer("no_norm", torch.zeros(1)) else: self.no_norm = None