Skip to content

Commit

Permalink
move upscale trt back
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Oct 27, 2024
1 parent a694e44 commit ed06380
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
28 changes: 17 additions & 11 deletions backend/src/UpscaleTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _load(self):
if self.trt_workspace_size > 0
else ""
)
+ ".dyn"
+ ".ts"
),
)

Expand All @@ -161,26 +161,32 @@ def _load(self):
device=self.device,
)
]
dummy_input_cpu_fp32 = [
torch.zeros(
(1, 3, 32, 32),
dtype=torch.float32,
device="cpu",
)
]

module = torch.jit.trace(model.float().cpu(), dummy_input_cpu_fp32)
module.to(device=self.device, dtype=self.dtype)
module = torch_tensorrt.compile(
model,
ir="dynamo",
module,
ir="ts",
inputs=inputs,
enabled_precisions={self.dtype},
device=self.device,
debug=self.trt_debug,
device=torch_tensorrt.Device(gpu_id=0),
workspace_size=self.trt_workspace_size,
truncate_long_and_double=True,
min_block_size=1,
max_aux_streams=self.trt_aux_streams,
optimization_level=self.trt_optimization_level,
cache_built_engines=False,
reuse_cached_engines=False,
)

printAndLog(f"Saving TensorRT engine to {trt_engine_path}")
torch_tensorrt.save(module, trt_engine_path, inputs=inputs)
torch.jit.save(module, trt_engine_path)

printAndLog(f"Loading TensorRT engine from {trt_engine_path}")
model = torch.export.load(trt_engine_path).module()
model = torch.jit.load(trt_engine_path)

self.model = model
self.prepareStream.synchronize()
Expand Down
4 changes: 2 additions & 2 deletions backend/src/spandrel/architectures/SPAN/__arch/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ def __init__(
self.in_channels = num_in_ch
self.out_channels = num_out_ch
self.img_range = img_range
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1).cuda()
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)

self.no_norm: torch.Tensor | None
if not norm:
self.register_buffer("no_norm", torch.zeros(1).cuda())
self.register_buffer("no_norm", torch.zeros(1))
else:
self.no_norm = None

Expand Down

0 comments on commit ed06380

Please sign in to comment.