-
Notifications
You must be signed in to change notification settings - Fork 1
/
loop_mosh.py
136 lines (100 loc) · 4.4 KB
/
loop_mosh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import gc
import sys
from argparse import Namespace
import cv2
import numpy as np
import torch
from kornia.geometry.transform import remap
from kornia.utils import create_meshgrid
from tqdm import tqdm
from mosh_utils import np_to_torch, read_frames, write_frames, write_gif
sys.path.append('RAFT/core')
sys.path.append('RAFT')
from raft import RAFT
from utils.utils import InputPadder
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def loop_mosh(args):
# if using CNN-based optical flow -- RAFT
if args.raft:
print("Datamoshing using RAFT...")
dummy = Namespace(
small=True,
alternate_corr=False,
model='RAFT/models/raft-small.pth',
# path='RAFT/demo-frames/',
mixed_precision=True
)
model = torch.nn.DataParallel(RAFT(dummy))
model.load_state_dict(torch.load(dummy.model, map_location=DEVICE))
model = model.module
model.to(DEVICE)
model.eval()
# frame height and width
h, w = args.height, args.width
# load up the video to be moshed
print("Reading video: {}".format(args.input_path))
vid_frames = read_frames(args.input_path, h=h, w=w)
vid_frames = vid_frames[::-1]
flows = []
outputs = []
warps = []
masks = []
for i, (image2, image1) in tqdm(enumerate(zip(vid_frames[:-1], vid_frames[1:])), total=len(vid_frames)):
if args.raft:
with torch.no_grad():
image1 = np_to_torch(image1).to(DEVICE)
image2 = np_to_torch(image2).to(DEVICE)
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=args.raft_iter, test_mode=True)
flow = flow_up.cpu()
else:
image1 = cv2.cvtColor(image1, cv2.COLOR_RGB2GRAY)
image2 = cv2.cvtColor(image2, cv2.COLOR_RGB2GRAY)
# TODO: pass the optical flow parameters through argparse
flow = cv2.calcOpticalFlowFarneback(image1, image2, None, 0.5, args.pyr_levels, 25, 3, 7, 1.5, 0)
flow = torch.from_numpy(np.array(flow)).permute(2, 0, 1).unsqueeze(0)
flows.append(flow * args.flow_speed)
vid_frames = [np.array(f).astype(np.uint8) for f in vid_frames]
# the frames are reversed so this is actually the first frame
start_frame = vid_frames[-1]
warped = torch.from_numpy(start_frame).permute(2, 0, 1).unsqueeze(0).float()
fg_mask = torch.ones_like(warped) * 255.
print("Creating datamosh...")
for flw in flows:
grid = create_meshgrid(h, w, False)
grid += flw.permute(0, 2, 3, 1).cpu()
warped = remap(warped, grid[..., 0], grid[..., 1], mode='nearest', align_corners=True)
fg_mask = remap(fg_mask, grid[..., 0], grid[..., 1], mode='nearest', align_corners=True)
masks.append(fg_mask.squeeze(0).permute(1, 2, 0).numpy())
warps.append(warped.squeeze(0).permute(1, 2, 0).numpy())
for orig, warped, mask in zip(vid_frames, warps, masks):
mask = mask.astype(bool)
warped[~mask] = orig[~mask]
outputs.append(warped)
del masks
del warps
del flows
gc.collect()
outputs = outputs[::-1]
if args.gif:
write_gif(outputs, args.output_path)
else:
write_frames(outputs, args.output_path, h, w)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# data arguments
parser.add_argument('-ip', '--input_path', help="path to the input video, e.g. ./input.mp4", default=None)
parser.add_argument('-op', '--output_path', help="path to output video, e.g. ./output.mp4")
parser.add_argument('-g', '--gif', action='store_true', help='whether or not to use output a gif')
# flow arguments
parser.add_argument('-rt', '--raft', action='store_true', help='flag to use raft flow')
parser.add_argument('-ri', '--raft_iter', default=5, type=int, help='raft iterations')
parser.add_argument('-fs', '--flow_speed', type=float, help='optical flow speed', default=1.0)
parser.add_argument('-pl', '--pyr_levels', type=int, help='number of farneback pyramid levels', default=3)
# image arguments
parser.add_argument('-fh', '--height', help='frame height', default=720, type=int)
parser.add_argument('-fw', '--width', help='frame height', default=1080, type=int)
args = parser.parse_args()
loop_mosh(args)