Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Oct 14, 2024
1 parent 2d24740 commit f9b0a01
Show file tree
Hide file tree
Showing 24 changed files with 269 additions and 154 deletions.
12 changes: 4 additions & 8 deletions REAL-Video-Enhancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from PySide6 import QtSvg # Import the QtSvg module so svg icons can be used on windows
from src.version import version
from src.InputHandler import VideoInputHandler

# other imports
from src.Util import (

openLink,
getOSInfo,
get_gpu_info,
Expand All @@ -36,8 +36,7 @@
videosPath,
checkForWritePermissions,
getAvailableDiskSpace,
errorAndLog

errorAndLog,
)
from src.ui.ProcessTab import ProcessTab
from src.ui.DownloadTab import DownloadTab
Expand Down Expand Up @@ -171,7 +170,7 @@ def __init__(self):
self.settingsTab = SettingsTab(
parent=self, halfPrecisionSupport=halfPrecisionSupport
)

self.moreTab = MoreTab(parent=self)
# Startup Animation
self.animationHandler = AnimationHandler()
Expand Down Expand Up @@ -381,10 +380,7 @@ def disableProcessPage(self):
def enableProcessPage(self):
self.processSettingsContainer.setEnabled(True)



def loadVideo(self, inputFile):

videoHandler = VideoInputHandler(inputText=inputFile)
if videoHandler.isYoutubeLink() and videoHandler.isValidYoutubeLink():
videoHandler.getDataFromYoutubeVideo()
Expand All @@ -409,7 +405,7 @@ def loadVideo(self, inputFile):
self.outputFileSelectButton.setEnabled(True)
self.isVideoLoaded = True
self.updateVideoGUIDetails()

# input file button
def openInputFile(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions backend/rve-backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self):
outputFile=self.args.output,
interpolateModel=self.args.interpolateModel,
interpolateFactor=self.args.interpolateFactor,
rifeVersion="v1", # some guy was angy about rifev2 being here, so I changed it to v1
rifeVersion="v1", # some guy was angy about rifev2 being here, so I changed it to v1
upscaleModel=self.args.upscaleModel,
tile_size=self.args.tilesize,
# backend settings
Expand All @@ -49,18 +49,18 @@ def __init__(self):
half_prec_supp = False
availableBackends = []
printMSG = ""

if checkForTensorRT():
"""
checks for tensorrt availability, and the current gpu works with it (if half precision is supported)
Trt 10 only supports RTX 20 series and up.
Half precision is only availaible on RTX 20 series and up
"""
import torch

half_prec_supp = check_bfloat16_support()
if half_prec_supp:
import tensorrt


availableBackends.append("tensorrt")
printMSG += f"TensorRT Version: {tensorrt.__version__}\n"
Expand All @@ -72,7 +72,7 @@ def __init__(self):
availableBackends.append("pytorch")
printMSG += f"PyTorch Version: {torch.__version__}\n"
half_prec_supp = check_bfloat16_support()

if checkForNCNN():
availableBackends.append("ncnn")
printMSG += f"NCNN Version: 20220729\n"
Expand Down
2 changes: 1 addition & 1 deletion backend/src/FFmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
self.previewFrame = None
self.crf = crf
self.sharedMemoryID = sharedMemoryID

self.subtitleFiles = []
self.sharedMemoryThread = Thread(
target=lambda: self.writeOutInformation(self.outputFrameChunkSize)
Expand Down
2 changes: 1 addition & 1 deletion backend/src/InterpolateArchs/DetectInterpolateArch.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def excluded_keys() -> tuple:
"transformer.layers.4.self_attn.merge.weight",
]


class GMFSS:
def __init__():
pass
Expand All @@ -207,7 +208,6 @@ def excluded_keys() -> tuple:
"module.encode.1.weight",
"module.encode.1.bias",
]



archs = [RIFE46, RIFE47, RIFE413, RIFE420, RIFE421, RIFE422lite, RIFE425, GMFSS]
Expand Down
108 changes: 84 additions & 24 deletions backend/src/InterpolateArchs/RIFE/custom_warplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@
from torch.fx.node import Argument, Target
from torch.library import custom_op, register_fake
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import dynamo_tensorrt_converter
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types, set_layer_name
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import (
enforce_tensor_types,
set_layer_name,
)
from torch_tensorrt.dynamo.types import TRTTensor


class WarpPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):
class WarpPlugin(
trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime
):
def __init__(self) -> None:
trt.IPluginV3.__init__(self)
trt.IPluginV3OneCore.__init__(self)
Expand All @@ -34,25 +41,40 @@ def clone(self) -> WarpPlugin:
def get_capability_interface(self, type: trt.PluginCapabilityType) -> Self:
return self

