diff --git a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py index 4d33f797..c66295b8 100644 --- a/backend/src/InterpolateArchs/RIFE/rife413IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife413IFNET.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn - +import math try: from .interpolate import interpolate except ImportError: @@ -138,6 +138,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 16, c=192) @@ -151,11 +153,14 @@ def __init__( self.ensemble = ensemble self.width = width self.height = height + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div + # self.contextnet = Contextnet() # self.unet = Unet() - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): # cant be cached h, w = img0.shape[2], img0.shape[3] imgs = torch.cat([img0, img1], dim=1) @@ -237,7 +242,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): torch.split(flows, [2, 2], dim=1)[::-1], dim=1 ) precomp = ( - (backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div) + (self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div) .permute(0, 2, 3, 1) .to(dtype=self.dtype) ) diff --git a/backend/src/InterpolateArchs/RIFE/rife420IFNET.py b/backend/src/InterpolateArchs/RIFE/rife420IFNET.py index a8d8b57c..df39a438 100644 --- a/backend/src/InterpolateArchs/RIFE/rife420IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife420IFNET.py @@ -5,7 +5,7 @@ from .interpolate import interpolate except ImportError: from torch.nn.functional import interpolate - +import math def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( @@ -138,6 +138,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 16, c=384) @@ -151,14 +153,12 @@ def __init__( self.ensemble = ensemble self.width = width self.height = height + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div - # self.contextnet = Contextnet() - # self.unet = Unet() - - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): # cant be cached h, w = img0.shape[2], img0.shape[3] - tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1) imgs = torch.cat([img0, img1], dim=1) imgs_2 = torch.reshape(imgs, (2, 3, h, w)) fs_2 = self.encode(imgs_2) @@ -240,7 +240,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): torch.split(flows, [2, 2], dim=1)[::-1], dim=1 ) precomp = ( - (backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div) + (self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div) .permute(0, 2, 3, 1) .to(dtype=self.dtype) ) diff --git a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py index 1da3d340..d95c7d89 100644 --- a/backend/src/InterpolateArchs/RIFE/rife421IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife421IFNET.py @@ -128,6 +128,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 16, c=256) @@ -141,13 +143,12 @@ def __init__( self.ensemble = ensemble self.width = width self.height = height - # self.contextnet = Contextnet() - # self.unet = Unet() + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): # cant be cached h, w = img0.shape[2], img0.shape[3] - tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1) imgs = torch.cat([img0, img1], dim=1) imgs_2 = torch.reshape(imgs, (2, 3, h, w)) fs_2 = self.encode(imgs_2) @@ -189,7 +190,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): torch.split(flows, [2, 2], dim=1)[::-1], dim=1 ) precomp = ( - backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div + self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div ).permute(0, 2, 3, 1) if scale == 1: warped_imgs = torch.nn.functional.grid_sample( diff --git a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py index 7d2b5aa6..4d36e02c 100644 --- a/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife422_liteIFNET.py @@ -117,6 +117,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7+8, c=192) @@ -131,16 +133,13 @@ def __init__( self.width = width self.height = height - # self.contextnet = Contextnet() - # self.unet = Unet() - - + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): # cant be cached h, w = img0.shape[2], img0.shape[3] - tenFlow_div = tenFlow_div.reshape(1, 2, 1, 1) imgs = torch.cat([img0, img1], dim=1) imgs_2 = torch.reshape(imgs, (2, 3, h, w)) fs_2 = self.encode(imgs_2) @@ -177,7 +176,7 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): precomp = ( - backwarp_tenGrid + flows.reshape((2, 2, h, w)) * tenFlow_div + self.backwarp_tenGrid + flows.reshape((2, 2, h, w)) * self.tenFlow_div ).permute(0, 2, 3, 1) if scale == 1: warped_imgs = torch.nn.functional.grid_sample( diff --git a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py index cc1dabe7..90020192 100644 --- a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py @@ -102,6 +102,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7, c=192) @@ -114,8 +116,10 @@ def __init__( self.device = device self.width = width self.height = height + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): flow_list = [] merged = [] mask_list = [] @@ -167,8 +171,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): mask = mask + m0 mask_list.append(mask) flow_list.append(flow) - warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) - warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) + warped_img0 = warp(img0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid) + warped_img1 = warp(img1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid) merged.append((warped_img0, warped_img1)) mask_list[3] = torch.sigmoid(mask_list[3]) frame = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) diff --git a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py index f9fef994..48bf0722 100644 --- a/backend/src/InterpolateArchs/RIFE/rife47IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife47IFNET.py @@ -117,6 +117,8 @@ def __init__( device="cuda", width=1920, height=1080, + backwarp_tenGrid=None, + tenFlow_div=None, ): super(IFNet, self).__init__() self.block0 = IFBlock(7 + 8, c=192) @@ -132,11 +134,13 @@ def __init__( self.ensemble = ensemble self.width = width self.height = height + self.backwarp_tenGrid = backwarp_tenGrid + self.tenFlow_div = tenFlow_div # self.contextnet = Contextnet() # self.unet = Unet() - def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): + def forward(self, img0, img1, timestep): f0 = self.encode(img0[:, :3]) f1 = self.encode(img1[:, :3]) flow_list = [] @@ -163,8 +167,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 mask = (mask + (-m_)) / 2 else: - wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) - wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) + wf0 = warp(f0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid) + wf1 = warp(f1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid) fd, m0 = block[i]( torch.cat( ( @@ -203,8 +207,8 @@ def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): flow = flow + fd mask_list.append(mask) flow_list.append(flow) - warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) - warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) + warped_img0 = warp(img0, flow[:, :2], self.tenFlow_div, self.backwarp_tenGrid) + warped_img1 = warp(img1, flow[:, 2:4], self.tenFlow_div, self.backwarp_tenGrid) merged.append((warped_img0, warped_img1)) mask = torch.sigmoid(mask) frame= warped_img0 * mask + warped_img1 * (1 - mask) diff --git a/backend/src/InterpolateTorch.py b/backend/src/InterpolateTorch.py index ac1709aa..bd04cffb 100644 --- a/backend/src/InterpolateTorch.py +++ b/backend/src/InterpolateTorch.py @@ -201,6 +201,8 @@ def __init__( device=self.device, width=self.width, height=self.height, + backwarp_tenGrid=self.backwarp_tenGrid, + tenFlow_div=self.tenFlow_div, ) state_dict = { @@ -252,14 +254,7 @@ def __init__( torch.zeros( (1, 1, self.ph, self.pw), dtype=self.dtype, device=device ), - torch.zeros((2,), dtype=self.dtype, device=device), - torch.zeros( - (1, 2, self.ph, self.pw), dtype=self.dtype, device=device - ) - if v1 else - torch.zeros( - (1, 2, 1, 1), dtype=self.dtype, device=device - ), + ] self.flownet = torch_tensorrt.compile( self.flownet, @@ -293,7 +288,7 @@ def process(self, img0, img1, timestep): ) output = self.flownet( - img0, img1, timestep, self.tenFlow_div, self.backwarp_tenGrid + img0, img1, timestep ) return self.tensor_to_frame(output) @@ -305,6 +300,7 @@ def tensor_to_frame(self, frame: torch.Tensor): return ( frame + .byte() .contiguous() .cpu() .numpy() diff --git a/backend/src/UpscaleTorch.py b/backend/src/UpscaleTorch.py index 31fb1daf..fc8dd5ab 100644 --- a/backend/src/UpscaleTorch.py +++ b/backend/src/UpscaleTorch.py @@ -188,22 +188,16 @@ def renderToNPArray(self, image: torch.Tensor) -> torch.Tensor: .float() .clamp(0.0, 1.0) .mul(255) + .byte() .contiguous() .detach() .cpu() .numpy() ) - @torch.inference_mode() - def renderImagesInDirectory(self, dir): - pass - def getScale(self): return self.scale - def saveImage(self, image: np.array, fullOutputPathLocation): - cv2.imwrite(fullOutputPathLocation, image) - @torch.inference_mode() def renderTiledImage( self,