diff --git a/REAL-Video-Enhancer.py b/REAL-Video-Enhancer.py index 25963f39..6f923479 100644 --- a/REAL-Video-Enhancer.py +++ b/REAL-Video-Enhancer.py @@ -25,9 +25,9 @@ from PySide6 import QtSvg # Import the QtSvg module so svg icons can be used on windows from src.version import version from src.InputHandler import VideoInputHandler + # other imports from src.Util import ( - openLink, getOSInfo, get_gpu_info, @@ -36,8 +36,7 @@ videosPath, checkForWritePermissions, getAvailableDiskSpace, - errorAndLog - + errorAndLog, ) from src.ui.ProcessTab import ProcessTab from src.ui.DownloadTab import DownloadTab @@ -171,7 +170,7 @@ def __init__(self): self.settingsTab = SettingsTab( parent=self, halfPrecisionSupport=halfPrecisionSupport ) - + self.moreTab = MoreTab(parent=self) # Startup Animation self.animationHandler = AnimationHandler() @@ -381,10 +380,7 @@ def disableProcessPage(self): def enableProcessPage(self): self.processSettingsContainer.setEnabled(True) - - def loadVideo(self, inputFile): - videoHandler = VideoInputHandler(inputText=inputFile) if videoHandler.isYoutubeLink() and videoHandler.isValidYoutubeLink(): videoHandler.getDataFromYoutubeVideo() @@ -409,7 +405,7 @@ def loadVideo(self, inputFile): self.outputFileSelectButton.setEnabled(True) self.isVideoLoaded = True self.updateVideoGUIDetails() - + # input file button def openInputFile(self): """ diff --git a/backend/rve-backend.py b/backend/rve-backend.py index 4e054790..902f62ad 100644 --- a/backend/rve-backend.py +++ b/backend/rve-backend.py @@ -25,7 +25,7 @@ def __init__(self): outputFile=self.args.output, interpolateModel=self.args.interpolateModel, interpolateFactor=self.args.interpolateFactor, - rifeVersion="v1", # some guy was angy about rifev2 being here, so I changed it to v1 + rifeVersion="v1", # some guy was angy about rifev2 being here, so I changed it to v1 upscaleModel=self.args.upscaleModel, tile_size=self.args.tilesize, # backend settings @@ -49,7 +49,7 @@ def __init__(self): half_prec_supp = False availableBackends = [] printMSG = "" - + if checkForTensorRT(): """ checks for tensorrt availability, and the current gpu works with it (if half precision is supported) @@ -57,10 +57,10 @@ def __init__(self): Half precision is only availaible on RTX 20 series and up """ import torch + half_prec_supp = check_bfloat16_support() if half_prec_supp: import tensorrt - availableBackends.append("tensorrt") printMSG += f"TensorRT Version: {tensorrt.__version__}\n" @@ -72,7 +72,7 @@ def __init__(self): availableBackends.append("pytorch") printMSG += f"PyTorch Version: {torch.__version__}\n" half_prec_supp = check_bfloat16_support() - + if checkForNCNN(): availableBackends.append("ncnn") printMSG += f"NCNN Version: 20220729\n" diff --git a/backend/src/FFmpeg.py b/backend/src/FFmpeg.py index 62d5adc1..6a9f6029 100644 --- a/backend/src/FFmpeg.py +++ b/backend/src/FFmpeg.py @@ -127,7 +127,7 @@ def __init__( self.previewFrame = None self.crf = crf self.sharedMemoryID = sharedMemoryID - + self.subtitleFiles = [] self.sharedMemoryThread = Thread( target=lambda: self.writeOutInformation(self.outputFrameChunkSize) diff --git a/backend/src/InterpolateArchs/DetectInterpolateArch.py b/backend/src/InterpolateArchs/DetectInterpolateArch.py index 129f29e7..e289ab0f 100644 --- a/backend/src/InterpolateArchs/DetectInterpolateArch.py +++ b/backend/src/InterpolateArchs/DetectInterpolateArch.py @@ -190,6 +190,7 @@ def excluded_keys() -> tuple: "transformer.layers.4.self_attn.merge.weight", ] + class GMFSS: def __init__(): pass @@ -207,7 +208,6 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", ] - archs = [RIFE46, RIFE47, RIFE413, RIFE420, RIFE421, RIFE422lite, RIFE425, GMFSS] diff --git a/backend/src/InterpolateArchs/RIFE/custom_warplayer.py b/backend/src/InterpolateArchs/RIFE/custom_warplayer.py index 1e382b62..3f6c8175 100644 --- a/backend/src/InterpolateArchs/RIFE/custom_warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/custom_warplayer.py @@ -10,12 +10,19 @@ from torch.fx.node import Argument, Target from torch.library import custom_op, register_fake from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion._ConverterRegistry import dynamo_tensorrt_converter -from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types, set_layer_name +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import ( + enforce_tensor_types, + set_layer_name, +) from torch_tensorrt.dynamo.types import TRTTensor -class WarpPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): +class WarpPlugin( + trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime +): def __init__(self) -> None: trt.IPluginV3.__init__(self) trt.IPluginV3OneCore.__init__(self) @@ -34,25 +41,40 @@ def clone(self) -> WarpPlugin: def get_capability_interface(self, type: trt.PluginCapabilityType) -> Self: return self - def configure_plugin(self, inp: list[trt.DynamicPluginTensorDesc], out: list[trt.DynamicPluginTensorDesc]) -> None: + def configure_plugin( + self, + inp: list[trt.DynamicPluginTensorDesc], + out: list[trt.DynamicPluginTensorDesc], + ) -> None: pass - def get_output_data_types(self, input_types: list[trt.DataType]) -> list[trt.DataType]: + def get_output_data_types( + self, input_types: list[trt.DataType] + ) -> list[trt.DataType]: return [input_types[0]] def get_output_shapes( - self, inputs: list[trt.DimsExprs], shape_inputs: list[trt.DimsExprs], expr_builder: trt.IExprBuilder + self, + inputs: list[trt.DimsExprs], + shape_inputs: list[trt.DimsExprs], + expr_builder: trt.IExprBuilder, ) -> list[trt.DimsExprs]: return [inputs[0]] - def supports_format_combination(self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int) -> bool: + def supports_format_combination( + self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int + ) -> bool: assert pos < len(in_out) assert num_inputs == 4 desc = in_out[pos].desc - return desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT + return ( + desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT + ) - def attach_to_context(self, resource_context: trt.IPluginResourceContext) -> WarpPlugin: + def attach_to_context( + self, resource_context: trt.IPluginResourceContext + ) -> WarpPlugin: return self.clone() def enqueue( @@ -68,11 +90,21 @@ def enqueue( itemsize = cp.dtype(dtype).itemsize with cp.cuda.ExternalStream(stream): - input0_mem = cp.cuda.UnownedMemory(inputs[0], np.prod(input_desc[0].dims) * itemsize, self) - input1_mem = cp.cuda.UnownedMemory(inputs[1], np.prod(input_desc[1].dims) * itemsize, self) - input2_mem = cp.cuda.UnownedMemory(inputs[2], np.prod(input_desc[2].dims) * itemsize, self) - input3_mem = cp.cuda.UnownedMemory(inputs[3], np.prod(input_desc[3].dims) * itemsize, self) - output_mem = cp.cuda.UnownedMemory(outputs[0], np.prod(output_desc[0].dims) * itemsize, self) + input0_mem = cp.cuda.UnownedMemory( + inputs[0], np.prod(input_desc[0].dims) * itemsize, self + ) + input1_mem = cp.cuda.UnownedMemory( + inputs[1], np.prod(input_desc[1].dims) * itemsize, self + ) + input2_mem = cp.cuda.UnownedMemory( + inputs[2], np.prod(input_desc[2].dims) * itemsize, self + ) + input3_mem = cp.cuda.UnownedMemory( + inputs[3], np.prod(input_desc[3].dims) * itemsize, self + ) + output_mem = cp.cuda.UnownedMemory( + outputs[0], np.prod(output_desc[0].dims) * itemsize, self + ) input0_ptr = cp.cuda.MemoryPointer(input0_mem, 0) input1_ptr = cp.cuda.MemoryPointer(input1_mem, 0) @@ -84,7 +116,9 @@ def enqueue( input1_d = cp.ndarray(input_desc[1].dims, dtype=dtype, memptr=input1_ptr) input2_d = cp.ndarray(input_desc[2].dims, dtype=dtype, memptr=input2_ptr) input3_d = cp.ndarray(input_desc[3].dims, dtype=dtype, memptr=input3_ptr) - output_d = cp.ndarray((np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr) + output_d = cp.ndarray( + (np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr + ) input0_t = torch.as_tensor(input0_d) input1_t = torch.as_tensor(input1_d) @@ -97,7 +131,9 @@ def enqueue( def get_fields_to_serialize(self) -> trt.PluginFieldCollection_: return trt.PluginFieldCollection() - def on_shape_change(self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc]) -> None: + def on_shape_change( + self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc] + ) -> None: pass def set_tactic(self, tactic: int) -> None: @@ -113,23 +149,40 @@ def __init__(self) -> None: self.field_names = trt.PluginFieldCollection() def create_plugin( - self, name: str, field_collection: trt.PluginFieldCollection_, phase: trt.TensorRTPhase + self, + name: str, + field_collection: trt.PluginFieldCollection_, + phase: trt.TensorRTPhase, ) -> WarpPlugin: return WarpPlugin() @custom_op("vsrife::warp", mutates_args=()) def warp_custom( - tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor + tenInput: torch.Tensor, + tenFlow: torch.Tensor, + tenFlow_div: torch.Tensor, + backwarp_tenGrid: torch.Tensor, ) -> torch.Tensor: - tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) + tenFlow = torch.cat( + [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 + ) g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) - return F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) + return F.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) @register_fake("vsrife::warp") def warp_fake( - tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor + tenInput: torch.Tensor, + tenFlow: torch.Tensor, + tenFlow_div: torch.Tensor, + backwarp_tenGrid: torch.Tensor, ) -> torch.Tensor: return tenInput @@ -152,17 +205,24 @@ def ops_warp( ) -> Union[TRTTensor, Sequence[TRTTensor]]: creator = trt.get_plugin_registry().get_creator("WarpPlugin", version="1") field_collection = trt.PluginFieldCollection() - plugin = creator.create_plugin("WarpPlugin", field_collection=field_collection, phase=trt.TensorRTPhase.BUILD) + plugin = creator.create_plugin( + "WarpPlugin", field_collection=field_collection, phase=trt.TensorRTPhase.BUILD + ) layer = ctx.net.add_plugin_v3(inputs=list(args), shape_inputs=[], plugin=plugin) set_layer_name(layer, target, name) return layer.get_output(0) def warp( - tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor + tenInput: torch.Tensor, + tenFlow: torch.Tensor, + tenFlow_div: torch.Tensor, + backwarp_tenGrid: torch.Tensor, ) -> torch.Tensor: dtype = tenInput.dtype tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) - return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to(dtype) \ No newline at end of file + return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to( + dtype + ) diff --git a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py index dbe34c35..ec1e232e 100644 --- a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py @@ -4,11 +4,13 @@ from torch.nn.functional import interpolate + try: from .custom_warplayer import warp except: from .warplayer import warp + def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d( @@ -173,7 +175,6 @@ def __init__( else: raise ValueError("rife_trt_mode must be 'fast' or 'accurate'") self.warp = warp - def forward(self, img0, img1, timestep, f0, f1): warped_img0 = img0 diff --git a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py index d9d96b5a..c44ea21b 100644 --- a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn.functional import interpolate + class MyPixelShuffle(nn.Module): def __init__(self, upscale_factor): super(MyPixelShuffle, self).__init__() diff --git a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py index 9d3866d3..7cedf409 100644 --- a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.nn.functional import interpolate + class MyPixelShuffle(nn.Module): def __init__(self, upscale_factor): super(MyPixelShuffle, self).__init__() @@ -167,7 +168,6 @@ def __init__( else: raise ValueError("rife_trt_mode must be 'fast' or 'accurate'") self.warp = warp - def forward(self, img0, img1, timestep, f0, f1): warped_img0 = img0 @@ -212,4 +212,4 @@ def forward(self, img0, img1, timestep, f0, f1): ][0] .permute(1, 2, 0) .mul(255) - ) \ No newline at end of file + ) diff --git a/backend/src/InterpolateArchs/RIFE/rife425IFNET.py b/backend/src/InterpolateArchs/RIFE/rife425IFNET.py index 5625d798..88986dfc 100644 --- a/backend/src/InterpolateArchs/RIFE/rife425IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife425IFNET.py @@ -141,7 +141,7 @@ def __init__( self.height = height self.backWarp = backwarp_tenGrid self.tenFlow = tenFlow_div - + self.paddedHeight = backwarp_tenGrid.shape[2] self.paddedWidth = backwarp_tenGrid.shape[3] diff --git a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py index 3b5b76eb..2f4bf212 100644 --- a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py @@ -133,7 +133,7 @@ def forward(self, img0, img1, timestep): warped_img1 = img1 flow = None mask = None - + for i in range(4): if flow is None: flow, mask = self.block[i]( @@ -151,9 +151,7 @@ def forward(self, img0, img1, timestep): mask = (mask + (-m1)) / 2 else: f0, m0 = self.block[i]( - torch.cat( - (warped_img0, warped_img1, timestep, mask), 1 - ), + torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i], ) @@ -186,4 +184,4 @@ def forward(self, img0, img1, timestep): temp = torch.sigmoid(latest_mask) frame = warped_img0 * temp + warped_img1 * (1 - temp) frame = frame[:, :, : self.height, : self.width][0] - return frame.permute(1, 2, 0).mul(255) \ No newline at end of file + return frame.permute(1, 2, 0).mul(255) diff --git a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py index ef134694..4749ef46 100644 --- a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py @@ -4,6 +4,7 @@ from torch.nn.functional import interpolate + try: from .custom_warplayer import warp except: @@ -175,8 +176,12 @@ def forward(self, img0, img1, timestep, f0, f1): flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 mask = (mask + (-m_)) / 2 else: - wf0 = self.warp(f0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid) - wf1 = self.warp(f1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid) + wf0 = self.warp( + f0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid + ) + wf1 = self.warp( + f1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid + ) fd, m0 = self.block[i]( torch.cat( ( diff --git a/backend/src/InterpolateArchs/RIFE/warplayer.py b/backend/src/InterpolateArchs/RIFE/warplayer.py index 8830cb5e..da69e221 100644 --- a/backend/src/InterpolateArchs/RIFE/warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/warplayer.py @@ -7,11 +7,20 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) - tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) + tenFlow = torch.cat( + [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 + ) g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) - grid_sample = F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) + grid_sample = F.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) return grid_sample.to(dtype) + """def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): tenFlow = torch.cat( [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 diff --git a/backend/src/InterpolateNCNN.py b/backend/src/InterpolateNCNN.py index 618cc159..0f2171e2 100644 --- a/backend/src/InterpolateNCNN.py +++ b/backend/src/InterpolateNCNN.py @@ -1,5 +1,6 @@ -from rife_ncnn_vulkan_python import wrapped +from rife_ncnn_vulkan_python import wrapped from time import sleep + # built-in imports import importlib import pathlib @@ -33,7 +34,7 @@ def __init__( self.width = width self.channels = channels self.max_timestep = max_timestep - self.output_bytes = bytearray(width*height*channels) + self.output_bytes = bytearray(width * height * channels) self.raw_out_image = wrapped.Image( self.output_bytes, self.width, self.height, self.channels ) @@ -42,20 +43,25 @@ def __init__( self.scale = scale else: raise ValueError("scale should be a power of 2") - + # determine if rife-v2 is used rife_v2 = ("rife-v2" in model) or ("rife-v3" in model) rife_v4 = "rife-v4" in model or "rife4" in model or "rife-4" in model # create raw RIFE wrapper object self._rife_object = wrapped.RifeWrapped( - gpuid, tta_mode, tta_temporal_mode, uhd_mode, num_threads, rife_v2, rife_v4, False + gpuid, + tta_mode, + tta_temporal_mode, + uhd_mode, + num_threads, + rife_v2, + rife_v4, + False, ) self._load(model) - def _load(self, model: str, model_dir: pathlib.Path = None): - # if model_dir is not specified if model_dir is None: model_dir = pathlib.Path(model) @@ -82,9 +88,9 @@ def process(self, image0: Image, image1: Image, timestep: float = 0.5) -> Image: # Return the image immediately instead of doing the copy in the upstream part which cause black output problems # The reason is that the upstream code use ncnn::Mat::operator=(const Mat& m) does a reference copy which won't # change our OutImage data. - if timestep == 0.: + if timestep == 0.0: return image0 - elif timestep == 1.: + elif timestep == 1.0: return image1 image0_bytes = bytearray(image0.tobytes()) @@ -107,15 +113,18 @@ def process(self, image0: Image, image1: Image, timestep: float = 0.5) -> Image: return Image.frombytes( image0.mode, (image0.width, image0.height), bytes(output_bytes) ) - def process_cv2(self, image0: np.ndarray, image1: np.ndarray, timestep: float = 0.5) -> np.ndarray: - if timestep == 0.: + + def process_cv2( + self, image0: np.ndarray, image1: np.ndarray, timestep: float = 0.5 + ) -> np.ndarray: + if timestep == 0.0: return image0 - elif timestep == 1.: + elif timestep == 1.0: return image1 - + image0_bytes = bytearray(image0.tobytes()) image1_bytes = bytearray(image1.tobytes()) - + self.channels = int(len(image0_bytes) / (image0.shape[1] * image0.shape[0])) self.output_bytes = bytearray(len(image0_bytes)) @@ -131,46 +140,55 @@ def process_cv2(self, image0: np.ndarray, image1: np.ndarray, timestep: float = ) self._rife_object.process(raw_in_image0, raw_in_image1, timestep, raw_out_image) - + return np.frombuffer(self.output_bytes, dtype=np.uint8).reshape( image0.shape[0], image0.shape[1], self.channels ) - + def uncache_frame(self): """ Used in instances where the scene change is active, and the frame needs to be uncached. """ self.image0_bytes = None self.raw_in_image0 = None - - def process_bytes(self, image0_bytes, image1_bytes, timestep: float = 0.5) -> np.ndarray: - #print(timestep) - if timestep == 0.: + def process_bytes( + self, image0_bytes, image1_bytes, timestep: float = 0.5 + ) -> np.ndarray: + # print(timestep) + if timestep == 0.0: return image0_bytes - elif timestep == 1.: + elif timestep == 1.0: return image1_bytes - + if self.image0_bytes is None: self.image0_bytes = bytearray(image0_bytes) self.raw_in_image0 = wrapped.Image( self.image0_bytes, self.width, self.height, self.channels ) image1_bytes = bytearray(image1_bytes) - + raw_in_image1 = wrapped.Image( image1_bytes, self.width, self.height, self.channels ) - - self._rife_object.process(self.raw_in_image0, raw_in_image1, timestep, self.raw_out_image) + self._rife_object.process( + self.raw_in_image0, raw_in_image1, timestep, self.raw_out_image + ) if timestep == self.max_timestep: self.image0_bytes = image1_bytes self.raw_in_image0 = raw_in_image1 return bytes(self.output_bytes) - - def process_fast(self, image0: np.ndarray, image1: np.ndarray, timestep: float = 0.5, shape: tuple = None, channels: int = 3) -> np.ndarray: + + def process_fast( + self, + image0: np.ndarray, + image1: np.ndarray, + timestep: float = 0.5, + shape: tuple = None, + channels: int = 3, + ) -> np.ndarray: """ An attempt at a faster implementation for NCNN that should speed it up significantly through better caching methods. @@ -182,10 +200,10 @@ def process_fast(self, image0: np.ndarray, image1: np.ndarray, timestep: float = :return: The processed image, format: np.ndarray. """ - - if timestep == 0.: + + if timestep == 0.0: return np.array(image0) - elif timestep == 1.: + elif timestep == 1.0: return np.array(image1) if self.height == None: @@ -195,19 +213,16 @@ def process_fast(self, image0: np.ndarray, image1: np.ndarray, timestep: float = self.height, self.width = shape image1_bytes = bytearray(image1.tobytes()) - raw_in_image1 = wrapped.Image( - image1_bytes, self.width, self.height, channels - ) + raw_in_image1 = wrapped.Image(image1_bytes, self.width, self.height, channels) if self.image0_bytes is None: self.image0_bytes = bytearray(image0.tobytes()) self.output_bytes = bytearray(len(self.image0_bytes)) - + raw_in_image0 = wrapped.Image( self.image0_bytes, self.width, self.height, channels ) - raw_out_image = wrapped.Image( self.output_bytes, self.width, self.height, channels ) @@ -219,8 +234,15 @@ def process_fast(self, image0: np.ndarray, image1: np.ndarray, timestep: float = return np.frombuffer(self.output_bytes, dtype=np.uint8).reshape( self.height, self.width, self.channels ) - - def process_fast_torch(self, image0: np.ndarray, image1: np.ndarray, timestep: float = 0.5, shape: tuple = None, channels: int = 3) -> np.ndarray: + + def process_fast_torch( + self, + image0: np.ndarray, + image1: np.ndarray, + timestep: float = 0.5, + shape: tuple = None, + channels: int = 3, + ) -> np.ndarray: """ An attempt at a faster implementation for NCNN that should speed it up significantly through better caching methods. @@ -265,9 +287,9 @@ def process_fast_torch(self, image0: np.ndarray, image1: np.ndarray, timestep: f return torch.frombuffer(self.output_bytes, dtype=torch.uint8).reshape( self.height, self.width, self.channels ) - -class RIFE(Rife): - ... + + +class RIFE(Rife): ... class InterpolateRIFENCNN: @@ -280,7 +302,6 @@ def __init__( gpuid: int = 0, max_timestep: int = 1, ): - self.max_timestep = max_timestep self.interpolateModelPath = interpolateModelPath self.width = width @@ -314,12 +335,10 @@ def process(self, img0, img1, timestep) -> bytes: frame = self.render.process_bytes(img0, img1, timestep) return frame - def normFrame(self,frame:bytes): + def normFrame(self, frame: bytes): return frame frame = bytearray(frame) - frame = wrapped.Image( - frame, self.width, self.height, 3 - ) + frame = wrapped.Image(frame, self.width, self.height, 3) return frame def uncacheFrame(self): diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index 627c028d..72668f29 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -348,10 +348,11 @@ def _load(self): import tensorrt import torch_tensorrt from .InterpolateArchs.RIFE.custom_warplayer import WarpPluginCreator + registry = tensorrt.get_plugin_registry() registry.register_creator(WarpPluginCreator()) - #torch_tensorrt.runtime.enable_cudagraphs() + # torch_tensorrt.runtime.enable_cudagraphs() logging.basicConfig(level=logging.INFO) base_trt_engine_path = os.path.join( os.path.realpath(self.trt_cache_dir), @@ -381,7 +382,6 @@ def _load(self): if self.trt_optimization_level is not None else "" ) - ), ) trt_engine_path = base_trt_engine_path + ".dyn" diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index 55c145b6..8358c8f9 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -94,7 +94,7 @@ def __init__( self.upscaleTimes = 1 # if no upscaling, it will default to 1 self.interpolateFactor = interpolateFactor # max timestep is a hack to make sure ncnn cache frames too early, and ncnn breaks if i modify the code at all so ig this is what we are doing - self.maxTimestep = (interpolateFactor- 1) / interpolateFactor + self.maxTimestep = (interpolateFactor - 1) / interpolateFactor self.ncnn = self.backend == "ncnn" self.rifeVersion = rifeVersion self.ceilInterpolateFactor = math.ceil(self.interpolateFactor) @@ -206,11 +206,11 @@ def renderInterpolate(self, frame, transition=False): timestep=timestep, ) elif self.ncnn: - self.interpolate( - img0=self.setupFrame0, - img1=self.setupFrame1, - timestep=self.maxTimestep, - ) + self.interpolate( + img0=self.setupFrame0, + img1=self.setupFrame1, + timestep=self.maxTimestep, + ) self.writeQueue.put(frame) @@ -316,7 +316,7 @@ def setupInterpolate(self): interpolateModelPath=self.interpolateModel, width=self.width, height=self.height, - max_timestep=self.maxTimestep + max_timestep=self.maxTimestep, ) self.frameSetupFunction = interpolateRifeNCNN.normFrame self.undoSetup = interpolateRifeNCNN.uncacheFrame diff --git a/backend/src/UpscaleNCNN.py b/backend/src/UpscaleNCNN.py index c34df62c..9d804377 100644 --- a/backend/src/UpscaleNCNN.py +++ b/backend/src/UpscaleNCNN.py @@ -3,12 +3,17 @@ from time import sleep import math import numpy as np + try: from upscale_ncnn_py import UPSCALE + method = "upscale_ncnn_py" except: import ncnn - method = "ncnn_vulkan" + + method = "ncnn_vulkan" + + class NCNNParam: """ Puts the last time an op shows up in a param in a dict @@ -59,7 +64,6 @@ def getNCNNScale(modelPath: str = "") -> int: return scale - class UpscaleNCNN: def __init__( self, @@ -70,7 +74,7 @@ def __init__( width: int = 1920, height: int = 1080, tilesize: int = 0, - tilePad = 10, + tilePad=10, ): # only import if necessary @@ -98,20 +102,20 @@ def _load(self): self.net.load_model(self.modelPath + ".bin") elif method == "upscale_ncnn_py": self.net = UPSCALE( - gpuid=self.gpuid, - model_str=self.modelPath, - num_threads=self.threads, - scale=self.scale, - tilesize=self.tilesize, - ) - - + gpuid=self.gpuid, + model_str=self.modelPath, + num_threads=self.threads, + scale=self.scale, + tilesize=self.tilesize, + ) + def hotUnload(self): self.model = None self.net = None def hotReload(self): self._load() + def NCNNImageMatFromNP(self, npArray: np.array): return ncnn.Mat.from_pixels( npArray, @@ -125,10 +129,8 @@ def NormalizeImage(self, mat, norm_vals): mat.substract_mean_normalize(mean_vals, norm_vals) def ClampNPArray(self, nparray: np.array) -> np.array: - return nparray.clip(0, 255) - def procNCNNVk(self, imageChunk): ex = self.net.create_extractor() frame = self.NCNNImageMatFromNP(imageChunk) @@ -153,12 +155,16 @@ def Upscale(self, imageChunk): if self.tilesize == 0: return self.procNCNNVk(imageChunk) else: - npArray = np.frombuffer(imageChunk,dtype=np.uint8).reshape(self.height,self.width,3).transpose(2,0,1) + npArray = ( + np.frombuffer(imageChunk, dtype=np.uint8) + .reshape(self.height, self.width, 3) + .transpose(2, 0, 1) + ) return self.upscaleTiledImage(npArray) elif method == "upscale_ncnn_py": return self.net.process_bytes(imageChunk, self.width, self.height, 3) - - def upscaleTiledImage(self, img:np.array): + + def upscaleTiledImage(self, img: np.array): batch, channel, height, width = img.shape output_shape = (batch, channel, height * self.scale, width * self.scale) @@ -203,13 +209,17 @@ def upscaleTiledImage(self, img:np.array): h, w = input_tile.shape[2:] pad_h = max(0, self.tilePad - h) pad_w = max(0, self.tilePad - w) - input_tile = np.pad(input_tile, ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w)), mode='edge') + input_tile = np.pad( + input_tile, + ((0, 0), (0, 0), (pad_h, pad_h), (pad_w, pad_w)), + mode="edge", + ) # Process tile using the model (assuming model is a function that can process numpy arrays) output_tile = self.procNCNNVk(input_tile) # Crop output tile to the expected size - output_tile = output_tile[:, :, :h * self.scale, :w * self.scale] + output_tile = output_tile[:, :, : h * self.scale, : w * self.scale] # Output tile area on total image output_start_x = input_start_x * self.scale @@ -233,4 +243,4 @@ def upscaleTiledImage(self, img:np.array): output_start_x_tile:output_end_x_tile, ] - return output \ No newline at end of file + return output diff --git a/backend/src/Util.py b/backend/src/Util.py index 01d714d9..fa4e986d 100644 --- a/backend/src/Util.py +++ b/backend/src/Util.py @@ -230,10 +230,13 @@ def checkForNCNN() -> bool: try: from rife_ncnn_vulkan_python import Rife import ncnn + try: from upscale_ncnn_py import UPSCALE except: - printAndLog("Warning: Cannot import upscale_ncnn, falling back to ncnn processing. (This can be slow!)") + printAndLog( + "Warning: Cannot import upscale_ncnn, falling back to ncnn processing. (This can be slow!)" + ) return True except ImportError as e: log(str(e)) diff --git a/src/DiscordRPC.py b/src/DiscordRPC.py index 33e3f490..0dd0490f 100644 --- a/src/DiscordRPC.py +++ b/src/DiscordRPC.py @@ -23,6 +23,7 @@ def timeout_handler(): finally: timer.cancel() + class DiscordRPC: def start_discordRPC(self, mode: str, videoName: str, backend: str): """ @@ -42,7 +43,8 @@ def start_discordRPC(self, mode: str, videoName: str, backend: str): ipc_path = f"{os.getenv('XDG_RUNTIME_DIR')}/discord-ipc-{i}" if not os.path.exists(ipc_path) or not os.path.isfile(ipc_path): os.symlink( - f"{os.getenv('HOME')}/.config/discord/{client_id}", ipc_path + f"{os.getenv('HOME')}/.config/discord/{client_id}", + ipc_path, ) except: log("Not flatpak") @@ -50,7 +52,7 @@ def start_discordRPC(self, mode: str, videoName: str, backend: str): self.RPC = Presence(client_id) # Initialize the client class self.RPC.connect() # Start the handshake loop - self. RPC.update( + self.RPC.update( state=f"{mode} Video", details=f"Backend: {backend}", large_image="logo-v2", @@ -62,5 +64,6 @@ def start_discordRPC(self, mode: str, videoName: str, backend: str): # Can only update rich presence every 15 seconds except Exception as e: log("Timed out!") + def closeRPC(self): self.RPC.close() diff --git a/src/DownloadDeps.py b/src/DownloadDeps.py index 06b077d2..dda8377b 100644 --- a/src/DownloadDeps.py +++ b/src/DownloadDeps.py @@ -169,7 +169,7 @@ def getPyTorchCUDADeps(self): ]""" # Nigthly test torchCUDALinuxDeps = [ - #"https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl", + # "https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl", "https://download.pytorch.org/whl/nightly/pytorch_triton-3.0.0%2Bdedb7bdf33-cp311-cp311-linux_x86_64.whl", "https://download.pytorch.org/whl/nightly/cu124_pypi_pkg/torch-2.5.0.dev20240826%2Bcu124-cp311-cp311-linux_x86_64.whl", "https://download.pytorch.org/whl/nightly/cu124/torchvision-0.20.0.dev20240826%2Bcu124-cp311-cp311-linux_x86_64.whl", @@ -179,7 +179,7 @@ def getPyTorchCUDADeps(self): "cupy-cuda12x==13.3.0", ] torchCUDAWindowsDeps = [ - #"https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl", + # "https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl", # "--pre", "https://download.pytorch.org/whl/nightly/cu124/torch-2.5.0.dev20240826%2Bcu124-cp311-cp311-win_amd64.whl", # "--pre", @@ -267,7 +267,7 @@ def downloadNCNNDeps(self): "opencv-python-headless", ] + self.getPlatformIndependentDeps() self.pipInstall(ncnnDeps) - self.pipInstall(['numpy==1.26.4','sympy']) + self.pipInstall(["numpy==1.26.4", "sympy"]) def downloadPyTorchROCmDeps(self): rocmLinuxDeps = [ @@ -281,4 +281,4 @@ def downloadPyTorchROCmDeps(self): if __name__ == "__main__": downloadDependencies = DownloadDependencies() - downloadDependencies.downloadPython() \ No newline at end of file + downloadDependencies.downloadPython() diff --git a/src/InputHandler.py b/src/InputHandler.py index e131b088..c875775b 100644 --- a/src/InputHandler.py +++ b/src/InputHandler.py @@ -72,11 +72,11 @@ def isYoutubeLink(self): def isValidVideoFile(self): return checkValidVideo(self.inputText) - + def isValidYoutubeLink(self): ydl_opts = { - 'quiet': True, # Suppress output - 'noplaylist': True, # Only check single video, not playlists + "quiet": True, # Suppress output + "noplaylist": True, # Only check single video, not playlists } with yt_dlp.YoutubeDL(ydl_opts) as ydl: @@ -84,11 +84,10 @@ def isValidYoutubeLink(self): # Extract info about the video info_dict = ydl.extract_info(self.inputText, download=False) # Check if there are available formats - if info_dict.get('formats'): + if info_dict.get("formats"): return True # Video is downloadable else: return False # No formats available except Exception as e: print(f"Error occurred: {e}") return False - diff --git a/src/ModelHandler.py b/src/ModelHandler.py index 6282ab34..ef88b11a 100644 --- a/src/ModelHandler.py +++ b/src/ModelHandler.py @@ -3,6 +3,7 @@ from .DownloadModels import DownloadModel from .ui.QTcustom import NetworkCheckPopup from .Util import currentDirectory, customModelsPath, createDirectory, printAndLog + """ Key value pairs of the model name in the GUI Data inside the tuple: @@ -235,16 +236,20 @@ customPytorchUpscaleModels = {} customNCNNUpscaleModels = {} for model in os.listdir(customModelsPath()): - pattern = r'\d+x|x+\d' + pattern = r"\d+x|x+\d" matches = re.findall(pattern, model) if len(matches) > 0: - upscaleFactor = int(matches[0].replace("x", "")) # get the integer value of the upscale factor + upscaleFactor = int( + matches[0].replace("x", "") + ) # get the integer value of the upscale factor if model.endswith(".bin"): customNCNNUpscaleModels[model] = (model, model, upscaleFactor, "custom") if model.endswith(".pth"): customPytorchUpscaleModels[model] = (model, model, upscaleFactor, "custom") else: - printAndLog(f"Custom model {model} does not have a valid upscale factor in the name") + printAndLog( + f"Custom model {model} does not have a valid upscale factor in the name" + ) pytorchUpscaleModels = pytorchUpscaleModels | customPytorchUpscaleModels tensorrtUpscaleModels = tensorrtUpscaleModels | customPytorchUpscaleModels ncnnUpscaleModels = ncnnUpscaleModels | customNCNNUpscaleModels diff --git a/src/Util.py b/src/Util.py index 0adf176c..05fe112f 100644 --- a/src/Util.py +++ b/src/Util.py @@ -153,6 +153,7 @@ def pythonPath() -> str: else os.path.join(cwd, "python", "python", "bin", "python3") ) + def customModelsPath() -> str: """ Returns the file path for the custom models directory. @@ -162,6 +163,7 @@ def customModelsPath() -> str: """ return os.path.join(cwd, "custom_models") + def modelsPath() -> str: """ Returns the file path for the models directory. @@ -293,7 +295,7 @@ def checkValidVideo(video_path): return False ret, frame = cap.read() - #if not ret: + # if not ret: # print(f"Error: Couldn't read frames from the video file '{video_path}'") # return False @@ -474,6 +476,8 @@ def openLink(link: str): :type link: str """ webbrowser.open(link) + + def errorAndLog(message: str): log("ERROR: " + message) raise os.error("ERROR: " + message) diff --git a/src/ui/ProcessTab.py b/src/ui/ProcessTab.py index 48fec0a4..e3629f12 100644 --- a/src/ui/ProcessTab.py +++ b/src/ui/ProcessTab.py @@ -99,9 +99,7 @@ def QConnect(self): # connect file select buttons self.parent.inputFileSelectButton.clicked.connect(self.parent.openInputFile) - self.parent.inputFileText.textChanged.connect( - self.parent.loadVideo - ) + self.parent.inputFileText.textChanged.connect(self.parent.loadVideo) self.parent.outputFileSelectButton.clicked.connect(self.parent.openOutputFolder) # connect render button self.parent.startRenderButton.clicked.connect(self.parent.startRender) @@ -225,7 +223,9 @@ def run( # discord rpc if self.settings["discord_rich_presence"] == "True": self.discordRPC = DiscordRPC() - self.discordRPC.start_discordRPC(method, os.path.basename(self.inputFile), backend) + self.discordRPC.start_discordRPC( + method, os.path.basename(self.inputFile), backend + ) DownloadModel( modelFile=self.modelFile, diff --git a/src/ui/SettingsTab.py b/src/ui/SettingsTab.py index f1c8e1e3..4944024e 100644 --- a/src/ui/SettingsTab.py +++ b/src/ui/SettingsTab.py @@ -4,6 +4,7 @@ from ..Util import currentDirectory, getPlatform, homedir, checkForWritePermissions from .QTcustom import RegularQTPopup + class SettingsTab: def __init__( self, @@ -19,7 +20,6 @@ def __init__( # disable half option if its not supported if not halfPrecisionSupport: self.parent.precision.removeItem(1) - """def connectWriteSettings(self): settings_and_combo_boxes = { @@ -47,7 +47,6 @@ def __init__( ) print(setting)""" - def connectWriteSettings(self): self.parent.precision.currentIndexChanged.connect( lambda: self.settings.writeSetting( @@ -98,7 +97,7 @@ def connectWriteSettings(self): ) ) self.parent.output_folder_location.textChanged.connect( - lambda:self.writeOutputFolder() + lambda: self.writeOutputFolder() ) self.parent.rife_trt_mode.currentIndexChanged.connect( lambda: self.settings.writeSetting( @@ -113,9 +112,9 @@ def writeOutputFolder(self): if os.path.exists(outputlocation) and os.path.isdir(outputlocation): if checkForWritePermissions(outputlocation): self.settings.writeSetting( - "output_folder_location", - str(outputlocation), - ) + "output_folder_location", + str(outputlocation), + ) else: RegularQTPopup("No permissions to export here!") @@ -170,9 +169,9 @@ def selectOutputFolder(self): if os.path.exists(outputlocation) and os.path.isdir(outputlocation): if checkForWritePermissions(outputlocation): self.settings.writeSetting( - "output_folder_location", - str(outputlocation), - ) + "output_folder_location", + str(outputlocation), + ) self.parent.output_folder_location.setText(outputlocation) else: RegularQTPopup("No permissions to export here!") @@ -195,7 +194,9 @@ def __init__(self): "discord_rich_presence": "True", "scene_detection_threshold": "2.0", "video_quality": "High", - "output_folder_location": os.path.join(f"{homedir}", "Videos") if getPlatform() != "darwin" else os.path.join(f"{homedir}", "Desktop"), + "output_folder_location": os.path.join(f"{homedir}", "Videos") + if getPlatform() != "darwin" + else os.path.join(f"{homedir}", "Desktop"), "rife_trt_mode": "accurate", } self.allowedSettings = { @@ -277,7 +278,8 @@ def writeOutCurrentSettings(self): for key, value in self.settings.items(): if key in self.defaultSettings: # check if the key is valid if ( - value in self.allowedSettings[key] or self.allowedSettings[key] == "ANY" + value in self.allowedSettings[key] + or self.allowedSettings[key] == "ANY" ): # check if it is in the allowed settings dict file.write(f"{key},{value}\n") else: