Skip to content

Commit

Permalink
improve rife speed
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Aug 22, 2024
1 parent 4bfd1f2 commit e91bc83
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 46 deletions.
11 changes: 8 additions & 3 deletions backend/src/InterpolateArchs/RIFE/rife413IFNET.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn

import math
try:
from .interpolate import interpolate
except ImportError:
Expand Down Expand Up @@ -138,6 +138,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 16, c=192)
Expand All @@ -151,11 +153,14 @@ def __init__(
self.ensemble = ensemble
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div


# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
# cant be cached
h, w = img0.shape[2], img0.shape[3]
imgs = torch.cat([img0, img1], dim=1)
Expand Down Expand Up @@ -237,7 +242,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
torch.split(flows, [2, 2], dim=1)[::-1], dim=1
)
precomp = (
(backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div)
(self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div)
.permute(0, 2, 3, 1)
.to(dtype=self.dtype)
)
Expand Down
14 changes: 7 additions & 7 deletions backend/src/InterpolateArchs/RIFE/rife420IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .interpolate import interpolate
except ImportError:
from torch.nn.functional import interpolate

import math

def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
Expand Down Expand Up @@ -138,6 +138,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 16, c=384)
Expand All @@ -151,14 +153,12 @@ def __init__(
self.ensemble = ensemble
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
# cant be cached
h, w = img0.shape[2], img0.shape[3]
tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1)
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
Expand Down Expand Up @@ -240,7 +240,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
torch.split(flows, [2, 2], dim=1)[::-1], dim=1
)
precomp = (
(backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div)
(self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div)
.permute(0, 2, 3, 1)
.to(dtype=self.dtype)
)
Expand Down
11 changes: 6 additions & 5 deletions backend/src/InterpolateArchs/RIFE/rife421IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 16, c=256)
Expand All @@ -141,13 +143,12 @@ def __init__(
self.ensemble = ensemble
self.width = width
self.height = height
# self.contextnet = Contextnet()
# self.unet = Unet()
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
# cant be cached
h, w = img0.shape[2], img0.shape[3]
tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1)
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
Expand Down Expand Up @@ -189,7 +190,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
torch.split(flows, [2, 2], dim=1)[::-1], dim=1
)
precomp = (
backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div
self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div
).permute(0, 2, 3, 1)
if scale == 1:
warped_imgs = torch.nn.functional.grid_sample(
Expand Down
13 changes: 6 additions & 7 deletions backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+8, c=192)
Expand All @@ -131,16 +133,13 @@ def __init__(
self.width = width
self.height = height

# self.contextnet = Contextnet()
# self.unet = Unet()


self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
# cant be cached

h, w = img0.shape[2], img0.shape[3]
tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1)
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
Expand Down Expand Up @@ -177,7 +176,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):


precomp = (
backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div
self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div
).permute(0, 2, 3, 1)
if scale == 1:
warped_imgs = torch.nn.functional.grid_sample(
Expand Down
10 changes: 7 additions & 3 deletions backend/src/InterpolateArchs/RIFE/rife46IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7, c=192)
Expand All @@ -114,8 +116,10 @@ def __init__(
self.device = device
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
flow_list = []
merged = []
mask_list = []
Expand Down Expand Up @@ -167,8 +171,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
mask = mask + m0
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid)
warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid)
warped_img0 = warp(img0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid)
warped_img1 = warp(img1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid)
merged.append((warped_img0, warped_img1))
mask_list[3] = torch.sigmoid(mask_list[3])
frame = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3])
Expand Down
14 changes: 9 additions & 5 deletions backend/src/InterpolateArchs/RIFE/rife47IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=192)
Expand All @@ -132,11 +134,13 @@ def __init__(
self.ensemble = ensemble
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

# self.contextnet = Contextnet()
# self.unet = Unet()

def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
def forward(self, img0, img1, timestep):
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
Expand All @@ -163,8 +167,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid)
wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid)
wf0 = warp(f0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid)
wf1 = warp(f1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid)
fd, m0 = block[i](
torch.cat(
(
Expand Down Expand Up @@ -203,8 +207,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid):
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid)
warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid)
warped_img0 = warp(img0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid)
warped_img1 = warp(img1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid)
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
frame= warped_img0 * mask + warped_img1 * (1 - mask)
Expand Down
14 changes: 5 additions & 9 deletions backend/src/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def __init__(
device=self.device,
width=self.width,
height=self.height,
backwarp_tenGrid=self.backwarp_tenGrid,
tenFlow_div=self.tenFlow_div,
)

state_dict = {
Expand Down Expand Up @@ -252,14 +254,7 @@ def __init__(
torch.zeros(
(1, 1, self.ph, self.pw), dtype=self.dtype, device=device
),
torch.zeros((2,), dtype=self.dtype, device=device),
torch.zeros(
(1, 2, self.ph, self.pw), dtype=self.dtype, device=device
)
if v1 else
torch.zeros(
(1, 2, 1, 1), dtype=self.dtype, device=device
),

]
self.flownet = torch_tensorrt.compile(
self.flownet,
Expand Down Expand Up @@ -293,7 +288,7 @@ def process(self, img0, img1, timestep):
)

output = self.flownet(
img0, img1, timestep, self.tenFlow_div, self.backwarp_tenGrid
img0, img1, timestep
)
return self.tensor_to_frame(output)

Expand All @@ -305,6 +300,7 @@ def tensor_to_frame(self, frame: torch.Tensor):

return (
frame
.byte()
.contiguous()
.cpu()
.numpy()
Expand Down
8 changes: 1 addition & 7 deletions backend/src/UpscaleTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,16 @@ def renderToNPArray(self, image: torch.Tensor) -> torch.Tensor:
.float()
.clamp(0.0, 1.0)
.mul(255)
.byte()
.contiguous()
.detach()
.cpu()
.numpy()
)

@torch.inference_mode()
def renderImagesInDirectory(self, dir):
pass

def getScale(self):
return self.scale

def saveImage(self, image: np.array, fullOutputPathLocation):
cv2.imwrite(fullOutputPathLocation, image)

@torch.inference_mode()
def renderTiledImage(
self,
Expand Down

0 comments on commit e91bc83

Please sign in to comment.