diff --git a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py index d6a93ca4..4fa1448f 100644 --- a/backend/src/InterpolateArchs/RIFE/rife46IFNET.py +++ b/backend/src/InterpolateArchs/RIFE/rife46IFNET.py @@ -145,8 +145,7 @@ def forward(self, img0, img1, timestep): temp = torch.cat( ( - wimg, - wf, + warped_imgs, timestep, mask, (flows * (1 / scale) if scale != 1 else flows), @@ -179,9 +178,7 @@ def forward(self, img0, img1, timestep): padding_mode="border", align_corners=True, ) - wimg, wf = torch.split(warps, [1, 2], dim=1) - wimg = torch.reshape(wimg, (1, 2, h, w)) - wf = torch.reshape(wf, (1, 4, h, w)) + warped_imgs = warps.reshape(1,6,h,w) mask = torch.sigmoid(mask) warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])