-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluate.py
83 lines (61 loc) · 2.52 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import numpy as np
import torch
import torch.nn.functional as F
from utils import centralize, resize_img, resize_flow, EPE
from datasets import Sintel_Clean, KITTI_2015
div_flow = 20.0
sintel_clean_dataset = Sintel_Clean()
kitti_2015_dataset = KITTI_2015()
def test_sintel_clean(model):
print('\nTesting Sintel Clean')
epe_all = 0.0
model.eval()
test_iters = len(sintel_clean_dataset)
for i in range(test_iters):
img1, img2, flow = sintel_clean_dataset[i]
img1 = img1.unsqueeze(0).cuda()
img2 = img2.unsqueeze(0).cuda()
flow = flow.unsqueeze(0).cuda()
mask = torch.ones_like(flow[:, :1, :, :]).cuda()
img1 = resize_img(img1, size=(448, 1024))
img2 = resize_img(img2, size=(448, 1024))
img1, img2, _ = centralize(img1, img2)
imgs =torch.cat([img1, img2], 1)
with torch.no_grad():
output = model(imgs).data
if model.__class__.__name__ == 'FastFlowNet':
flow_pred = div_flow * F.interpolate(output, size=(448, 1024), mode='bilinear', align_corners=False)
elif model.__class__.__name__ == 'RAFT':
flow_pred = F.interpolate(output, size=(448, 1024), mode='bilinear', align_corners=False)
flow_pred = resize_flow(flow_pred, size=(436, 1024))
epe_all += EPE(flow_pred, flow, mask)
epe_all /= test_iters
return epe_all
def test_kitti_2015(model):
print('\nTesting KITTI 2015')
epe_all = 0.0
model.eval()
test_iters = len(kitti_2015_dataset)
for i in range(test_iters):
img1, img2, flow = kitti_2015_dataset[i]
img1 = img1.unsqueeze(0).cuda()
img2 = img2.unsqueeze(0).cuda()
flow = flow.unsqueeze(0).cuda()
mask = flow[:, 2:, :, :]
flow = flow[:, :2, :, :]
input_size = img1.shape[2:]
img1 = resize_img(img1, size=(512, 1024))
img2 = resize_img(img2, size=(512, 1024))
img1, img2, _ = centralize(img1, img2)
imgs =torch.cat([img1, img2], 1)
with torch.no_grad():
output = model(imgs).data
if model.__class__.__name__ == 'FastFlowNet':
flow_pred = div_flow * F.interpolate(output, size=(512, 1024), mode='bilinear', align_corners=False)
elif model.__class__.__name__ == 'RAFT':
flow_pred = F.interpolate(output, size=(512, 1024), mode='bilinear', align_corners=False)
flow_pred = resize_flow(flow_pred, input_size)
epe_all += EPE(flow_pred, flow, mask)
epe_all /= test_iters
return epe_all