-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
72 lines (63 loc) · 2.28 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from models import SetPredictor
from datasets import CLEVR
from utils import average_precision_clevr
def evaluate(model, dataloader, device, limit=2):
thrs = [-1, 1, 0.5, 0.25, 0.125, 0.0625]
model.to(device)
model.eval()
predictions = []
targets = []
print('Making predictions...')
with torch.no_grad():
for i, batch in enumerate(dataloader):
img = batch['image'].to(device)
targ = batch['target'].to(device)
with torch.no_grad():
pred = model(img)
targets.append(targ)
predictions.append(pred['prediction'])
if i == limit:
break
print('Predictions are ready. Calculating metrics...')
predictions = torch.cat(predictions, dim=0).reshape(-1, 10, 19).detach().cpu().numpy()
targets = torch.cat(targets, dim=0).reshape(-1, 10, 19).detach().cpu().numpy()
for thr in thrs:
ap = average_precision_clevr(predictions, targets, distance_threshold=thr)
print(f"AP({thr}): {ap:.4}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size",
"-b",
type=int,
default=64)
parser.add_argument("--model_path",
"-p",
type=str,
default='./gmm_base_detach.pth')
parser.add_argument("--device",
"-d",
type=str,
default='cpu')
parser.add_argument("--limit_data",
"-l",
default=2)
args = parser.parse_args()
pl.seed_everything(39)
device = torch.device(args.device)
data = CLEVR(
images_path='./CLEVR_v1.0/images/val',
scenes_path='./CLEVR_v1.0/scenes/CLEVR_val_scenes.json',
max_objs=10
)
dataloader = DataLoader(data, batch_size=args.batch_size)
model = SetPredictor(num_slots=10)
model.load_state_dict(torch.load('e725.pth', map_location=device), strict=False)
limit = args.limit_data
evaluate(model, dataloader, device, limit)
print('*** END ***')
if __name__ == '__main__':
main()