Skip to content

Commit

Permalink
improve rife cuda performance
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Sep 8, 2024
1 parent 12f7496 commit 8cf5173
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 212 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ python3 build.py --build_exe
<li> <a rel="noopener noreferrer" href="https://github.com/styler00dollar" target="_blank">Styler00dollar (For RIFE models [4.1-4.5],[4.7-4.12-lite]) and Sudo Shuffle Span</a> </li>
<li> <a rel="noopener noreferrer" href="https://github.com/hzwer/Practical-RIFE" target="_blank" >RIFE</a> </li>
<li> <a rel="noopener noreferrer" href="https://github.com/Breakthrough/PySceneDetect" target="_blank" >PySceneDetect</a> </li>
<li> <a rel="noopener noreferrer" href="https://github.com/NevermindNilas/TheAnimeScripter" target="_blank" >TheAnimeScripter for inspiration</a></li>
<li> <a rel="noopener noreferrer" href="https://github.com/NevermindNilas/TheAnimeScripter" target="_blank" >TheAnimeScripter for inspiration and mods to rife arch.</a></li>
<li> <a rel="noopener noreferrer" href="https://github.com/chaiNNer-org/spandrel" target="_blank">Spandrel (For CUDA upscaling model arch support)</a></li>
<li> <a rel="noopener noreferrer" href="https://github.com/Final2x/realesrgan-ncnn-py" target="_blank">RealESRGAN NCNN python</a></li>
<li> <a rel="noopener noreferrer" href="https://github.com/marcelotduarte/cx_Freeze" target="_blank">cx_Freeze</a></li>
Expand Down
75 changes: 46 additions & 29 deletions backend/src/InterpolateArchs/RIFE/rife413IFNET.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,35 +148,44 @@ def __init__(
self.encode = Head()
self.device = device
self.dtype = dtype
self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.ensemble = ensemble
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div
self.backWarp = backwarp_tenGrid
self.tenFlow = tenFlow_div

# self.contextnet = Contextnet()
# self.unet = Unet()
self.paddedHeight = backwarp_tenGrid.shape[2]
self.paddedWidth = backwarp_tenGrid.shape[3]

def forward(self, img0, img1, timestep):
# cant be cached
h, w = img0.shape[2], img0.shape[3]
self.blocks = [self.block0, self.block1, self.block2, self.block3]

def forward(self, img0, img1, timestep, f0):
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
fs = torch.reshape(fs_2, (1, 16, h, w))
imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth))
f1 = self.encode(img1[:, :3])
fs = torch.cat([f0, f1], dim=1)
fs_2 = torch.reshape(fs, (2, 8, self.paddedHeight, self.paddedWidth))
if self.ensemble:
fs_rev = torch.cat(torch.split(fs, [8, 8], dim=1)[::-1], dim=1)
imgs_rev = torch.cat([img1, img0], dim=1)

flows = None
mask = None
blocks = [self.block0, self.block1, self.block2, self.block3]
for block, scale in zip(blocks, self.scale_list):
for block, scale in zip(self.blocks, self.scaleList):
if flows is None:
if self.ensemble:
temp_ = torch.cat((imgs_rev, fs_rev, 1 - timestep), 1)
flowss, masks = block(torch.cat((temp, temp_), 0), scale=scale)
flowss, masks = block(
torch.cat(
(
temp, # noqa
temp_,
),
0,
),
scale=scale,
)
flows, flows_ = torch.split(flowss, [1, 1], dim=0)
mask, mask_ = torch.split(masks, [1, 1], dim=0)
flows = (
Expand All @@ -195,8 +204,8 @@ def forward(self, img0, img1, timestep):
if self.ensemble:
temp = torch.cat(
(
wimg,
wf,
wimg, # noqa
wf, # noqa
timestep,
mask,
(flows * (1 / scale) if scale != 1 else flows),
Expand All @@ -205,8 +214,8 @@ def forward(self, img0, img1, timestep):
)
temp_ = torch.cat(
(
wimg_rev,
wf_rev,
wimg_rev, # noqa
wf_rev, # noqa
1 - timestep,
-mask,
(flows_rev * (1 / scale) if scale != 1 else flows_rev),
Expand All @@ -223,8 +232,8 @@ def forward(self, img0, img1, timestep):
else:
temp = torch.cat(
(
wimg,
wf,
wimg, # noqa
wf, # noqa
timestep,
mask,
(flows * (1 / scale) if scale != 1 else flows),
Expand All @@ -240,7 +249,11 @@ def forward(self, img0, img1, timestep):
torch.split(flows, [2, 2], dim=1)[::-1], dim=1
)
precomp = (
(self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div)
(
self.backWarp
+ flows.reshape((2, 2, self.paddedHeight, self.paddedWidth))
* self.tenFlow
)
.permute(0, 2, 3, 1)
.to(dtype=self.dtype)
)
Expand All @@ -263,14 +276,18 @@ def forward(self, img0, img1, timestep):
align_corners=True,
)
wimg, wf = torch.split(warps, [3, 8], dim=1)
wimg = torch.reshape(wimg, (1, 6, h, w))
wf = torch.reshape(wf, (1, 16, h, w))
wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth))
wf = torch.reshape(wf, (1, 16, self.paddedHeight, self.paddedWidth))
if self.ensemble:
wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1)
wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1)
wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1) # noqa
wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1) # noqa
mask = torch.sigmoid(mask)
warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])