def configure_plugin(self, inp: list[trt.DynamicPluginTensorDesc], out: list[trt.DynamicPluginTensorDesc]) -> None:
def configure_plugin(
self,
inp: list[trt.DynamicPluginTensorDesc],
out: list[trt.DynamicPluginTensorDesc],
) -> None:
pass

def get_output_data_types(self, input_types: list[trt.DataType]) -> list[trt.DataType]:
def get_output_data_types(
self, input_types: list[trt.DataType]
) -> list[trt.DataType]:
return [input_types[0]]

def get_output_shapes(
self, inputs: list[trt.DimsExprs], shape_inputs: list[trt.DimsExprs], expr_builder: trt.IExprBuilder
self,
inputs: list[trt.DimsExprs],
shape_inputs: list[trt.DimsExprs],
expr_builder: trt.IExprBuilder,
) -> list[trt.DimsExprs]:
return [inputs[0]]

def supports_format_combination(self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int) -> bool:
def supports_format_combination(
self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int
) -> bool:
assert pos < len(in_out)
assert num_inputs == 4

desc = in_out[pos].desc
return desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT
return (
desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT
)

def attach_to_context(self, resource_context: trt.IPluginResourceContext) -> WarpPlugin:
def attach_to_context(
self, resource_context: trt.IPluginResourceContext
) -> WarpPlugin:
return self.clone()

