Skip to content

Commit

Permalink
change where engines are build
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Nov 6, 2024
1 parent 7680ad9 commit 45889b0
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 12 deletions.
9 changes: 5 additions & 4 deletions backend/src/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def frame_to_tensor(self, frame):
@torch.inference_mode()
def __init__(
self,
interpolateModelPath: str,
modelPath: str,
ceilInterpolateFactor: int = 2,
width: int = 1920,
height: int = 1080,
Expand All @@ -98,7 +98,7 @@ def __init__(
trt_workspace_size: int = 0,
trt_max_aux_streams: int | None = None,
trt_optimization_level: int = 5,
trt_cache_dir: str = modelsDirectory(),
trt_cache_dir: str = None,
trt_debug: bool = False,
rife_trt_mode: str = "accurate",
trt_static_shape: bool = True,
Expand All @@ -115,7 +115,7 @@ def __init__(

printAndLog("Using device: " + str(device))

self.interpolateModel = interpolateModelPath
self.interpolateModel = modelPath
self.width = width
self.height = height

Expand All @@ -124,6 +124,8 @@ def __init__(
self.trt_workspace_size = trt_workspace_size
self.trt_max_aux_streams = trt_max_aux_streams
self.trt_optimization_level = trt_optimization_level
if trt_cache_dir is None:
trt_cache_dir = os.path.dirname(modelPath) # use the model directory as the cache directory
self.trt_cache_dir = trt_cache_dir
self.backend = backend
self.ceilInterpolateFactor = ceilInterpolateFactor
Expand Down Expand Up @@ -279,7 +281,6 @@ def _load(self):

trtHandler = TorchTensorRTHandler(
trt_optimization_level=self.trt_optimization_level,
trt_cache_dir=self.trt_cache_dir,
)

base_trt_engine_path = os.path.join(
Expand Down
3 changes: 1 addition & 2 deletions backend/src/RenderVideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,14 @@ def setupInterpolate(self):

if self.backend == "pytorch" or self.backend == "tensorrt":
interpolateRifePytorch = InterpolateRifeTorch(
interpolateModelPath=self.interpolateModel,
modelPath=self.interpolateModel,
ceilInterpolateFactor=self.ceilInterpolateFactor,
width=self.width,
height=self.height,
device=self.device,
dtype=self.precision,
backend=self.backend,
trt_optimization_level=self.trt_optimization_level,
rife_trt_mode=self.rife_trt_mode,
)
self.frameSetupFunction = interpolateRifePytorch.frame_to_tensor
self.undoSetup = interpolateRifePytorch.uncacheFrame
Expand Down
3 changes: 0 additions & 3 deletions backend/src/TensorRTHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tensorrt
import torch
import torch_tensorrt
from .Util import modelsDirectory
from torch._decomp import get_decompositions


Expand All @@ -14,7 +13,6 @@ def __init__(
trt_workspace_size: int = 0,
max_aux_streams: int | None = None,
trt_optimization_level: int = 3,
trt_cache_dir: str = modelsDirectory(),
debug: bool = False,
static_shape: bool = True,
):
Expand All @@ -24,7 +22,6 @@ def __init__(
self.trt_workspace_size = trt_workspace_size
self.max_aux_streams = max_aux_streams
self.optimization_level = trt_optimization_level
self.cache_dir = trt_cache_dir
self.debug = debug
self.static_shape = static_shape # Unused for now

Expand Down
7 changes: 4 additions & 3 deletions backend/src/UpscaleTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from src.Util import (
currentDirectory,
modelsDirectory,
printAndLog,
check_bfloat16_support,
)
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(
backend: str = "pytorch",
# trt options
trt_workspace_size: int = 0,
trt_cache_dir: str = modelsDirectory(),
trt_cache_dir: str = None,
trt_optimization_level: int = 3,
trt_max_aux_streams: int | None = None,
trt_debug: bool = False,
Expand All @@ -89,6 +88,8 @@ def __init__(
self.tile = [self.tilesize, self.tilesize]
self.modelPath = modelPath
self.backend = backend
if trt_cache_dir is None:
trt_cache_dir = os.path.dirname(modelPath) # use the model directory as the cache directory
self.trt_cache_dir = trt_cache_dir
self.trt_workspace_size = trt_workspace_size
self.trt_optimization_level = trt_optimization_level
Expand Down Expand Up @@ -135,7 +136,7 @@ def _load(self):
from .TensorRTHandler import TorchTensorRTHandler

trtHandler = TorchTensorRTHandler(
export_format="torchscript", trt_cache_dir=self.trt_cache_dir
export_format="torchscript"
)

trt_engine_path = os.path.join(
Expand Down

0 comments on commit 45889b0

Please sign in to comment.