frame = warped_img0 * mask + warped_img1 * (1 - mask)
frame = frame[:, :, : self.height, : self.width][0]
return frame.permute(1, 2, 0).mul(255).float()
return (
(warped_img0 * mask + warped_img1 * (1 - mask))[
:, :, : self.height, : self.width
][0]
.permute(1, 2, 0)
.mul(255)
.float()
), f1
80 changes: 38 additions & 42 deletions backend/src/InterpolateArchs/RIFE/rife421IFNET.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import torch
import torch.nn as nn


from torch.nn.functional import interpolate


Expand Down Expand Up @@ -108,7 +106,6 @@ def forward(self, x, scale=1):
tmp, scale_factor=scale, mode="bilinear", align_corners=False
)

# flows, mask, _ = torch.split(tmp, split_size_or_sections=[4, 1, 1], dim=1)
flow = tmp[:, :4]
mask = tmp[:, 4:5]
feat = tmp[:, 5:]
Expand Down Expand Up @@ -138,80 +135,79 @@ def __init__(
self.encode = Head()
self.device = device
self.dtype = dtype
self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.ensemble = ensemble
self.width = width
self.height = height
self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div
self.backWarp = backwarp_tenGrid
self.tenFlow = tenFlow_div

self.paddedHeight = backwarp_tenGrid.shape[2]
self.paddedWidth = backwarp_tenGrid.shape[3]

def forward(self, img0, img1, timestep):
# cant be cached
h, w = img0.shape[2], img0.shape[3]
self.blocks = [self.block0, self.block1, self.block2, self.block3]

def forward(self, img0, img1, timeStep, f0):
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
fs = torch.reshape(fs_2, (1, 16, h, w))
if self.ensemble:
fs_rev = torch.cat(torch.split(fs, [8, 8], dim=1)[::-1], dim=1)
imgs_rev = torch.cat([img1, img0], dim=1)

warped_img0 = img0
warped_img1 = img1
flows = None
imgs2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth))
f1 = self.encode(img1[:, :3])
fs = torch.cat([f0, f1], dim=1)
fs2 = torch.reshape(fs, (2, 8, self.paddedHeight, self.paddedWidth))
warpedImg0 = img0
warpedImg1 = img1
flows = None
blocks = [self.block0, self.block1, self.block2, self.block3]
scale_list = [8, 4, 2, 1]
for block, scale in zip(blocks, scale_list):
for block, scale in zip(self.blocks, self.scaleList):
if flows is None:
temp = torch.cat((imgs, fs, timestep), 1)
temp = torch.cat((imgs, fs, timeStep), 1)
flows, mask, feat = block(temp, scale=scale)
else:
temp = torch.cat(
(
wimg,
wf,
timestep,
wimg, # noqa
wf, # noqa
timeStep,
mask,
feat,
(flows * (1 / scale) if scale != 1 else flows),
),
1,
)
fds, mask, feat = block(temp, scale=scale)

flows = flows + fds

