diff --git a/backend/src/InterpolateArchs/RIFE/custom_warplayer.py b/backend/src/InterpolateArchs/RIFE/custom_warplayer.py index 9e7ca840..1e382b62 100644 --- a/backend/src/InterpolateArchs/RIFE/custom_warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/custom_warplayer.py @@ -1,5 +1,8 @@ from __future__ import annotations + from typing import Self, Sequence, Union + +import cupy as cp import numpy as np import tensorrt as trt import torch @@ -10,8 +13,8 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import dynamo_tensorrt_converter from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types, set_layer_name from torch_tensorrt.dynamo.types import TRTTensor -import cupy as cp -import importerror # this causes an import error, because this file is still under development + + class WarpPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): def __init__(self) -> None: trt.IPluginV3.__init__(self) @@ -45,6 +48,7 @@ def get_output_shapes( def supports_format_combination(self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int) -> bool: assert pos < len(in_out) assert num_inputs == 4 + desc = in_out[pos].desc return desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT @@ -60,20 +64,38 @@ def enqueue( workspace: int, stream: int, ) -> None: - dtype = torch.float32 - with torch.cuda.stream(torch.cuda.ExternalStream(stream)): - input0 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[0], dtype=np.float32)).reshape(input_desc[0].dims), dtype=dtype) - input1 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[1], dtype=np.float32)).reshape(input_desc[1].dims), dtype=dtype) - input2 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[2], dtype=np.float32)).reshape(input_desc[2].dims), dtype=dtype) - input3 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[3], dtype=np.float32)).reshape(input_desc[3].dims), dtype=dtype) - - out = warp_custom(input0, input1, input2, input3) - - output_tensor = torch.as_tensor(torch.from_numpy(np.frombuffer(outputs[0], dtype=np.float32)).reshape(output_desc[0].dims), dtype=dtype) - output_tensor.copy_(out.reshape(-1)) + dtype = trt.nptype(input_desc[0].type) + itemsize = cp.dtype(dtype).itemsize + + with cp.cuda.ExternalStream(stream): + input0_mem = cp.cuda.UnownedMemory(inputs[0], np.prod(input_desc[0].dims) * itemsize, self) + input1_mem = cp.cuda.UnownedMemory(inputs[1], np.prod(input_desc[1].dims) * itemsize, self) + input2_mem = cp.cuda.UnownedMemory(inputs[2], np.prod(input_desc[2].dims) * itemsize, self) + input3_mem = cp.cuda.UnownedMemory(inputs[3], np.prod(input_desc[3].dims) * itemsize, self) + output_mem = cp.cuda.UnownedMemory(outputs[0], np.prod(output_desc[0].dims) * itemsize, self) + + input0_ptr = cp.cuda.MemoryPointer(input0_mem, 0) + input1_ptr = cp.cuda.MemoryPointer(input1_mem, 0) + input2_ptr = cp.cuda.MemoryPointer(input2_mem, 0) + input3_ptr = cp.cuda.MemoryPointer(input3_mem, 0) + output_ptr = cp.cuda.MemoryPointer(output_mem, 0) + + input0_d = cp.ndarray(input_desc[0].dims, dtype=dtype, memptr=input0_ptr) + input1_d = cp.ndarray(input_desc[1].dims, dtype=dtype, memptr=input1_ptr) + input2_d = cp.ndarray(input_desc[2].dims, dtype=dtype, memptr=input2_ptr) + input3_d = cp.ndarray(input_desc[3].dims, dtype=dtype, memptr=input3_ptr) + output_d = cp.ndarray((np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr) + + input0_t = torch.as_tensor(input0_d) + input1_t = torch.as_tensor(input1_d) + input2_t = torch.as_tensor(input2_d) + input3_t = torch.as_tensor(input3_d) + + out = warp_custom(input0_t, input1_t, input2_t, input3_t) + cp.copyto(output_d, cp.reshape(cp.asarray(out), (-1,))) def get_fields_to_serialize(self) -> trt.PluginFieldCollection_: - return trt.PluginFieldCollection_(trt.PluginFieldCollection()) + return trt.PluginFieldCollection() def on_shape_change(self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc]) -> None: pass @@ -81,6 +103,7 @@ def on_shape_change(self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginT def set_tactic(self, tactic: int) -> None: pass + class WarpPluginCreator(trt.IPluginCreatorV3One): def __init__(self) -> None: super().__init__() @@ -94,6 +117,7 @@ def create_plugin( ) -> WarpPlugin: return WarpPlugin() + @custom_op("vsrife::warp", mutates_args=()) def warp_custom( tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor @@ -102,12 +126,14 @@ def warp_custom( g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) return F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) + @register_fake("vsrife::warp") def warp_fake( tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor ) -> torch.Tensor: return tenInput + @dynamo_tensorrt_converter(torch.ops.vsrife.warp.default, supports_dynamic_shapes=True) @enforce_tensor_types( { @@ -131,10 +157,12 @@ def ops_warp( set_layer_name(layer, target, name) return layer.get_output(0) + def warp( tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor ) -> torch.Tensor: dtype = tenInput.dtype tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) + return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to(dtype) \ No newline at end of file diff --git a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py index f48b3a73..1dc461c7 100644 --- a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py @@ -1,9 +1,10 @@ import torch import torch.nn as nn from torch.nn.functional import interpolate - -from .warplayer import warp - +try: + from .custom_warplayer import warp +except: + from .warplayer import warp class MyPixelShuffle(nn.Module): def __init__(self, upscale_factor): diff --git a/backend/src/InterpolateArchs/RIFE/warplayer.py b/backend/src/InterpolateArchs/RIFE/warplayer.py index 22f35d45..85f06cf8 100644 --- a/backend/src/InterpolateArchs/RIFE/warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/warplayer.py @@ -2,9 +2,7 @@ import torch.nn.functional as F -'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): - tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], - tenFlow[:, 1:2] / tenFlow_div[1]], 1) +def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): dtype = tenInput.dtype tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) @@ -12,9 +10,9 @@ tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) grid_sample = F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) - return grid_sample.to(dtype)''' + return grid_sample.to(dtype) -def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): +'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): tenFlow = torch.cat( [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 ) @@ -26,4 +24,4 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): mode="bilinear", padding_mode="border", align_corners=True, - ) \ No newline at end of file + )''' \ No newline at end of file diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index 5c880656..01db3568 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -344,9 +344,9 @@ def _load(self): if self.backend == "tensorrt": import tensorrt import torch_tensorrt - #from .warplayer_custom import WarpPluginCreator - #registry = tensorrt.get_plugin_registry() - #registry.register_creator(WarpPluginCreator()) + from .InterpolateArchs.RIFE.custom_warplayer import WarpPluginCreator + registry = tensorrt.get_plugin_registry() + registry.register_creator(WarpPluginCreator()) #torch_tensorrt.runtime.enable_cudagraphs() logging.basicConfig(level=logging.INFO) diff --git a/src/DownloadDeps.py b/src/DownloadDeps.py index 708dfc3b..1bdfb075 100644 --- a/src/DownloadDeps.py +++ b/src/DownloadDeps.py @@ -178,6 +178,7 @@ def getPyTorchCUDADeps(self): "https://download.pytorch.org/whl/nightly/cu124_pypi_pkg/torch_no_python-2.5.0.dev20240826%2Bcu124-py3-none-any.whl", "safetensors", "einops", + "cupy-cuda12x==13.3.0", ] torchCUDAWindowsDeps = [ #"https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl", @@ -189,6 +190,7 @@ def getPyTorchCUDADeps(self): # "torchvision==0.19.0", "safetensors", "einops", + "cupy-cuda12x==13.3.0", ] match getPlatform(): case "win32":