Skip to content

Commit

Permalink
add gimm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TNTwise committed Dec 7, 2024
1 parent 65be4a9 commit 7c532f8
Showing 1 changed file with 131 additions and 0 deletions.
131 changes: 131 additions & 0 deletions backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from gimmvfi_r import GIMMVFI_R
import torch
import torch.nn.functional as F
import os
from PIL import Image
import numpy as np

class InputPadder:
"""Pads images such that dimensions are divisible by divisor"""

def __init__(self, dims, divisor=16):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]

def pad(self, *inputs):
if len(inputs) == 1:
return F.pad(inputs[0], self._pad, mode="replicate")
else:
return [F.pad(x, self._pad, mode="replicate") for x in inputs]

def unpad(self, *inputs):
if len(inputs) == 1:
return self._unpad(inputs[0])
else:
return [self._unpad(x) for x in inputs]

def _unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]



device = torch.device("cuda")
model = GIMMVFI_R("GIMMVFI_RAFT.pth").to(device)
def convert(param):
return {k.replace("module.", ""): v for k, v in param.items() if "module" in k}
ckpt = torch.load("gimmvfi_r_arb_lpips.pt", map_location="cpu")
raft = torch.load("raft-things.pth", map_location="cpu")
combined_state_dict = {
"gimmvfi_r": ckpt["state_dict"],
"raft": convert(raft)
}
torch.save(combined_state_dict, "GIMMVFI_RAFT.pth")
model.load_state_dict(combined_state_dict["gimmvfi_r"])

images = []
def load_image(img_path):
img = Image.open(img_path)
raw_img = np.array(img.convert("RGB"))
img = torch.from_numpy(raw_img.copy()).permute(2, 0, 1) / 255.0
return img.to(torch.float).unsqueeze(0)


img_path0 = os.path.join(source_path, img_list[j])
img_path2 = os.path.join(source_path, img_list[j + 1])
# prepare data b,c,h,w
I0 = load_image(img_path0)
I2 = load_image(img_path2)
padder = InputPadder(I0.shape, 32)
I0, I2 = padder.pad(I0, I2)
xs = torch.cat((I0.unsqueeze(2), I2.unsqueeze(2)), dim=2).to(
device, non_blocking=True
)
model.eval()
batch_size = xs.shape[0]
s_shape = xs.shape[-2:]

model.zero_grad()
ds_factor = args.ds_factor
with torch.no_grad():
coord_inputs = [
(
model.sample_coord_input(
batch_size,
s_shape,
[1 / args.N * i],
device=xs.device,
upsample_ratio=ds_factor,
),
None,
)
for i in range(1, args.N)
]
timesteps = [
i * 1 / args.N * torch.ones(xs.shape[0]).to(xs.device).to(torch.float)
for i in range(1, args.N)
]
all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor)
out_frames = [padder.unpad(im) for im in all_outputs["imgt_pred"]]
out_flowts = [padder.unpad(f) for f in all_outputs["flowt"]]
flowt_imgs = [
flow_to_image(
flowt.squeeze().detach().cpu().permute(1, 2, 0).numpy(),
convert_to_bgr=True,
)
for flowt in out_flowts
]
I1_pred_img = [
(I1_pred[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0)[
:, :, ::-1
].astype(np.uint8)
for I1_pred in out_frames
]

for i in range(args.N - 1):
images.append(I1_pred_img[i])
flows.append(flowt_imgs[i])

images[-1] = cv2.hconcat([ori_image[-1], images[-1]])

images.append(
(
(padder.unpad(I2)).squeeze().detach().cpu().numpy().transpose(1, 2, 0)
* 255.0
)[:, :, ::-1].astype(np.uint8)
)
ori_image.append(
(
(padder.unpad(I2)).squeeze().detach().cpu().numpy().transpose(1, 2, 0)
* 255.0
)[:, :, ::-1].astype(np.uint8)
)
images[-1] = cv2.hconcat([ori_image[-1], images[-1]])

0 comments on commit 7c532f8

Please sign in to comment.