Skip to content

Commit

Permalink
fix span trt
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Nov 10, 2024
1 parent 910a239 commit 397f818
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
6 changes: 4 additions & 2 deletions backend/src/TensorRTHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def export_dynamo_model(
):
"""Exports a model using TensorRT Dynamo."""
model.to(device=device, dtype=dtype)
example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs]
exported_program = torch.export.export(
model, tuple(example_inputs), dynamic_shapes=None
)
Expand Down Expand Up @@ -106,8 +107,9 @@ def export_torchscript_model(
"""Exports a model using TorchScript."""

# maybe try to load it onto CUDA, and clear pytorch cache after.

module = torch.jit.trace(model.to(device=device, dtype=dtype), example_inputs)
model.to(device=device,dtype=dtype)
example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs]
module = torch.jit.trace(model, example_inputs) # have to put both on same device or sum
torch.cuda.empty_cache()
model = None

Expand Down
15 changes: 11 additions & 4 deletions backend/src/spandrel/architectures/SPAN/__arch/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
from typing import Literal
import sys

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -258,7 +259,8 @@ def __init__(
self.in_channels = num_in_ch
self.out_channels = num_out_ch
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
self.mean_half = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda().half()
self.mean_float = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda().float()

self.no_norm: torch.Tensor | None
if not norm:
Expand All @@ -285,13 +287,18 @@ def __init__(

@property
def is_norm(self):

return self.no_norm is None

def forward(self, x):
device = x.device
dtype = x.dtype
self.device = x.device
self.dtype = x.dtype
if self.dtype == torch.float16:
self.mean = self.mean_half
else:
self.mean = self.mean_float
if self.is_norm:
self.mean = self.mean.type_as(x).to(device=device, dtype=dtype)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range

out_feature = self.conv_1(x)
Expand Down

0 comments on commit 397f818

Please sign in to comment.