diff --git a/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py b/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py new file mode 100644 index 00000000..c6807c29 --- /dev/null +++ b/backend/src/pytorch/InterpolateArchs/GIMM/GIMM.py @@ -0,0 +1,131 @@ +from gimmvfi_r import GIMMVFI_R +import torch +import torch.nn.functional as F +import os +from PIL import Image +import numpy as np + +class InputPadder: + """Pads images such that dimensions are divisible by divisor""" + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode="replicate") + else: + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + + +device = torch.device("cuda") +model = GIMMVFI_R("GIMMVFI_RAFT.pth").to(device) +def convert(param): + return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} +ckpt = torch.load("gimmvfi_r_arb_lpips.pt", map_location="cpu") +raft = torch.load("raft-things.pth", map_location="cpu") +combined_state_dict = { + "gimmvfi_r": ckpt["state_dict"], + "raft": convert(raft) +} +torch.save(combined_state_dict, "GIMMVFI_RAFT.pth") +model.load_state_dict(combined_state_dict["gimmvfi_r"]) + +images = [] +def load_image(img_path): + img = Image.open(img_path) + raw_img = np.array(img.convert("RGB")) + img = torch.from_numpy(raw_img.copy()).permute(2, 0, 1) / 255.0 + return img.to(torch.float).unsqueeze(0) + + +img_path0 = os.path.join(source_path, img_list[j]) +img_path2 = os.path.join(source_path, img_list[j + 1]) +# prepare data b,c,h,w +I0 = load_image(img_path0) +I2 = load_image(img_path2) +padder = InputPadder(I0.shape, 32) +I0, I2 = padder.pad(I0, I2) +xs = torch.cat((I0.unsqueeze(2), I2.unsqueeze(2)), dim=2).to( + device, non_blocking=True +) +model.eval() +batch_size = xs.shape[0] +s_shape = xs.shape[-2:] + +model.zero_grad() +ds_factor = args.ds_factor +with torch.no_grad(): + coord_inputs = [ + ( + model.sample_coord_input( + batch_size, + s_shape, + [1 / args.N * i], + device=xs.device, + upsample_ratio=ds_factor, + ), + None, + ) + for i in range(1, args.N) + ] + timesteps = [ + i * 1 / args.N * torch.ones(xs.shape[0]).to(xs.device).to(torch.float) + for i in range(1, args.N) + ] + all_outputs = model(xs, coord_inputs, t=timesteps, ds_factor=ds_factor) + out_frames = [padder.unpad(im) for im in all_outputs["imgt_pred"]] + out_flowts = [padder.unpad(f) for f in all_outputs["flowt"]] +flowt_imgs = [ + flow_to_image( + flowt.squeeze().detach().cpu().permute(1, 2, 0).numpy(), + convert_to_bgr=True, + ) + for flowt in out_flowts +] +I1_pred_img = [ + (I1_pred[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0)[ + :, :, ::-1 + ].astype(np.uint8) + for I1_pred in out_frames +] + +for i in range(args.N - 1): + images.append(I1_pred_img[i]) + flows.append(flowt_imgs[i]) + + images[-1] = cv2.hconcat([ori_image[-1], images[-1]]) + +images.append( + ( + (padder.unpad(I2)).squeeze().detach().cpu().numpy().transpose(1, 2, 0) + * 255.0 + )[:, :, ::-1].astype(np.uint8) +) +ori_image.append( + ( + (padder.unpad(I2)).squeeze().detach().cpu().numpy().transpose(1, 2, 0) + * 255.0 + )[:, :, ::-1].astype(np.uint8) +) +images[-1] = cv2.hconcat([ori_image[-1], images[-1]])