-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_metrics.py
62 lines (50 loc) · 2.1 KB
/
compute_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import datetime
import pdb
import re
import time
import torch
import torch.nn.functional as F
from torch import nn, optim
from torchvision.transforms.functional import InterpolationMode
import glob
import os
import os.path as op
import json
import numpy as np
from torchvision import transforms
from tqdm import tqdm
import collections
import argparse
from utils import compute_map, compute_mrr, compute_ndcg, compute_recall
from evaluate_retrieval import get_metric_values_with_rank, load_data
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dir", type=str, default="")
parser.add_argument("-s", "--split", type=str, default="val", choices=['val', 'test'])
opt = parser.parse_args()
split = opt.split
with open(op.join(opt.dir, "features", f"{split}_gt_rel.json"), "r") as file:
gt_rel = json.load(file)
sim_matrix = []
text_embeds = torch.load(op.join(opt.dir, "features", f"{split}_text_feat.pt")).to("cuda")
if "restricted" in opt.dir:
image_feat_files = sorted(glob.glob(op.join(opt.dir, "features", f"{split}_image_*")))
else:
image_feat_files = sorted(glob.glob(op.join(opt.dir, "features", "val_image_*")))
image_feat_files = sorted(glob.glob(op.join(opt.dir, "features", f"{split}_image_*")))
for fname in image_feat_files:
image_embeds = torch.load(fname).to("cuda")
sim_matrix.append((text_embeds @ image_embeds.t()))
sim_matrix = torch.cat(sim_matrix, dim=1)
assert len(sim_matrix) == len(gt_rel)
gt_rank = []
top20 = []
for sims, gt_idx in zip(sim_matrix, gt_rel):
_, sorted_idx = torch.sort(sims, descending=True)
idx_rank = torch.argsort(sorted_idx)
gt_rank.append(idx_rank.cpu().numpy()[gt_idx[0]])
top20.append(sorted_idx[:20].cpu().numpy().tolist())
eval_results = get_metric_values_with_rank(gt_rel, gt_rank)
eval_results['metadata'] = (eval_results['metadata'][0], eval_results['metadata'][1], top20)
with open(op.join(opt.dir, f"{split}_results.json"), 'w') as file:
json.dump(eval_results, file)