From 5dc5a678f8d312d31628c80319a7021df49cd5f8 Mon Sep 17 00:00:00 2001 From: TNTwise Date: Fri, 11 Oct 2024 10:40:22 -0500 Subject: [PATCH] force fp32 on warp --- backend/src/InterpolateArchs/RIFE/warplayer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/src/InterpolateArchs/RIFE/warplayer.py b/backend/src/InterpolateArchs/RIFE/warplayer.py index 59c61639..52d28036 100644 --- a/backend/src/InterpolateArchs/RIFE/warplayer.py +++ b/backend/src/InterpolateArchs/RIFE/warplayer.py @@ -2,7 +2,7 @@ import torch.nn.functional as F -'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): +def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): dtype = tenInput.dtype tenInput = tenInput.to(torch.float) tenFlow = tenFlow.to(torch.float) @@ -10,9 +10,9 @@ tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) grid_sample = F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True) - return grid_sample.to(dtype)''' + return grid_sample.to(dtype) -def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): +"""def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): tenFlow = torch.cat( [tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1 ) @@ -24,4 +24,5 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): mode="bilinear", padding_mode="border", align_corners=True, - ) \ No newline at end of file + ) +""" \ No newline at end of file