From 45889b02581fb6bdeeae2f5d6960e9db9ef45214 Mon Sep 17 00:00:00 2001 From: TNTwise Date: Tue, 5 Nov 2024 19:32:45 -0600 Subject: [PATCH] change where engines are build --- backend/src/InterpolateTorch.py | 9 +++++---- backend/src/RenderVideo.py | 3 +-- backend/src/TensorRTHandler.py | 3 --- backend/src/UpscaleTorch.py | 7 ++++--- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index f68fd478..2c34fb14 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -86,7 +86,7 @@ def frame_to_tensor(self, frame): @torch.inference_mode() def __init__( self, - interpolateModelPath: str, + modelPath: str, ceilInterpolateFactor: int = 2, width: int = 1920, height: int = 1080, @@ -98,7 +98,7 @@ def __init__( trt_workspace_size: int = 0, trt_max_aux_streams: int | None = None, trt_optimization_level: int = 5, - trt_cache_dir: str = modelsDirectory(), + trt_cache_dir: str = None, trt_debug: bool = False, rife_trt_mode: str = "accurate", trt_static_shape: bool = True, @@ -115,7 +115,7 @@ def __init__( printAndLog("Using device: " + str(device)) - self.interpolateModel = interpolateModelPath + self.interpolateModel = modelPath self.width = width self.height = height @@ -124,6 +124,8 @@ def __init__( self.trt_workspace_size = trt_workspace_size self.trt_max_aux_streams = trt_max_aux_streams self.trt_optimization_level = trt_optimization_level + if trt_cache_dir is None: + trt_cache_dir = os.path.dirname(modelPath) # use the model directory as the cache directory self.trt_cache_dir = trt_cache_dir self.backend = backend self.ceilInterpolateFactor = ceilInterpolateFactor @@ -279,7 +281,6 @@ def _load(self): trtHandler = TorchTensorRTHandler( trt_optimization_level=self.trt_optimization_level, - trt_cache_dir=self.trt_cache_dir, ) base_trt_engine_path = os.path.join( diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index 96f95918..615c1827 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -324,7 +324,7 @@ def setupInterpolate(self): if self.backend == "pytorch" or self.backend == "tensorrt": interpolateRifePytorch = InterpolateRifeTorch( - interpolateModelPath=self.interpolateModel, + modelPath=self.interpolateModel, ceilInterpolateFactor=self.ceilInterpolateFactor, width=self.width, height=self.height, @@ -332,7 +332,6 @@ def setupInterpolate(self): dtype=self.precision, backend=self.backend, trt_optimization_level=self.trt_optimization_level, - rife_trt_mode=self.rife_trt_mode, ) self.frameSetupFunction = interpolateRifePytorch.frame_to_tensor self.undoSetup = interpolateRifePytorch.uncacheFrame diff --git a/backend/src/TensorRTHandler.py b/backend/src/TensorRTHandler.py index af2e56d5..e87808e1 100644 --- a/backend/src/TensorRTHandler.py +++ b/backend/src/TensorRTHandler.py @@ -3,7 +3,6 @@ import tensorrt import torch import torch_tensorrt -from .Util import modelsDirectory from torch._decomp import get_decompositions @@ -14,7 +13,6 @@ def __init__( trt_workspace_size: int = 0, max_aux_streams: int | None = None, trt_optimization_level: int = 3, - trt_cache_dir: str = modelsDirectory(), debug: bool = False, static_shape: bool = True, ): @@ -24,7 +22,6 @@ def __init__( self.trt_workspace_size = trt_workspace_size self.max_aux_streams = max_aux_streams self.optimization_level = trt_optimization_level - self.cache_dir = trt_cache_dir self.debug = debug self.static_shape = static_shape # Unused for now diff --git a/backend/src/UpscaleTorch.py b/backend/src/UpscaleTorch.py index a90e40a4..a4200f6f 100644 --- a/backend/src/UpscaleTorch.py +++ b/backend/src/UpscaleTorch.py @@ -9,7 +9,6 @@ from src.Util import ( currentDirectory, - modelsDirectory, printAndLog, check_bfloat16_support, ) @@ -65,7 +64,7 @@ def __init__( backend: str = "pytorch", # trt options trt_workspace_size: int = 0, - trt_cache_dir: str = modelsDirectory(), + trt_cache_dir: str = None, trt_optimization_level: int = 3, trt_max_aux_streams: int | None = None, trt_debug: bool = False, @@ -89,6 +88,8 @@ def __init__( self.tile = [self.tilesize, self.tilesize] self.modelPath = modelPath self.backend = backend + if trt_cache_dir is None: + trt_cache_dir = os.path.dirname(modelPath) # use the model directory as the cache directory self.trt_cache_dir = trt_cache_dir self.trt_workspace_size = trt_workspace_size self.trt_optimization_level = trt_optimization_level @@ -135,7 +136,7 @@ def _load(self): from .TensorRTHandler import TorchTensorRTHandler trtHandler = TorchTensorRTHandler( - export_format="torchscript", trt_cache_dir=self.trt_cache_dir + export_format="torchscript" ) trt_engine_path = os.path.join(