Skip to content

Commit

Permalink
remove any v1 instance
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Aug 23, 2024
1 parent d6638b4 commit cfca1f0
Showing 1 changed file with 19 additions and 45 deletions.
64 changes: 19 additions & 45 deletions backend/src/InterpolateTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,70 +134,44 @@ def __init__(
case "rife46":
from .InterpolateArchs.RIFE.rife46IFNET import IFNet

v1 = False
case "rife47":
from .InterpolateArchs.RIFE.rife47IFNET import IFNet

v1 = False
case "rife413":
from .InterpolateArchs.RIFE.rife413IFNET import IFNet

v1 = False
case "rife420":
from .InterpolateArchs.RIFE.rife420IFNET import IFNet

v1 = False
case "rife421":
from .InterpolateArchs.RIFE.rife421IFNET import IFNet

v1 = False
case "rife422-lite":
from .InterpolateArchs.RIFE.rife422_liteIFNET import IFNet

v1 = False
case _:
errorAndLog("Invalid Interpolation Arch")

# if 4.6 v1
if v1:
self.tenFlow_div = torch.tensor(
[(self.pw - 1.0) / 2.0, (self.ph - 1.0) / 2.0],
dtype=self.dtype,
device=self.device,
)
tenHorizontal = (
torch.linspace(-1.0, 1.0, self.pw, dtype=self.dtype, device=self.device)
.view(1, 1, 1, self.pw)
.expand(-1, -1, self.ph, -1)
).to(dtype=self.dtype, device=self.device)
tenVertical = (
torch.linspace(-1.0, 1.0, self.ph, dtype=self.dtype, device=self.device)
.view(1, 1, self.ph, 1)
.expand(-1, -1, -1, self.pw)
).to(dtype=self.dtype, device=self.device)
self.backwarp_tenGrid = torch.cat([tenHorizontal, tenVertical], 1)

else:
# if v2
h_mul = 2 / (self.pw - 1)
v_mul = 2 / (self.ph - 1)
self.tenFlow_div = (
torch.Tensor([h_mul, v_mul])
.to(device=self.device, dtype=self.dtype)
.reshape(1, 2, 1, 1)
)

h_mul = 2 / (self.pw - 1)
v_mul = 2 / (self.ph - 1)
self.tenFlow_div = (
torch.Tensor([h_mul, v_mul])
.to(device=self.device, dtype=self.dtype)
.reshape(1, 2, 1, 1)
)

self.backwarp_tenGrid = torch.cat(
(
(torch.arange(self.pw) * h_mul - 1)
.reshape(1, 1, 1, -1)
.expand(-1, -1, self.ph, -1),
(torch.arange(self.ph) * v_mul - 1)
.reshape(1, 1, -1, 1)
.expand(-1, -1, -1, self.pw),
),
dim=1,
).to(device=self.device, dtype=self.dtype)
self.backwarp_tenGrid = torch.cat(
(
(torch.arange(self.pw) * h_mul - 1)
.reshape(1, 1, 1, -1)
.expand(-1, -1, self.ph, -1),
(torch.arange(self.ph) * v_mul - 1)
.reshape(1, 1, -1, 1)
.expand(-1, -1, -1, self.pw),
),
dim=1,
).to(device=self.device, dtype=self.dtype)

self.flownet = IFNet(
scale=scale,
Expand Down

0 comments on commit cfca1f0

Please sign in to comment.