Skip to content

Commit

Permalink
fix trt, thanks holywu
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Oct 7, 2024
1 parent 7d694c5 commit a307ad3
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 26 deletions.
56 changes: 42 additions & 14 deletions backend/src/InterpolateArchs/RIFE/custom_warplayer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from typing import Self, Sequence, Union

import cupy as cp
import numpy as np
import tensorrt as trt
import torch
Expand All @@ -10,8 +13,8 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import dynamo_tensorrt_converter
from torch_tensorrt.dynamo.conversion.converter_utils import enforce_tensor_types, set_layer_name
from torch_tensorrt.dynamo.types import TRTTensor
import cupy as cp
import importerror # this causes an import error, because this file is still under development


class WarpPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime):
def __init__(self) -> None:
trt.IPluginV3.__init__(self)
Expand Down Expand Up @@ -45,6 +48,7 @@ def get_output_shapes(
def supports_format_combination(self, pos: int, in_out: list[trt.DynamicPluginTensorDesc], num_inputs: int) -> bool:
assert pos < len(in_out)
assert num_inputs == 4

desc = in_out[pos].desc
return desc.format == trt.TensorFormat.LINEAR and desc.type == trt.DataType.FLOAT

Expand All @@ -60,27 +64,46 @@ def enqueue(
workspace: int,
stream: int,
) -> None:
dtype = torch.float32
with torch.cuda.stream(torch.cuda.ExternalStream(stream)):
input0 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[0], dtype=np.float32)).reshape(input_desc[0].dims), dtype=dtype)
input1 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[1], dtype=np.float32)).reshape(input_desc[1].dims), dtype=dtype)
input2 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[2], dtype=np.float32)).reshape(input_desc[2].dims), dtype=dtype)
input3 = torch.as_tensor(torch.from_numpy(np.frombuffer(inputs[3], dtype=np.float32)).reshape(input_desc[3].dims), dtype=dtype)

out = warp_custom(input0, input1, input2, input3)

output_tensor = torch.as_tensor(torch.from_numpy(np.frombuffer(outputs[0], dtype=np.float32)).reshape(output_desc[0].dims), dtype=dtype)
output_tensor.copy_(out.reshape(-1))
dtype = trt.nptype(input_desc[0].type)
itemsize = cp.dtype(dtype).itemsize

with cp.cuda.ExternalStream(stream):
input0_mem = cp.cuda.UnownedMemory(inputs[0], np.prod(input_desc[0].dims) * itemsize, self)
input1_mem = cp.cuda.UnownedMemory(inputs[1], np.prod(input_desc[1].dims) * itemsize, self)
input2_mem = cp.cuda.UnownedMemory(inputs[2], np.prod(input_desc[2].dims) * itemsize, self)
input3_mem = cp.cuda.UnownedMemory(inputs[3], np.prod(input_desc[3].dims) * itemsize, self)
output_mem = cp.cuda.UnownedMemory(outputs[0], np.prod(output_desc[0].dims) * itemsize, self)

input0_ptr = cp.cuda.MemoryPointer(input0_mem, 0)
input1_ptr = cp.cuda.MemoryPointer(input1_mem, 0)
input2_ptr = cp.cuda.MemoryPointer(input2_mem, 0)
input3_ptr = cp.cuda.MemoryPointer(input3_mem, 0)
output_ptr = cp.cuda.MemoryPointer(output_mem, 0)

input0_d = cp.ndarray(input_desc[0].dims, dtype=dtype, memptr=input0_ptr)
input1_d = cp.ndarray(input_desc[1].dims, dtype=dtype, memptr=input1_ptr)
input2_d = cp.ndarray(input_desc[2].dims, dtype=dtype, memptr=input2_ptr)
input3_d = cp.ndarray(input_desc[3].dims, dtype=dtype, memptr=input3_ptr)
output_d = cp.ndarray((np.prod(output_desc[0].dims),), dtype=dtype, memptr=output_ptr)

input0_t = torch.as_tensor(input0_d)
input1_t = torch.as_tensor(input1_d)
input2_t = torch.as_tensor(input2_d)
input3_t = torch.as_tensor(input3_d)

