From 397f81811b489b4d592bc805e0224fb657877f85 Mon Sep 17 00:00:00 2001 From: tntwise Date: Sun, 10 Nov 2024 14:55:00 -0800 Subject: [PATCH] fix span trt --- backend/src/TensorRTHandler.py | 6 ++++-- .../spandrel/architectures/SPAN/__arch/span.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/backend/src/TensorRTHandler.py b/backend/src/TensorRTHandler.py index 31d71ae2..0bc4a6c1 100644 --- a/backend/src/TensorRTHandler.py +++ b/backend/src/TensorRTHandler.py @@ -68,6 +68,7 @@ def export_dynamo_model( ): """Exports a model using TensorRT Dynamo.""" model.to(device=device, dtype=dtype) + example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs] exported_program = torch.export.export( model, tuple(example_inputs), dynamic_shapes=None ) @@ -106,8 +107,9 @@ def export_torchscript_model( """Exports a model using TorchScript.""" # maybe try to load it onto CUDA, and clear pytorch cache after. - - module = torch.jit.trace(model.to(device=device, dtype=dtype), example_inputs) + model.to(device=device,dtype=dtype) + example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs] + module = torch.jit.trace(model, example_inputs) # have to put both on same device or sum torch.cuda.empty_cache() model = None diff --git a/backend/src/spandrel/architectures/SPAN/__arch/span.py b/backend/src/spandrel/architectures/SPAN/__arch/span.py index 89cbf5a7..852bc5a1 100644 --- a/backend/src/spandrel/architectures/SPAN/__arch/span.py +++ b/backend/src/spandrel/architectures/SPAN/__arch/span.py @@ -2,6 +2,7 @@ from collections import OrderedDict from typing import Literal +import sys import torch import torch.nn.functional as F @@ -258,7 +259,8 @@ def __init__( self.in_channels = num_in_ch self.out_channels = num_out_ch self.img_range = img_range - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + self.mean_half = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda().half() + self.mean_float = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda().float() self.no_norm: torch.Tensor | None if not norm: @@ -285,13 +287,18 @@ def __init__( @property def is_norm(self): + return self.no_norm is None def forward(self, x): - device = x.device - dtype = x.dtype + self.device = x.device + self.dtype = x.dtype + if self.dtype == torch.float16: + self.mean = self.mean_half + else: + self.mean = self.mean_float if self.is_norm: - self.mean = self.mean.type_as(x).to(device=device, dtype=dtype) + self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range out_feature = self.conv_1(x)