if self.ensemble:
flows_rev = torch.cat(
torch.split(flows, [2, 2], dim=1)[::-1], dim=1
)
precomp = (
self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div
self.backWarp
+ flows.reshape((2, 2, self.paddedHeight, self.paddedWidth))
* self.tenFlow
).permute(0, 2, 3, 1)
if scale == 1:
warped_imgs = torch.nn.functional.grid_sample(
imgs_2,
warpedImgs = torch.nn.functional.grid_sample(
imgs2,
precomp,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
else:
imgs_fs_2 = torch.cat((imgs_2, fs_2), 1)
imgsFs2 = torch.cat((imgs2, fs2), 1)
warps = torch.nn.functional.grid_sample(
imgs_fs_2,
imgsFs2,
precomp,
mode="bilinear",
padding_mode="border",
align_corners=True,
)
wimg, wf = torch.split(warps, [3, 8], dim=1)
wimg = torch.reshape(wimg, (1, 6, h, w))
wf = torch.reshape(wf, (1, 16, h, w))
wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth))
wf = torch.reshape(wf, (1, 16, self.paddedHeight, self.paddedWidth))

mask = torch.sigmoid(mask)
warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])
frame = warped_img0 * mask + warped_img1 * (1 - mask)
frame = frame[:, :, : self.height, : self.width][0]
return frame.permute(1, 2, 0).mul(255).float()
warpedImg0, warpedImg1 = torch.split(warpedImgs, [1, 1])
return (
(warpedImg0 * mask + warpedImg1 * (1 - mask))[
:, :, : self.height, : self.width
][0]
.permute(1, 2, 0)
.mul(255)
.float()
), f1
54 changes: 28 additions & 26 deletions backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn


from torch.nn.functional import interpolate


Expand Down Expand Up @@ -139,8 +140,6 @@ def __init__(
height=1080,
backwarp_tenGrid=None,
tenFlow_div=None,
pw=1920,
ph=1088,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=192)
Expand All @@ -150,39 +149,35 @@ def __init__(
self.encode = Head()
self.device = device
self.dtype = dtype
self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.ensemble = ensemble
self.width = width
self.height = height
self.backWarp = backwarp_tenGrid
self.tenFlow = tenFlow_div
self.blocks = [self.block0, self.block1, self.block2, self.block3]

self.backwarp_tenGrid = backwarp_tenGrid
self.tenFlow_div = tenFlow_div

self.pw = pw
self.ph = ph
self.paddedHeight = backwarp_tenGrid.shape[2]
self.paddedWidth = backwarp_tenGrid.shape[3]

def forward(self, img0, img1, timestep):
h, w = img0.shape[2], img0.shape[3]
def forward(self, img0, img1, timestep, f0):
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, h, w))
fs_2 = self.encode(imgs_2)
fs = torch.reshape(fs_2, (1, 8, h, w))

imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth))
f1 = self.encode(img1[:, :3])
fs = torch.cat([f0, f1], dim=1)
fs_2 = torch.reshape(fs, (2, 4, self.paddedHeight, self.paddedWidth))
warped_img0 = img0
warped_img1 = img1
flows = None
flows = None
blocks = [self.block0, self.block1, self.block2, self.block3]
scale_list = [8, 4, 2, 1]
for block, scale in zip(blocks, scale_list):
for block, scale in zip(self.blocks, self.scaleList):
if flows is None:
temp = torch.cat((imgs, fs, timestep), 1)
flows, mask, feat = block(temp, scale=scale)
else:
temp = torch.cat(
(
wimg,
wf,
wimg, # noqa
wf, # noqa
timestep,
mask,
feat,
Expand All @@ -195,7 +190,9 @@ def forward(self, img0, img1, timestep):
flows = flows + fds

precomp = (
self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div
self.backWarp
+ flows.reshape((2, 2, self.paddedHeight, self.paddedWidth))
* self.tenFlow
).permute(0, 2, 3, 1)
if scale == 1:
warped_imgs = torch.nn.functional.grid_sample(
Expand All @@ -214,11 +211,16 @@ def forward(self, img0, img1, timestep):
align_corners=True,
)
wimg, wf = torch.split(warps, [3, 4], dim=1)
wimg = torch.reshape(wimg, (1, 6, h, w))
wf = torch.reshape(wf, (1, 8, h, w))
wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth))
wf = torch.reshape(wf, (1, 8, self.paddedHeight, self.paddedWidth))

mask = torch.sigmoid(mask)
warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])
frame = warped_img0 * mask + warped_img1 * (1 - mask)
frame = frame[:, :, : self.height, : self.width][0]
return frame.permute(1, 2, 0).mul(255).float()
return (
(warped_img0 * mask + warped_img1 * (1 - mask))[
:, :, : self.height, : self.width
][0]
.permute(1, 2, 0)
.mul(255)
.float()
), f1
Loading

0 comments on commit 8cf5173

Please sign in to comment.