From ddec64b90875c91cb04d11f56c55514540cb2fe0 Mon Sep 17 00:00:00 2001 From: TNTwise Date: Sun, 3 Nov 2024 12:40:42 -0600 Subject: [PATCH] simplify engine building --- REAL-Video-Enhancer.py | 7 +- backend/rve-backend.py | 2 - backend/src/FFmpeg.py | 19 - .../InterpolateArchs/DetectInterpolateArch.py | 1 + .../src/InterpolateArchs/GMFSS/FeatureNet.py | 1 + backend/src/InterpolateArchs/GMFSS/GMFSS.py | 23 +- .../src/InterpolateArchs/GMFSS/softsplat.py | 2 +- .../src/InterpolateArchs/RIFE/warplayer.py | 6 +- backend/src/InterpolateNCNN.py | 2 +- backend/src/InterpolateTorch.py | 344 ++---------------- backend/src/TensorRTHandler.py | 66 ++++ backend/src/Util.py | 2 + src/Backendhandler.py | 1 - src/ModelHandler.py | 30 +- src/Util.py | 1 + src/ui/ProcessTab.py | 15 +- src/ui/SettingsTab.py | 2 +- 17 files changed, 167 insertions(+), 357 deletions(-) create mode 100644 backend/src/TensorRTHandler.py diff --git a/REAL-Video-Enhancer.py b/REAL-Video-Enhancer.py index 3862ad54..b6060f38 100644 --- a/REAL-Video-Enhancer.py +++ b/REAL-Video-Enhancer.py @@ -1,5 +1,6 @@ import sys import os + # patch for macos if sys.platform == "darwin": os.chdir(os.path.dirname(os.path.abspath(__file__))) @@ -15,7 +16,7 @@ ) from PySide6.QtGui import QIcon from src.Util import printAndLog -from mainwindow import Ui_MainWindow +from mainwindow import Ui_MainWindow from PySide6 import QtSvg # Import the QtSvg module so svg icons can be used on windows from src.version import version from src.InputHandler import VideoInputHandler @@ -464,11 +465,11 @@ def closeEvent(self, event): app = QApplication(sys.argv) # setting the pallette - + app.setPalette(Palette()) window = MainWindow() if len(sys.argv) > 1: - if sys.argv[1] == '--fullscreen': + if sys.argv[1] == "--fullscreen": window.showFullScreen() window.show() sys.exit(app.exec()) diff --git a/backend/rve-backend.py b/backend/rve-backend.py index 5d2d273d..2c05dd76 100644 --- a/backend/rve-backend.py +++ b/backend/rve-backend.py @@ -247,8 +247,6 @@ def checkArguments(self): raise ValueError( "Interpolation factor must be 1 if no interpolation model is used.\nPlease use --interpolateFactor 1 for no interpolation!" ) - - if __name__ == "__main__": diff --git a/backend/src/FFmpeg.py b/backend/src/FFmpeg.py index ef782922..606bcf65 100644 --- a/backend/src/FFmpeg.py +++ b/backend/src/FFmpeg.py @@ -377,22 +377,6 @@ def writeOutInformation(self, fcs): time.sleep(0.1) - def openMPVProc(self): - self.mpv_process = subprocess.Popen( - [ - "mpv", - "--no-correct-pts", - f"--fps={self.fps * self.ceilInterpolateFactor}", - "--demuxer-thread=no", - "--", - "-", - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=False, - ) - def writeOutVideoFrames(self): """ Writes out frames either to ffmpeg or to pipe @@ -401,9 +385,6 @@ def writeOutVideoFrames(self): ffmpeg -f rawvideo -pix_fmt rgb24 -s 1920x1080 -framerate 24 -i - -c:v libx264 -crf 18 -pix_fmt yuv420p -c:a copy out.mp4 """ log("Rendering") - # - - # self.openMPVProc() self.startTime = time.time() self.framesRendered: int = 1 self.last_length: int = 0 diff --git a/backend/src/InterpolateArchs/DetectInterpolateArch.py b/backend/src/InterpolateArchs/DetectInterpolateArch.py index 5df25513..0d4b0f8c 100644 --- a/backend/src/InterpolateArchs/DetectInterpolateArch.py +++ b/backend/src/InterpolateArchs/DetectInterpolateArch.py @@ -1,5 +1,6 @@ import torch + class RIFE46: def __init__(): pass diff --git a/backend/src/InterpolateArchs/GMFSS/FeatureNet.py b/backend/src/InterpolateArchs/GMFSS/FeatureNet.py index c3c9e087..dd28876b 100644 --- a/backend/src/InterpolateArchs/GMFSS/FeatureNet.py +++ b/backend/src/InterpolateArchs/GMFSS/FeatureNet.py @@ -1,6 +1,7 @@ import torch.nn as nn from .util import MyPReLU + class FeatureNet(nn.Module): """The quadratic model""" diff --git a/backend/src/InterpolateArchs/GMFSS/GMFSS.py b/backend/src/InterpolateArchs/GMFSS/GMFSS.py index 19409a95..6cf083e0 100644 --- a/backend/src/InterpolateArchs/GMFSS/GMFSS.py +++ b/backend/src/InterpolateArchs/GMFSS/GMFSS.py @@ -12,9 +12,18 @@ class GMFSS(nn.Module): - def __init__(self, model_path, model_type:str="union", scale:int=1, ensemble:bool=False, width:int=1920, height:int=1080): + def __init__( + self, + model_path, + model_type: str = "union", + scale: int = 1, + ensemble: bool = False, + width: int = 1920, + height: int = 1080, + ): super(GMFSS, self).__init__() from .FusionNet_u import GridNet + # get gmfss from here, as its a combination of all the models https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/GMFSS.pkl self.width = width self.height = height @@ -25,12 +34,12 @@ def __init__(self, model_path, model_type:str="union", scale:int=1, ensemble:boo self.fusionnet = GridNet() combined_state_dict = torch.load(model_path, map_location="cpu") if model_type != "base": - self.ifnet.load_state_dict(combined_state_dict['rife']) - self.flownet.load_state_dict(combined_state_dict['flownet']) - self.metricnet.load_state_dict(combined_state_dict['metricnet']) - self.feat_ext.load_state_dict(combined_state_dict['feat_ext']) - self.fusionnet.load_state_dict(combined_state_dict['fusionnet']) - + self.ifnet.load_state_dict(combined_state_dict["rife"]) + self.flownet.load_state_dict(combined_state_dict["flownet"]) + self.metricnet.load_state_dict(combined_state_dict["metricnet"]) + self.feat_ext.load_state_dict(combined_state_dict["feat_ext"]) + self.fusionnet.load_state_dict(combined_state_dict["fusionnet"]) + self.model_type = model_type self.scale = scale diff --git a/backend/src/InterpolateArchs/GMFSS/softsplat.py b/backend/src/InterpolateArchs/GMFSS/softsplat.py index 4bf23516..c8aa7146 100644 --- a/backend/src/InterpolateArchs/GMFSS/softsplat.py +++ b/backend/src/InterpolateArchs/GMFSS/softsplat.py @@ -642,4 +642,4 @@ def backward(self, tenOutgrad): # end -# end \ No newline at end of file +# end diff --git a/backend/src/InterpolateArchs/RIFE/warplayer.py b/backend/src/InterpolateArchs/RIFE/warplayer.py index e8e671c9..77b4a666 100644 --- a/backend/src/InterpolateArchs/RIFE/warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/warplayer.py @@ -31,6 +31,8 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) - tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) + tenFlow = torch.cat( + [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 + ) g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) - return torch.ops.aten.grid_sampler_2d(tenInput, g, 0, 1, True).to(dtype) \ No newline at end of file + return torch.ops.aten.grid_sampler_2d(tenInput, g, 0, 1, True).to(dtype) diff --git a/backend/src/InterpolateNCNN.py b/backend/src/InterpolateNCNN.py index c93fe366..0f787d00 100644 --- a/backend/src/InterpolateNCNN.py +++ b/backend/src/InterpolateNCNN.py @@ -237,7 +237,7 @@ def process_fast( self.height, self.width, self.channels ) - + class RIFE(Rife): ... diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index 0c1b3af9..e8702049 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -1,6 +1,6 @@ import torch import torch.nn.functional as F -from torch._decomp import get_decompositions + from .InterpolateArchs.DetectInterpolateArch import ArchDetect import math import os @@ -20,6 +20,7 @@ torch.set_grad_enabled(False) logging.basicConfig(level=logging.INFO) + class InterpolateRifeTorch: """InterpolateRifeTorch class for video interpolation using RIFE model in PyTorch. @@ -135,27 +136,6 @@ def __init__( self.trt_debug = trt_debug # too much output, i would like a progress bar tho self.rife_trt_mode = rife_trt_mode self.trt_static_shape = trt_static_shape - if not self.trt_static_shape: - if self.height < 256 and self.width < 256: - trt_min_shape: list[int] = [1, 1] - trt_opt_shape: list[int] = [128, 128] - trt_max_shape: list[int] = [256, 256] - elif self.height <= 1080 and self.width <= 1920: # if its 1080p or lower - trt_min_shape: list[int] = [128, 128] - trt_opt_shape: list[int] = [1920, 1080] - trt_max_shape: list[int] = [1920, 1080] - elif self.height <= 2160 and self.width <= 3840: # if its 4k or lower - trt_min_shape: list[int] = [1921, 1921] - trt_opt_shape: list[int] = [3840, 2160] - trt_max_shape: list[int] = [3840, 2160] - else: # too big, give warning and switch to static - warnAndLog( - "Warning: Reso lution is really high, this will most likely not work at all!\nFalling back to static shape." - ) - trt_static_shape = True - self.trt_min_shape = trt_min_shape - self.trt_opt_shape = trt_opt_shape - self.trt_max_shape = trt_max_shape if UHDMode: self.scale = 0.5 @@ -183,11 +163,13 @@ def _load(self): match interpolateArch.lower(): case "gmfss": from .InterpolateArchs.GMFSS.GMFSS import GMFSS + _pad = 64 self.gmfss = True self.rife46 = True case "rife46": from .InterpolateArchs.RIFE.rife46IFNET import IFNet + self.rife46 = True case "rife47": from .InterpolateArchs.RIFE.rife47IFNET import IFNet @@ -223,11 +205,10 @@ def _load(self): _pad = 64 num_ch_for_encode = 4 self.encode = Head() - + case _: errorAndLog("Invalid Interpolation Arch") - # model unspecific setup tmp = max(_pad, int(_pad / self.scale)) self.pw = math.ceil(self.width / tmp) * tmp @@ -238,17 +219,24 @@ def _load(self): if GMFSS is not None: for n in range(self.ceilInterpolateFactor): timestep = n / (self.ceilInterpolateFactor) - timestep_tens = torch.tensor([timestep], dtype=self.dtype, device=self.device).to(non_blocking=True) + timestep_tens = torch.tensor( + [timestep], dtype=self.dtype, device=self.device + ).to(non_blocking=True) self.timestepDict[timestep] = timestep_tens self.flownet = GMFSS( - model_path=self.interpolateModel, scale=self.scale, width=self.width, height=self.height + model_path=self.interpolateModel, + scale=self.scale, + width=self.width, + height=self.height, ) - #self.dtype = torch.float32 - #warnAndLog("GMFSS does not support float16, switching to float32") + # self.dtype = torch.float32 + # warnAndLog("GMFSS does not support float16, switching to float32") self.flownet.eval().to(device=self.device, dtype=self.dtype) if self.backend == "tensorrt": - warnAndLog("TensorRT is not implemented for GMFSS yet, falling back to PyTorch") - + warnAndLog( + "TensorRT is not implemented for GMFSS yet, falling back to PyTorch" + ) + elif IFNet is not None: for n in range(self.ceilInterpolateFactor): timestep = n / (self.ceilInterpolateFactor) @@ -261,23 +249,6 @@ def _load(self): self.timestepDict[timestep] = timestep_tens # rife specific setup self.set_rife_args() # sets backwarp_tenGrid and tenFlow_div - # set up dynamic - if self.trt_static_shape: - self.dimensions = f"{self.pw}x{self.ph}" - else: - for i in range(2): - self.trt_min_shape[i] = math.ceil(self.trt_min_shape[i] / tmp) * tmp - self.trt_opt_shape[i] = math.ceil(self.trt_opt_shape[i] / tmp) * tmp - self.trt_max_shape[i] = math.ceil(self.trt_max_shape[i] / tmp) * tmp - - self.dimensions = ( - f"min-{self.trt_min_shape[0]}x{self.trt_min_shape[1]}" - f"_opt-{self.trt_opt_shape[0]}x{self.trt_opt_shape[1]}" - f"_max-{self.trt_max_shape[0]}x{self.trt_max_shape[1]}" - ) - - - self.flownet = IFNet( scale=self.scale, ensemble=False, @@ -306,12 +277,14 @@ def _load(self): if self.backend == "tensorrt": import tensorrt import torch_tensorrt + from .TensorRTHandler import TorchTensorRTHandler + trtHandler = TorchTensorRTHandler(trt_optimization_level=self.trt_optimization_level,trt_cache_dir=self.trt_cache_dir) base_trt_engine_path = os.path.join( os.path.realpath(self.trt_cache_dir), ( f"{os.path.basename(self.interpolateModel)}" - + f"_{self.dimensions}" + + f"_{self.width}x{self.height}" + f"_{'fp16' if self.dtype == torch.float16 else 'fp32'}" + f"_scale-{self.scale}" + "_ensemble-False" @@ -366,84 +339,7 @@ def _load(self): device=self.device, ), ] - if self.trt_static_shape: - dynamic_shapes = None - inputs = [ - torch_tensorrt.Input( - shape=[1, 3, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input( - shape=[1, 3, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input( - shape=[1, 1, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input(shape=[2], dtype=torch.float), - torch_tensorrt.Input( - shape=[1, 2, self.ph, self.pw], dtype=torch.float - ), - ] - else: - self.trt_min_shape.reverse() - self.trt_opt_shape.reverse() - self.trt_max_shape.reverse() - - _height = torch.export.Dim( - "height", - min=self.trt_min_shape[0] // tmp, - max=self.trt_max_shape[0] // tmp, - ) - _width = torch.export.Dim( - "width", - min=self.trt_min_shape[1] // tmp, - max=self.trt_max_shape[1] // tmp, - ) - dim_height = _height * tmp - dim_width = _width * tmp - dynamic_shapes = { - "img0": {2: dim_height, 3: dim_width}, - "img1": {2: dim_height, 3: dim_width}, - "timestep": {2: dim_height, 3: dim_width}, - "tenFlow_div": {}, - "backwarp_tenGrid": {2: dim_height, 3: dim_width}, - } - - inputs = [ - torch_tensorrt.Input( - min_shape=[1, 3] + self.trt_min_shape, - opt_shape=[1, 3] + self.trt_opt_shape, - max_shape=[1, 3] + self.trt_max_shape, - dtype=self.dtype, - name="img0", - ), - torch_tensorrt.Input( - min_shape=[1, 3] + self.trt_min_shape, - opt_shape=[1, 3] + self.trt_opt_shape, - max_shape=[1, 3] + self.trt_max_shape, - dtype=self.dtype, - name="img1", - ), - torch_tensorrt.Input( - min_shape=[1, 1] + self.trt_min_shape, - opt_shape=[1, 1] + self.trt_opt_shape, - max_shape=[1, 1] + self.trt_max_shape, - dtype=self.dtype, - name="timestep", - ), - torch_tensorrt.Input( - shape=[2], - dtype=torch.float, - name="tenFlow_div", - ), - torch_tensorrt.Input( - min_shape=[1, 2] + self.trt_min_shape, - opt_shape=[1, 2] + self.trt_opt_shape, - max_shape=[1, 2] + self.trt_max_shape, - dtype=torch.float, - name="backwarp_tenGrid", - ), - ] else: # if not rife46 exampleInput = [ @@ -479,112 +375,6 @@ def _load(self): device=self.device, ), ] - if self.trt_static_shape: - dynamic_shapes = None - - inputs = [ - torch_tensorrt.Input( - shape=[1, 3, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input( - shape=[1, 3, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input( - shape=[1, 1, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input(shape=[2], dtype=torch.float), - torch_tensorrt.Input( - shape=[1, 2, self.ph, self.pw], dtype=torch.float - ), - torch_tensorrt.Input( - shape=[1, 1, self.ph, self.pw], dtype=self.dtype - ), - torch_tensorrt.Input( - shape=[1, 1, self.ph, self.pw], dtype=self.dtype - ), - ] - else: - self.trt_min_shape.reverse() - self.trt_opt_shape.reverse() - self.trt_max_shape.reverse() - - _height = torch.export.Dim( - "height", - min=self.trt_min_shape[0] // tmp, - max=self.trt_max_shape[0] // tmp, - ) - _width = torch.export.Dim( - "width", - min=self.trt_min_shape[1] // tmp, - max=self.trt_max_shape[1] // tmp, - ) - dim_height = _height * tmp - dim_width = _width * tmp - dynamic_shapes = { - "img0": {2: dim_height, 3: dim_width}, - "img1": {2: dim_height, 3: dim_width}, - "timestep": {2: dim_height, 3: dim_width}, - "tenFlow_div": {}, - "backwarp_tenGrid": {2: dim_height, 3: dim_width}, - "f0": {2: dim_height, 3: dim_width}, - "f1": {2: dim_height, 3: dim_width}, - } - - inputs = [ - torch_tensorrt.Input( - min_shape=[1, 3] + self.trt_min_shape, - opt_shape=[1, 3] + self.trt_opt_shape, - max_shape=[1, 3] + self.trt_max_shape, - dtype=self.dtype, - name="img0", - ), - torch_tensorrt.Input( - min_shape=[1, 3] + self.trt_min_shape, - opt_shape=[1, 3] + self.trt_opt_shape, - max_shape=[1, 3] + self.trt_max_shape, - dtype=self.dtype, - name="img1", - ), - torch_tensorrt.Input( - min_shape=[1, 1] + self.trt_min_shape, - opt_shape=[1, 1] + self.trt_opt_shape, - max_shape=[1, 1] + self.trt_max_shape, - dtype=self.dtype, - name="timestep", - ), - torch_tensorrt.Input( - shape=[2], - dtype=torch.float, - name="tenFlow_div", - ), - torch_tensorrt.Input( - min_shape=[1, 2] + self.trt_min_shape, - opt_shape=[1, 2] + self.trt_opt_shape, - max_shape=[1, 2] + self.trt_max_shape, - dtype=torch.float, - name="backwarp_tenGrid", - ), - torch_tensorrt.Input( - min_shape=[1, num_ch_for_encode] - + self.trt_min_shape, - opt_shape=[1, num_ch_for_encode] - + self.trt_opt_shape, - max_shape=[1, num_ch_for_encode] - + self.trt_max_shape, - dtype=self.dtype, - name="f0", - ), - torch_tensorrt.Input( - min_shape=[1, num_ch_for_encode] - + self.trt_min_shape, - opt_shape=[1, num_ch_for_encode] - + self.trt_opt_shape, - max_shape=[1, num_ch_for_encode] - + self.trt_max_shape, - dtype=self.dtype, - name="f1", - ), - ] if not os.path.isfile(encode_trt_engine_path): # build encode engine @@ -600,54 +390,7 @@ def _load(self): device=self.device, ), ] - if self.trt_static_shape: - dynamic_encode_shapes = None - encodedInput = ( - torch_tensorrt.Input( - shape=[1, 3, self.ph, self.pw], - dtype=torch.float, - ), - ) - else: - dynamic_encode_shapes = { - "x": {2: dim_height, 3: dim_width}, - } - encodedInput = [ - torch_tensorrt.Input( - min_shape=[1, 3] + self.trt_min_shape, - opt_shape=[1, 3] + self.trt_opt_shape, - max_shape=[1, 3] + self.trt_max_shape, - dtype=self.dtype, - name="x", - ), - ] - exported_encode_program = torch.export.export( - self.encode, - tuple(encodedExampleInputs), - dynamic_shapes=dynamic_encode_shapes, - ) - - self.encode = torch_tensorrt.dynamo.compile( - exported_encode_program, - tuple(encodedInput), - device=self.device, - enabled_precisions={self.dtype}, - debug=self.trt_debug, - num_avg_timing_iters=4, - workspace_size=self.trt_workspace_size, - min_block_size=1, - max_aux_streams=self.trt_max_aux_streams, - optimization_level=self.trt_optimization_level, - ) - printAndLog( - f"Saving TensorRT engine to {encode_trt_engine_path}" - ) - torch_tensorrt.save( - self.encode, - encode_trt_engine_path, - output_format="torchscript", - inputs=tuple(encodedExampleInputs), - ) + trtHandler.build_engine(model=self.encode,dtype=self.dtype,example_inputs=encodedExampleInputs,device=self.device,trt_engine_path=encode_trt_engine_path) printAndLog( f"Loading TensorRT engine from {encode_trt_engine_path}" @@ -655,33 +398,11 @@ def _load(self): self.encode = torch.jit.load(encode_trt_engine_path).eval() # export flow engine - printAndLog("Building TensorRT engine {}".format(trt_engine_path)) - exported_program = torch.export.export( - self.flownet, - tuple(exampleInput), - dynamic_shapes=dynamic_shapes, - ) - exported_program = exported_program.run_decompositions(get_decompositions([torch.ops.aten.grid_sampler_2d])) - - self.flownet = torch_tensorrt.dynamo.compile( - exported_program, - tuple(inputs), - device=self.device, - use_explicit_typing=True, - debug=self.trt_debug, - num_avg_timing_iters=4, - workspace_size=self.trt_workspace_size, - min_block_size=1, - max_aux_streams=self.trt_max_aux_streams, - optimization_level=self.trt_optimization_level, - ) - printAndLog(f"Saving TensorRT engine to {trt_engine_path}") - torch_tensorrt.save( - self.flownet, - trt_engine_path, - output_format="torchscript", - inputs=tuple(exampleInput), + printAndLog( + "Building TensorRT engine {}".format(trt_engine_path) ) + trtHandler.build_engine(model=self.flownet,dtype=self.dtype,example_inputs=exampleInput,device=self.device,trt_engine_path=trt_engine_path) + printAndLog(f"Loading TensorRT engine from {trt_engine_path}") self.flownet = torch.jit.load(trt_engine_path).eval() self.prepareStream.synchronize() @@ -758,7 +479,7 @@ def process(self, img0, img1, timestep, f0encode=None, f1encode=None): img0, img1, timestep, self.tenFlow_div, self.backwarp_tenGrid ) else: - #output = F.interpolate(self.flownet(img0, img1, timestep), (self.height, self.width), mode="bilinear") + # output = F.interpolate(self.flownet(img0, img1, timestep), (self.height, self.width), mode="bilinear") output = self.flownet(img0, img1, timestep) self.stream.synchronize() return self.tensor_to_frame(output) @@ -783,8 +504,13 @@ def encode_Frame(self, frame: torch.Tensor): @torch.inference_mode() def norm(self, frame: torch.Tensor): - return frame.reshape(self.height, self.width, 3).permute(2, 0, 1).unsqueeze(0).div(255.0) - + return ( + frame.reshape(self.height, self.width, 3) + .permute(2, 0, 1) + .unsqueeze(0) + .div(255.0) + ) + @torch.inference_mode() def frame_to_tensor(self, frame) -> torch.Tensor: with torch.cuda.stream(self.prepareStream): @@ -794,7 +520,7 @@ def frame_to_tensor(self, frame) -> torch.Tensor: dtype=torch.uint8, ).to(device=self.device, dtype=self.dtype, non_blocking=True) ) - frame = F.pad(frame,self.padding) + frame = F.pad(frame, self.padding) self.prepareStream.synchronize() return frame diff --git a/backend/src/TensorRTHandler.py b/backend/src/TensorRTHandler.py new file mode 100644 index 00000000..4d6aff48 --- /dev/null +++ b/backend/src/TensorRTHandler.py @@ -0,0 +1,66 @@ +import tensorrt +import torch +import torch_tensorrt +from .Util import modelsDirectory +from torch._decomp import get_decompositions + +class TorchTensorRTHandler: + def __init__( + self, + export_format: str = "dynamo", + trt_workspace_size: int = 0, + trt_max_aux_streams: int | None = None, + trt_optimization_level: int = 5, + trt_cache_dir: str = modelsDirectory(), + trt_debug: bool = False, + trt_static_shape: bool = True, + ): + self.export_format = export_format + self.trt_workspace_size = trt_workspace_size + self.trt_max_aux_streams = trt_max_aux_streams + self.trt_optimization_level = trt_optimization_level + self.trt_cache_dir = trt_cache_dir + self.trt_debug = trt_debug + self.trt_static_shape = trt_static_shape # unused for now + + def prepare_inputs(self, example_inputs): + inputs = [] + for input in example_inputs: + inputs.append(torch_tensorrt.Input(shape=input.shape, dtype=input.dtype)) + return inputs + + def build_engine( + self, + model: torch.nn.Module, + dtype: torch.dtype, + device: torch.device, + example_inputs: list[torch.Tensor], + trt_engine_path: str, + ): + model.to(device=device,dtype=dtype) + exported_program = torch.export.export( + model, + tuple(example_inputs), + dynamic_shapes=None, + ) + exported_program = exported_program.run_decompositions( + get_decompositions([torch.ops.aten.grid_sampler_2d]) + ) # this is a workaround for a bug in tensorrt where grid_sample has a bad output + model = torch_tensorrt.dynamo.compile( + exported_program, + tuple(self.prepare_inputs(example_inputs)), + device=device, + use_explicit_typing=True, # this allows for multi-precision engines + debug=self.trt_debug, + num_avg_timing_iters=4, + workspace_size=self.trt_workspace_size, + min_block_size=1, + max_aux_streams=self.trt_max_aux_streams, + optimization_level=self.trt_optimization_level, + ) + torch_tensorrt.save( + model, + trt_engine_path, + output_format="torchscript", + inputs=tuple(example_inputs), + ) diff --git a/backend/src/Util.py b/backend/src/Util.py index 2b9601d2..33967d1d 100644 --- a/backend/src/Util.py +++ b/backend/src/Util.py @@ -124,6 +124,7 @@ def checkForTensorRT() -> bool: except Exception as e: log(str(e)) + def checkForGMFSS() -> bool: try: import torch @@ -136,6 +137,7 @@ def checkForGMFSS() -> bool: return False return True + def check_bfloat16_support() -> bool: """ Function that checks if the torch backend supports bfloat16 diff --git a/src/Backendhandler.py b/src/Backendhandler.py index 616b0dc3..5daf4826 100644 --- a/src/Backendhandler.py +++ b/src/Backendhandler.py @@ -88,7 +88,6 @@ def recursivlyCheckIfDepsOnFirstInstallToMakeSureUserHasInstalledAtLeastOneBacke try: self.availableBackends, self.fullOutput = self.getAvailableBackends() if not len(self.availableBackends) == 0: - return self.availableBackends, self.fullOutput except SyntaxError as e: printAndLog(str(e)) diff --git a/src/ModelHandler.py b/src/ModelHandler.py index d0e3a554..6c8f1456 100644 --- a/src/ModelHandler.py +++ b/src/ModelHandler.py @@ -42,7 +42,7 @@ ), } pytorchInterpolateModels = { - "GMFSS (Slowest Model, Animation)": ("GMFSS.pkl","GMFSS.pkl",1,"gmfss"), + "GMFSS (Slowest Model, Animation)": ("GMFSS.pkl", "GMFSS.pkl", 1, "gmfss"), "RIFE 4.6 (Fastest Model)": ("rife4.6.pkl", "rife4.6.pkl", 1, "rife46"), "RIFE 4.7 (Smoothest Model)": ("rife4.7.pkl", "rife4.7.pkl", 1, "rife47"), "RIFE 4.15": ("rife4.15.pkl", "rife4.15.pkl", 1, "rife413"), @@ -52,14 +52,24 @@ 1, "rife413", ), - "RIFE 4.22 (Slowest Model, Animation)": ("rife4.22.pkl", "rife4.22.pkl", 1, "rife421"), + "RIFE 4.22 (Slowest Model, Animation)": ( + "rife4.22.pkl", + "rife4.22.pkl", + 1, + "rife421", + ), "RIFE 4.22-lite (Latest LITE model)": ( "rife4.22-lite.pkl", "rife4.22-lite.pkl", 1, "rife422-lite", ), - "RIFE 4.25 (Latest General Model, Recommended)": ("rife4.25.pkl", "rife4.25.pkl", 1, "rife425"), + "RIFE 4.25 (Latest General Model, Recommended)": ( + "rife4.25.pkl", + "rife4.25.pkl", + 1, + "rife425", + ), } tensorrtInterpolateModels = { "RIFE 4.6 (Fastest Model)": ("rife4.6.pkl", "rife4.6.pkl", 1, "rife46"), @@ -71,14 +81,24 @@ 1, "rife413", ), - "RIFE 4.22 (Slowest Model, Animation)": ("rife4.22.pkl", "rife4.22.pkl", 1, "rife421"), + "RIFE 4.22 (Slowest Model, Animation)": ( + "rife4.22.pkl", + "rife4.22.pkl", + 1, + "rife421", + ), "RIFE 4.22-lite (Latest LITE model)": ( "rife4.22-lite.pkl", "rife4.22-lite.pkl", 1, "rife422-lite", ), - "RIFE 4.25 (Latest General Model, Recommended)": ("rife4.25.pkl", "rife4.25.pkl", 1, "rife425"), + "RIFE 4.25 (Latest General Model, Recommended)": ( + "rife4.25.pkl", + "rife4.25.pkl", + 1, + "rife425", + ), } ncnnUpscaleModels = { "SPAN (Animation) (2X) (Fast)": ( diff --git a/src/Util.py b/src/Util.py index 58b21b25..a819f8a6 100644 --- a/src/Util.py +++ b/src/Util.py @@ -88,6 +88,7 @@ def backendDirectory(): else: return os.path.join(cwd, "backend") + def downloadTempDirectory() -> str: tmppath = os.path.join(cwd, "temp") createDirectory(tmppath) diff --git a/src/ui/ProcessTab.py b/src/ui/ProcessTab.py index 0f5f97f3..bdd07546 100644 --- a/src/ui/ProcessTab.py +++ b/src/ui/ProcessTab.py @@ -111,7 +111,7 @@ def QConnect(self): self.parent.outputFileSelectButton.clicked.connect(self.parent.openOutputFolder) # connect render button self.parent.startRenderButton.clicked.connect(self.parent.startRender) - cbs = (self.parent.methodComboBox,self.parent.backendComboBox) + cbs = (self.parent.methodComboBox, self.parent.backendComboBox) for combobox in cbs: combobox.currentIndexChanged.connect(self.switchInterpolationAndUpscale) # set tile size visible to false by default @@ -135,7 +135,6 @@ def killRenderProcess(self): except AttributeError: printAndLog("No render process!") - def switchInterpolationAndUpscale(self): """ Called every render, gets the correct model based on the backend and the method. @@ -146,7 +145,7 @@ def switchInterpolationAndUpscale(self): method = self.parent.methodComboBox.currentText() backend = self.parent.backendComboBox.currentText() models = self.getTotalModels(method=method, backend=backend) - + self.parent.modelComboBox.addItems(models) total_items = self.parent.modelComboBox.count() if total_items > 0 and method.lower() == "interpolate": @@ -159,8 +158,13 @@ def switchInterpolationAndUpscale(self): if not self.gmfssSupport: # Disable specific options based on the selected text for i in range(self.parent.modelComboBox.count()): - if self.parent.modelComboBox.itemText(i) == "GMFSS (Slowest Model, Animation)": # hacky solution, just straight copy pasted - self.parent.modelComboBox.model().item(i).setEnabled(self.gmfssSupport) + if ( + self.parent.modelComboBox.itemText(i) + == "GMFSS (Slowest Model, Animation)" + ): # hacky solution, just straight copy pasted + self.parent.modelComboBox.model().item(i).setEnabled( + self.gmfssSupport + ) else: self.parent.interpolationContainer.setVisible(False) self.parent.upscaleContainer.setVisible(True) @@ -334,7 +338,6 @@ def renderToPipeThread(self, method: str, backend: str, interpolateTimes: int): ), "--interpolateFactor", f"{interpolateTimes}", - ] if self.settings["preview_enabled"] == "True": command += [ diff --git a/src/ui/SettingsTab.py b/src/ui/SettingsTab.py index a77b273d..dfb09c02 100644 --- a/src/ui/SettingsTab.py +++ b/src/ui/SettingsTab.py @@ -105,7 +105,7 @@ def connectWriteSettings(self): self.parent.output_folder_location.textChanged.connect( lambda: self.writeOutputFolder() ) - + self.parent.resetSettingsBtn.clicked.connect(self.resetSettings) def writeOutputFolder(self):