Skip to content

Commit

Permalink
force fp32 on warp
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Oct 11, 2024
1 parent 057af36 commit 5dc5a67
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions backend/src/InterpolateArchs/RIFE/warplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
import torch.nn.functional as F


'''def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
dtype = tenInput.dtype
tenInput = tenInput.to(torch.float)
tenFlow = tenFlow.to(torch.float)

tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1)
g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1)
grid_sample = F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True)
return grid_sample.to(dtype)'''
return grid_sample.to(dtype)

def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
"""def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
tenFlow = torch.cat(
[tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1
)
Expand All @@ -24,4 +24,5 @@ def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid):
mode="bilinear",
padding_mode="border",
align_corners=True,
)
)
"""

0 comments on commit 5dc5a67

Please sign in to comment.