From 8cf5173c964fc594a471a28be6b9ea32e38f84c2 Mon Sep 17 00:00:00 2001 From: TNTwise Date: Sun, 8 Sep 2024 22:55:16 +0000 Subject: [PATCH] improve rife cuda performance --- README.md | 2 +- .../src/InterpolateArchs/RIFE/rife413IFNET.py | 75 +++++++---- .../src/InterpolateArchs/RIFE/rife421IFNET.py | 80 ++++++----- .../RIFE/rife422_liteIFNET.py | 54 ++++---- .../src/InterpolateArchs/RIFE/rife47IFNET.py | 125 +++++------------- backend/src/InterpolateTorch.py | 81 +++++++++--- backend/src/RenderVideo.py | 5 +- 7 files changed, 210 insertions(+), 212 deletions(-) diff --git a/README.md b/README.md index 173a0efb..580090ea 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ python3 build.py --build_exe
  • Styler00dollar (For RIFE models [4.1-4.5],[4.7-4.12-lite]) and Sudo Shuffle Span
  • RIFE
  • PySceneDetect
  • -
  • TheAnimeScripter for inspiration
  • +
  • TheAnimeScripter for inspiration and mods to rife arch.
  • Spandrel (For CUDA upscaling model arch support)
  • RealESRGAN NCNN python
  • cx_Freeze
  • diff --git a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py index fac1af8d..2a175c38 100644 --- a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py @@ -148,35 +148,44 @@ def __init__( self.encode = Head() self.device = device self.dtype = dtype - self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale] self.ensemble = ensemble self.width = width self.height = height - self.backwarp_tenGrid = backwarp_tenGrid - self.tenFlow_div = tenFlow_div + self.backWarp = backwarp_tenGrid + self.tenFlow = tenFlow_div - # self.contextnet = Contextnet() - # self.unet = Unet() + self.paddedHeight = backwarp_tenGrid.shape[2] + self.paddedWidth = backwarp_tenGrid.shape[3] - def forward(self, img0, img1, timestep): - # cant be cached - h, w = img0.shape[2], img0.shape[3] + self.blocks = [self.block0, self.block1, self.block2, self.block3] + + def forward(self, img0, img1, timestep, f0): imgs = torch.cat([img0, img1], dim=1) - imgs_2 = torch.reshape(imgs, (2, 3, h, w)) - fs_2 = self.encode(imgs_2) - fs = torch.reshape(fs_2, (1, 16, h, w)) + imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth)) + f1 = self.encode(img1[:, :3]) + fs = torch.cat([f0, f1], dim=1) + fs_2 = torch.reshape(fs, (2, 8, self.paddedHeight, self.paddedWidth)) if self.ensemble: fs_rev = torch.cat(torch.split(fs, [8, 8], dim=1)[::-1], dim=1) imgs_rev = torch.cat([img1, img0], dim=1) flows = None mask = None - blocks = [self.block0, self.block1, self.block2, self.block3] - for block, scale in zip(blocks, self.scale_list): + for block, scale in zip(self.blocks, self.scaleList): if flows is None: if self.ensemble: temp_ = torch.cat((imgs_rev, fs_rev, 1 - timestep), 1) - flowss, masks = block(torch.cat((temp, temp_), 0), scale=scale) + flowss, masks = block( + torch.cat( + ( + temp, # noqa + temp_, + ), + 0, + ), + scale=scale, + ) flows, flows_ = torch.split(flowss, [1, 1], dim=0) mask, mask_ = torch.split(masks, [1, 1], dim=0) flows = ( @@ -195,8 +204,8 @@ def forward(self, img0, img1, timestep): if self.ensemble: temp = torch.cat( ( - wimg, - wf, + wimg, # noqa + wf, # noqa timestep, mask, (flows * (1 / scale) if scale != 1 else flows), @@ -205,8 +214,8 @@ def forward(self, img0, img1, timestep): ) temp_ = torch.cat( ( - wimg_rev, - wf_rev, + wimg_rev, # noqa + wf_rev, # noqa 1 - timestep, -mask, (flows_rev * (1 / scale) if scale != 1 else flows_rev), @@ -223,8 +232,8 @@ def forward(self, img0, img1, timestep): else: temp = torch.cat( ( - wimg, - wf, + wimg, # noqa + wf, # noqa timestep, mask, (flows * (1 / scale) if scale != 1 else flows), @@ -240,7 +249,11 @@ def forward(self, img0, img1, timestep): torch.split(flows, [2, 2], dim=1)[::-1], dim=1 ) precomp = ( - (self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div) + ( + self.backWarp + + flows.reshape((2, 2, self.paddedHeight, self.paddedWidth)) + * self.tenFlow + ) .permute(0, 2, 3, 1) .to(dtype=self.dtype) ) @@ -263,14 +276,18 @@ def forward(self, img0, img1, timestep): align_corners=True, ) wimg, wf = torch.split(warps, [3, 8], dim=1) - wimg = torch.reshape(wimg, (1, 6, h, w)) - wf = torch.reshape(wf, (1, 16, h, w)) + wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth)) + wf = torch.reshape(wf, (1, 16, self.paddedHeight, self.paddedWidth)) if self.ensemble: - wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1) - wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1) + wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1) # noqa + wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1) # noqa mask = torch.sigmoid(mask) warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1]) - - frame = warped_img0 * mask + warped_img1 * (1 - mask) - frame = frame[:, :, : self.height, : self.width][0] - return frame.permute(1, 2, 0).mul(255).float() + return ( + (warped_img0 * mask + warped_img1 * (1 - mask))[ + :, :, : self.height, : self.width + ][0] + .permute(1, 2, 0) + .mul(255) + .float() + ), f1 \ No newline at end of file diff --git a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py index e0a6afc0..a4f34f34 100644 --- a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn - - from torch.nn.functional import interpolate @@ -108,7 +106,6 @@ def forward(self, x, scale=1): tmp, scale_factor=scale, mode="bilinear", align_corners=False ) - # flows, mask, _ = torch.split(tmp, split_size_or_sections=[4, 1, 1], dim=1) flow = tmp[:, :4] mask = tmp[:, 4:5] feat = tmp[:, 5:] @@ -138,40 +135,37 @@ def __init__( self.encode = Head() self.device = device self.dtype = dtype - self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale] self.ensemble = ensemble self.width = width self.height = height - self.backwarp_tenGrid = backwarp_tenGrid - self.tenFlow_div = tenFlow_div + self.backWarp = backwarp_tenGrid + self.tenFlow = tenFlow_div + + self.paddedHeight = backwarp_tenGrid.shape[2] + self.paddedWidth = backwarp_tenGrid.shape[3] - def forward(self, img0, img1, timestep): - # cant be cached - h, w = img0.shape[2], img0.shape[3] + self.blocks = [self.block0, self.block1, self.block2, self.block3] + + def forward(self, img0, img1, timeStep, f0): imgs = torch.cat([img0, img1], dim=1) - imgs_2 = torch.reshape(imgs, (2, 3, h, w)) - fs_2 = self.encode(imgs_2) - fs = torch.reshape(fs_2, (1, 16, h, w)) - if self.ensemble: - fs_rev = torch.cat(torch.split(fs, [8, 8], dim=1)[::-1], dim=1) - imgs_rev = torch.cat([img1, img0], dim=1) - - warped_img0 = img0 - warped_img1 = img1 - flows = None + imgs2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth)) + f1 = self.encode(img1[:, :3]) + fs = torch.cat([f0, f1], dim=1) + fs2 = torch.reshape(fs, (2, 8, self.paddedHeight, self.paddedWidth)) + warpedImg0 = img0 + warpedImg1 = img1 flows = None - blocks = [self.block0, self.block1, self.block2, self.block3] - scale_list = [8, 4, 2, 1] - for block, scale in zip(blocks, scale_list): + for block, scale in zip(self.blocks, self.scaleList): if flows is None: - temp = torch.cat((imgs, fs, timestep), 1) + temp = torch.cat((imgs, fs, timeStep), 1) flows, mask, feat = block(temp, scale=scale) else: temp = torch.cat( ( - wimg, - wf, - timestep, + wimg, # noqa + wf, # noqa + timeStep, mask, feat, (flows * (1 / scale) if scale != 1 else flows), @@ -179,39 +173,41 @@ def forward(self, img0, img1, timestep): 1, ) fds, mask, feat = block(temp, scale=scale) - flows = flows + fds - if self.ensemble: - flows_rev = torch.cat( - torch.split(flows, [2, 2], dim=1)[::-1], dim=1 - ) precomp = ( - self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div + self.backWarp + + flows.reshape((2, 2, self.paddedHeight, self.paddedWidth)) + * self.tenFlow ).permute(0, 2, 3, 1) if scale == 1: - warped_imgs = torch.nn.functional.grid_sample( - imgs_2, + warpedImgs = torch.nn.functional.grid_sample( + imgs2, precomp, mode="bilinear", padding_mode="border", align_corners=True, ) else: - imgs_fs_2 = torch.cat((imgs_2, fs_2), 1) + imgsFs2 = torch.cat((imgs2, fs2), 1) warps = torch.nn.functional.grid_sample( - imgs_fs_2, + imgsFs2, precomp, mode="bilinear", padding_mode="border", align_corners=True, ) wimg, wf = torch.split(warps, [3, 8], dim=1) - wimg = torch.reshape(wimg, (1, 6, h, w)) - wf = torch.reshape(wf, (1, 16, h, w)) + wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth)) + wf = torch.reshape(wf, (1, 16, self.paddedHeight, self.paddedWidth)) mask = torch.sigmoid(mask) - warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1]) - frame = warped_img0 * mask + warped_img1 * (1 - mask) - frame = frame[:, :, : self.height, : self.width][0] - return frame.permute(1, 2, 0).mul(255).float() + warpedImg0, warpedImg1 = torch.split(warpedImgs, [1, 1]) + return ( + (warpedImg0 * mask + warpedImg1 * (1 - mask))[ + :, :, : self.height, : self.width + ][0] + .permute(1, 2, 0) + .mul(255) + .float() + ), f1 \ No newline at end of file diff --git a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py index fc4a190a..59a1a338 100644 --- a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn + from torch.nn.functional import interpolate @@ -139,8 +140,6 @@ def __init__( height=1080, backwarp_tenGrid=None, tenFlow_div=None, - pw=1920, - ph=1088, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 8, c=192) @@ -150,39 +149,35 @@ def __init__( self.encode = Head() self.device = device self.dtype = dtype - self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] + self.scaleList = [8 / scale, 4 / scale, 2 / scale, 1 / scale] self.ensemble = ensemble self.width = width self.height = height + self.backWarp = backwarp_tenGrid + self.tenFlow = tenFlow_div + self.blocks = [self.block0, self.block1, self.block2, self.block3] - self.backwarp_tenGrid = backwarp_tenGrid - self.tenFlow_div = tenFlow_div - - self.pw = pw - self.ph = ph + self.paddedHeight = backwarp_tenGrid.shape[2] + self.paddedWidth = backwarp_tenGrid.shape[3] - def forward(self, img0, img1, timestep): - h, w = img0.shape[2], img0.shape[3] + def forward(self, img0, img1, timestep, f0): imgs = torch.cat([img0, img1], dim=1) - imgs_2 = torch.reshape(imgs, (2, 3, h, w)) - fs_2 = self.encode(imgs_2) - fs = torch.reshape(fs_2, (1, 8, h, w)) - + imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth)) + f1 = self.encode(img1[:, :3]) + fs = torch.cat([f0, f1], dim=1) + fs_2 = torch.reshape(fs, (2, 4, self.paddedHeight, self.paddedWidth)) warped_img0 = img0 warped_img1 = img1 flows = None - flows = None - blocks = [self.block0, self.block1, self.block2, self.block3] - scale_list = [8, 4, 2, 1] - for block, scale in zip(blocks, scale_list): + for block, scale in zip(self.blocks, self.scaleList): if flows is None: temp = torch.cat((imgs, fs, timestep), 1) flows, mask, feat = block(temp, scale=scale) else: temp = torch.cat( ( - wimg, - wf, + wimg, # noqa + wf, # noqa timestep, mask, feat, @@ -195,7 +190,9 @@ def forward(self, img0, img1, timestep): flows = flows + fds precomp = ( - self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div + self.backWarp + + flows.reshape((2, 2, self.paddedHeight, self.paddedWidth)) + * self.tenFlow ).permute(0, 2, 3, 1) if scale == 1: warped_imgs = torch.nn.functional.grid_sample( @@ -214,11 +211,16 @@ def forward(self, img0, img1, timestep): align_corners=True, ) wimg, wf = torch.split(warps, [3, 4], dim=1) - wimg = torch.reshape(wimg, (1, 6, h, w)) - wf = torch.reshape(wf, (1, 8, h, w)) + wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth)) + wf = torch.reshape(wf, (1, 8, self.paddedHeight, self.paddedWidth)) mask = torch.sigmoid(mask) warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1]) - frame = warped_img0 * mask + warped_img1 * (1 - mask) - frame = frame[:, :, : self.height, : self.width][0] - return frame.permute(1, 2, 0).mul(255).float() + return ( + (warped_img0 * mask + warped_img1 * (1 - mask))[ + :, :, : self.height, : self.width + ][0] + .permute(1, 2, 0) + .mul(255) + .float() + ), f1 \ No newline at end of file diff --git a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py index 314fd238..4c8ff60f 100644 --- a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py @@ -53,26 +53,6 @@ def forward(self, x): return x_view.permute(0, 1, 4, 2, 5, 3).reshape(b, out_channel, h, w) -class Head(nn.Module): - def __init__(self): - super(Head, self).__init__() - self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) - self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) - self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) - self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) - self.relu = nn.LeakyReLU(0.2, True) - - def forward(self, x, feat=False): - x0 = self.cnn0(x) - x = self.relu(x0) - x1 = self.cnn1(x) - x = self.relu(x1) - x2 = self.cnn2(x) - x = self.relu(x2) - x3 = self.cnn3(x) - if feat: - return [x0, x1, x2, x3] - return x3 class ResConv(nn.Module): @@ -154,95 +134,56 @@ def __init__( self.ensemble = ensemble self.width = width self.height = height - self.backwarp_tenGrid = backwarp_tenGrid - self.tenFlow_div = tenFlow_div + self.backWarp = backwarp_tenGrid + self.tenFlow = tenFlow_div + self.blocks = [self.block0, self.block1, self.block2, self.block3] + + self.paddedHeight = backwarp_tenGrid.shape[2] + self.paddedWidth = backwarp_tenGrid.shape[3] # self.contextnet = Contextnet() # self.unet = Unet() - def forward(self, img0, img1, timestep): + def forward(self, img0, img1, timestep,f0): # cant be cached h, w = img0.shape[2], img0.shape[3] imgs = torch.cat([img0, img1], dim=1) - imgs_2 = torch.reshape(imgs, (2, 3, h, w)) - fs_2 = self.encode(imgs_2) - fs = torch.reshape(fs_2, (1, 8, h, w)) - if self.ensemble: - fs_rev = torch.cat(torch.split(fs, [8, 8], dim=1)[::-1], dim=1) - imgs_rev = torch.cat([img1, img0], dim=1) + imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth)) + f1 = self.encode(img1[:, :3]) + fs = torch.cat([f0, f1], dim=1) + fs_2 = torch.reshape(fs, (2, 4, self.paddedHeight, self.paddedWidth)) + + + flows = None + mask = None flows = None mask = None blocks = [self.block0, self.block1, self.block2, self.block3] for block, scale in zip(blocks, self.scale_list): if flows is None: - if self.ensemble: - temp_ = torch.cat((imgs_rev, fs_rev, 1 - timestep), 1) - flowss, masks = block(torch.cat((temp, temp_), 0), scale=scale) - flows, flows_ = torch.split(flowss, [1, 1], dim=0) - mask, mask_ = torch.split(masks, [1, 1], dim=0) - flows = ( - flows - + torch.cat(torch.split(flows_, [2, 2], dim=1)[::-1], dim=1) - ) / 2 - mask = (mask - mask_) / 2 - - flows_rev = torch.cat( - torch.split(flows, [2, 2], dim=1)[::-1], dim=1 - ) - else: - temp = torch.cat((imgs, fs, timestep), 1) - flows, mask = block(temp, scale=scale) + + temp = torch.cat((imgs, fs, timestep), 1) + flows, mask = block(temp, scale=scale) else: - if self.ensemble: - temp = torch.cat( - ( - wimg, - wf, - timestep, - mask, - (flows * (1 / scale) if scale != 1 else flows), - ), - 1, - ) - temp_ = torch.cat( - ( - wimg_rev, - wf_rev, - 1 - timestep, - -mask, - (flows_rev * (1 / scale) if scale != 1 else flows_rev), - ), - 1, - ) - fdss, masks = block(torch.cat((temp, temp_), 0), scale=scale) - fds, fds_ = torch.split(fdss, [1, 1], dim=0) - mask, mask_ = torch.split(masks, [1, 1], dim=0) - fds = ( - fds + torch.cat(torch.split(fds_, [2, 2], dim=1)[::-1], dim=1) - ) / 2 - mask = (mask - mask_) / 2 - else: - temp = torch.cat( - ( - wimg, - wf, - timestep, - mask, - (flows * (1 / scale) if scale != 1 else flows), - ), - 1, - ) - fds, mask = block(temp, scale=scale) + + temp = torch.cat( + ( + wimg, + wf, + timestep, + mask, + (flows * (1 / scale) if scale != 1 else flows), + ), + 1, + ) + fds, mask = block(temp, scale=scale) flows = flows + fds - if self.ensemble: - flows_rev = torch.cat( - torch.split(flows, [2, 2], dim=1)[::-1], dim=1 - ) + precomp = ( - (self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div) + (self.backWarp + flows.reshape((2, 2, h, w)) * self.tenFlow) .permute(0, 2, 3, 1) .to(dtype=self.dtype) ) @@ -275,4 +216,4 @@ def forward(self, img0, img1, timestep): frame = warped_img0 * mask + warped_img1 * (1 - mask) frame = frame[:, :, : self.height, : self.width][0] - return frame.permute(1, 2, 0).mul(255).float() + return frame.permute(1, 2, 0).mul(255).float(), f1 diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index 69510fac..291fc356 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -9,6 +9,7 @@ errorAndLog, modelsDirectory, check_bfloat16_support, + log ) torch.set_float32_matmul_precision("high") @@ -119,6 +120,7 @@ def __init__( self.stream = torch.cuda.Stream() self.prepareStream = torch.cuda.Stream() scale = 1 + self.f1encode = None if UHDMode: scale = 0.5 with torch.cuda.stream(self.prepareStream): @@ -147,6 +149,18 @@ def __init__( ) self.timestepDict[timestep] = timestep_tens # detect what rife arch to use + self.inputs = [ + torch.zeros( + (1, 3, self.ph, self.pw), dtype=self.dtype, device=device + ), + torch.zeros( + (1, 3, self.ph, self.pw), dtype=self.dtype, device=device + ), + torch.zeros( + (1, 1, self.ph, self.pw), dtype=self.dtype, device=device + ), + ] + log("interp arch"+ interpolateArch.lower()) match interpolateArch.lower(): case "rife46": from .InterpolateArchs.RIFE.rife46IFNET import IFNet @@ -156,25 +170,57 @@ def __init__( from .InterpolateArchs.RIFE.rife47IFNET import IFNet v1 = False + self.inputs.append( + torch.zeros( + (1, 4, self.ph, self.pw), dtype=self.dtype, device=device + ), + ) + self.encode = torch.nn.Sequential( + torch.nn.Conv2d(3, 16, 3, 2, 1), + torch.nn.ConvTranspose2d(16, 4, 4, 2, 1) + ).to(device=self.device, dtype=self.dtype) case "rife413": - from .InterpolateArchs.RIFE.rife413IFNET import IFNet + from .InterpolateArchs.RIFE.rife413IFNET import IFNet, Head v1 = False + self.inputs.append( + torch.zeros( + (1, 8, self.ph, self.pw), dtype=self.dtype, device=device + ), + ) + self.encode = Head().to(device=self.device, dtype=self.dtype) case "rife420": - from .InterpolateArchs.RIFE.rife420IFNET import IFNet + from .InterpolateArchs.RIFE.rife420IFNET import IFNet, Head v1 = False + self.inputs.append( + torch.zeros( + (1, 8, self.ph, self.pw), dtype=self.dtype, device=device + ), + ) + self.encode = Head().to(device=self.device, dtype=self.dtype) case "rife421": - from .InterpolateArchs.RIFE.rife421IFNET import IFNet - + from .InterpolateArchs.RIFE.rife421IFNET import IFNet, Head v1 = False + self.inputs.append( + torch.zeros( + (1, 8, self.ph, self.pw), dtype=self.dtype, device=device + ), + ) + self.encode = Head().to(device=self.device, dtype=self.dtype) case "rife422lite": - from .InterpolateArchs.RIFE.rife422_liteIFNET import IFNet + from .InterpolateArchs.RIFE.rife422_liteIFNET import IFNet, Head + self.inputs.append( + torch.zeros( + (1, 4, self.ph, self.pw), dtype=self.dtype, device=device + ), + ) + self.encode = Head().to(device=self.device, dtype=self.dtype) v1 = False case _: errorAndLog("Invalid Interpolation Arch") - + self.v1 = v1 # if 4.6 v1 if v1: self.tenFlow_div = torch.tensor( @@ -273,21 +319,11 @@ def __init__( ), ) if not os.path.isfile(trt_engine_path): - inputs = [ - torch.zeros( - (1, 3, self.ph, self.pw), dtype=self.dtype, device=device - ), - torch.zeros( - (1, 3, self.ph, self.pw), dtype=self.dtype, device=device - ), - torch.zeros( - (1, 1, self.ph, self.pw), dtype=self.dtype, device=device - ), - ] + self.flownet = torch_tensorrt.compile( self.flownet, ir="dynamo", - inputs=inputs, + inputs=self.inputs, enabled_precisions={self.dtype}, debug=trt_debug, workspace_size=trt_workspace_size, @@ -297,7 +333,7 @@ def __init__( device=device, ) - torch_tensorrt.save(self.flownet, trt_engine_path, inputs=inputs) + torch_tensorrt.save(self.flownet, trt_engine_path, inputs=self.inputs) self.flownet = torch.export.load(trt_engine_path).module() @@ -313,7 +349,12 @@ def handlePrecision(self, precision): def process(self, img0, img1, timestep): with torch.cuda.stream(self.stream): timestep = self.timestepDict[timestep] - output = self.flownet(img0, img1, timestep) + if not self.v1: + if self.f1encode is None: + self.f1encode = self.encode(img1[:, :3]) + output, self.f1encode = self.flownet(img0, img1, timestep, self.f1encode) + else: + output = self.flownet(img0, img1, timestep) output = self.tensor_to_frame(output) self.stream.synchronize() return output diff --git a/backend/src/RenderVideo.py b/backend/src/RenderVideo.py index de49699f..9a51722b 100644 --- a/backend/src/RenderVideo.py +++ b/backend/src/RenderVideo.py @@ -189,8 +189,9 @@ def renderInterpolate(self): frame = self.interpolate(self.setup_frame0, setup_frame1, timestep) self.writeQueue.put(frame) else: - # uncache the cached frame - self.undoSetup(self.frame0) + + self.interpolate(self.setup_frame0, setup_frame1, 0) + for n in range(self.ceilInterpolateFactor): self.writeQueue.put(self.frame0) try: # get_nowait sends an error out of the queue is empty, I would like a better solution than this though