def enqueue(
Expand All @@ -68,11 +90,21 @@ def enqueue(
itemsize = cp.dtype(dtype).itemsize

with cp.cuda.ExternalStream(stream):
input0_mem = cp.cuda.UnownedMemory(inputs[0], np.prod(input_desc[0].dims) * itemsize, self)
input1_mem = cp.cuda.UnownedMemory(inputs[1], np.prod(input_desc[1].dims) * itemsize, self)
input2_mem = cp.cuda.UnownedMemory(inputs[2], np.prod(input_desc[2].dims) * itemsize, self)
input3_mem = cp.cuda.UnownedMemory(inputs[3], np.prod(input_desc[3].dims) * itemsize, self)
output_mem = cp.cuda.UnownedMemory(outputs[0], np.prod(output_desc[0].dims) * itemsize, self)
input0_mem = cp.cuda.UnownedMemory(
inputs[0], np.prod(input_desc[0].dims) * itemsize, self
)
input1_mem = cp.cuda.UnownedMemory(
inputs[1], np.prod(input_desc[1].dims) * itemsize, self
)
input2_mem = cp.cuda.UnownedMemory(
inputs[2], np.prod(input_desc[2].dims) * itemsize, self
)
input3_mem = cp.cuda.UnownedMemory(
inputs[3], np.prod(input_desc[3].dims) * itemsize, self
)
output_mem = cp.cuda.UnownedMemory(
outputs[0], np.prod(output_desc[0].dims) * itemsize, self
)

input0_ptr = cp.cuda.MemoryPointer(input0_mem, 0)
input1_ptr = cp.cuda.MemoryPointer(input1_mem, 0)
Expand All @@ -84,7 +116,9 @@ def enqueue(
input1_d = cp.ndarray(input_desc[1].dims, dtype=dtype, memptr=input1_ptr)
input2_d = cp.ndarray(input_desc[2].dims, dtype=dtype, memptr=input2_ptr)
input3_d = cp.ndarray(input_desc[3].dims, dtype=dtype, memptr=input3_ptr)
output_d = cp.ndarray((np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr)
output_d = cp.ndarray(
(np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr
)

input0_t = torch.as_tensor(input0_d)
input1_t = torch.as_tensor(input1_d)
Expand All @@ -97,7 +131,9 @@ def enqueue(
def get_fields_to_serialize(self) -> trt.PluginFieldCollection_:
return trt.PluginFieldCollection()

def on_shape_change(self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc]) -> None:
def on_shape_change(
self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc]
) -> None:
pass

def set_tactic(self, tactic: int) -> None:
Expand All @@ -113,23 +149,40 @@ def __init__(self) -> None:
self.field_names = trt.PluginFieldCollection()

def create_plugin(
self, name: str, field_collection: trt.PluginFieldCollection_, phase: trt.TensorRTPhase
self,
name: str,
field_collection: trt.PluginFieldCollection_,
phase: trt.TensorRTPhase,
) -> WarpPlugin:
return WarpPlugin()


@custom_op("vsrife::warp", mutates_args=())
def warp_custom(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
tenInput: torch.Tensor,
tenFlow: torch.Tensor,
tenFlow_div: torch.Tensor,
backwarp_tenGrid: torch.Tensor,
) -> torch.Tensor:
tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1)
tenFlow = torch.cat(
[tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1
)
g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
return F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True)
return F.grid_sample(
input=tenInput,
grid=g,
mode="bilinear",
padding_mode="border",
align_corners=True,
)


@register_fake("vsrife::warp")
def warp_fake(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
tenInput: torch.Tensor,
tenFlow: torch.Tensor,
tenFlow_div: torch.Tensor,
backwarp_tenGrid: torch.Tensor,
) -> torch.Tensor:
return tenInput

Expand All @@ -152,17 +205,24 @@ def ops_warp(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
creator = trt.get_plugin_registry().get_creator("WarpPlugin", version="1")
field_collection = trt.PluginFieldCollection()
plugin = creator.create_plugin("WarpPlugin", field_collection=field_collection, phase=trt.TensorRTPhase.BUILD)
plugin = creator.create_plugin(
"WarpPlugin", field_collection=field_collection, phase=trt.TensorRTPhase.BUILD
)
layer = ctx.net.add_plugin_v3(inputs=list(args), shape_inputs=[], plugin=plugin)
set_layer_name(layer, target, name)
return layer.get_output(0)


def warp(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
tenInput: torch.Tensor,
tenFlow: torch.Tensor,
tenFlow_div: torch.Tensor,
backwarp_tenGrid: torch.Tensor,
) -> torch.Tensor:
dtype = tenInput.dtype
tenInput = tenInput.to(torch.float)
tenFlow = tenFlow.to(torch.float)

return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to(dtype)
return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to(
dtype
)
3 changes: 2 additions & 1 deletion backend/src/InterpolateArchs/RIFE/rife413IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@


from torch.nn.functional import interpolate

try:
from .custom_warplayer import warp
except:
from .warplayer import warp


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -173,7 +175,6 @@ def __init__(
else:
raise ValueError("rife_trt_mode must be 'fast' or 'accurate'")
self.warp = warp


def forward(self, img0, img1, timestep, f0, f1):
warped_img0 = img0
Expand Down
1 change: 1 addition & 0 deletions backend/src/InterpolateArchs/RIFE/rife421IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from torch.nn.functional import interpolate


class MyPixelShuffle(nn.Module):
def __init__(self, upscale_factor):
super(MyPixelShuffle, self).__init__()
Expand Down
4 changes: 2 additions & 2 deletions backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from torch.nn.functional import interpolate


class MyPixelShuffle(nn.Module):
def __init__(self, upscale_factor):
super(MyPixelShuffle, self).__init__()
Expand Down Expand Up @@ -167,7 +168,6 @@ def __init__(
else:
raise ValueError("rife_trt_mode must be 'fast' or 'accurate'")
self.warp = warp


def forward(self, img0, img1, timestep, f0, f1):
warped_img0 = img0
Expand Down Expand Up @@ -212,4 +212,4 @@ def forward(self, img0, img1, timestep, f0, f1):
][0]
.permute(1, 2, 0)
.mul(255)
)
)
2 changes: 1 addition & 1 deletion backend/src/InterpolateArchs/RIFE/rife425IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self.height = height
self.backWarp = backwarp_tenGrid
self.tenFlow = tenFlow_div

self.paddedHeight = backwarp_tenGrid.shape[2]
self.paddedWidth = backwarp_tenGrid.shape[3]

Expand Down
8 changes: 3 additions & 5 deletions backend/src/InterpolateArchs/RIFE/rife46IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, img0, img1, timestep):
warped_img1 = img1
flow = None
mask = None

for i in range(4):
if flow is None:
flow, mask = self.block[i](
Expand All @@ -151,9 +151,7 @@ def forward(self, img0, img1, timestep):
mask = (mask + (-m1)) / 2
else:
f0, m0 = self.block[i](
torch.cat(
(warped_img0, warped_img1, timestep, mask), 1
),
torch.cat((warped_img0, warped_img1, timestep, mask), 1),
flow,
scale=self.scale_list[i],
)
Expand Down Expand Up @@ -186,4 +184,4 @@ def forward(self, img0, img1, timestep):
temp = torch.sigmoid(latest_mask)
frame = warped_img0 * temp + warped_img1 * (1 - temp)
frame = frame[:, :, : self.height, : self.width][0]
return frame.permute(1, 2, 0).mul(255)
return frame.permute(1, 2, 0).mul(255)
Loading

0 comments on commit f9b0a01

Please sign in to comment.