out = warp_custom(input0_t, input1_t, input2_t, input3_t)
cp.copyto(output_d, cp.reshape(cp.asarray(out), (-1,)))

def get_fields_to_serialize(self) -> trt.PluginFieldCollection_:
return trt.PluginFieldCollection_(trt.PluginFieldCollection())
return trt.PluginFieldCollection()

def on_shape_change(self, inp: list[trt.PluginTensorDesc], out: list[trt.PluginTensorDesc]) -> None:
pass

def set_tactic(self, tactic: int) -> None:
pass


class WarpPluginCreator(trt.IPluginCreatorV3One):
def __init__(self) -> None:
super().__init__()
Expand All @@ -94,6 +117,7 @@ def create_plugin(
) -> WarpPlugin:
return WarpPlugin()


@custom_op("vsrife::warp", mutates_args=())
def warp_custom(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
Expand All @@ -102,12 +126,14 @@ def warp_custom(
g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
return F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True)


@register_fake("vsrife::warp")
def warp_fake(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
) -> torch.Tensor:
return tenInput


@dynamo_tensorrt_converter(torch.ops.vsrife.warp.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
Expand All @@ -131,10 +157,12 @@ def ops_warp(
set_layer_name(layer, target, name)
return layer.get_output(0)


def warp(
tenInput: torch.Tensor, tenFlow: torch.Tensor, tenFlow_div: torch.Tensor, backwarp_tenGrid: torch.Tensor
) -> torch.Tensor:
dtype = tenInput.dtype
tenInput = tenInput.to(torch.float)
tenFlow = tenFlow.to(torch.float)

return torch.ops.vsrife.warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid).to(dtype)
7 changes: 4 additions & 3 deletions backend/src/InterpolateArchs/RIFE/rife421IFNET.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import torch.nn as nn
from torch.nn.functional import interpolate

from .warplayer import warp

try:
from .custom_warplayer import warp
except:
from .warplayer import warp

class MyPixelShuffle(nn.Module):
def __init__(self, upscale_factor):
Expand Down
10 changes: 4 additions & 6 deletions backend/src/InterpolateArchs/RIFE/warplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
import torch.nn.functional as F


'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0],
tenFlow[:, 1:2] / tenFlow_div[1]], 1)
def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
dtype = tenInput.dtype
tenInput = tenInput.to(torch.float)
tenFlow = tenFlow.to(torch.float)

tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1)
g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
grid_sample = F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True)
return grid_sample.to(dtype)'''
return grid_sample.to(dtype)

def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
tenFlow = torch.cat(
[tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1
)
Expand All @@ -26,4 +24,4 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
mode="bilinear",
padding_mode="border",
align_corners=True,
)
)'''
6 changes: 3 additions & 3 deletions backend/src/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ def _load(self):
if self.backend == "tensorrt":
import tensorrt
import torch_tensorrt
#from .warplayer_custom import WarpPluginCreator
#registry = tensorrt.get_plugin_registry()
#registry.register_creator(WarpPluginCreator())
from .InterpolateArchs.RIFE.custom_warplayer import WarpPluginCreator
registry = tensorrt.get_plugin_registry()
registry.register_creator(WarpPluginCreator())

#torch_tensorrt.runtime.enable_cudagraphs()
logging.basicConfig(level=logging.INFO)
Expand Down
2 changes: 2 additions & 0 deletions src/DownloadDeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def getPyTorchCUDADeps(self):
"https://download.pytorch.org/whl/nightly/cu124_pypi_pkg/torch_no_python-2.5.0.dev20240826%2Bcu124-py3-none-any.whl",
"safetensors",
"einops",
"cupy-cuda12x==13.3.0",
]
torchCUDAWindowsDeps = [
#"https://github.com/TNTwise/real-video-enhancer-models/releases/download/models/spandrel-0.3.4-py3-none-any.whl",
Expand All @@ -189,6 +190,7 @@ def getPyTorchCUDADeps(self):
# "torchvision==0.19.0",
"safetensors",
"einops",
"cupy-cuda12x==13.3.0",
]
match getPlatform():
case "win32":
Expand Down

0 comments on commit a307ad3

Please sign in to comment.