-
Notifications
You must be signed in to change notification settings - Fork 30
/
reid_metric.py
48 lines (40 loc) · 1.4 KB
/
reid_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# encoding: utf-8
"""
@author: liaoxingyu
@contact: [email protected]
"""
import numpy as np
import torch
from ignite.metrics import Metric
from data.datasets.eval_reid import eval_func
class R1_mAP(Metric):
def __init__(self, num_query, max_rank=50):
super(R1_mAP, self).__init__()
self.num_query = num_query
self.max_rank = max_rank
def reset(self):
self.feats = []
self.pids = []
self.camids = []
def update(self, output):
feat, pid, camid = output
self.feats.append(feat)
self.pids.extend(np.asarray(pid))
self.camids.extend(np.asarray(camid))
def compute(self):
feats = torch.cat(self.feats, dim=0)
# query
qf = feats[:self.num_query]
q_pids = np.asarray(self.pids[:self.num_query])
q_camids = np.asarray(self.camids[:self.num_query])
# gallery
gf = feats[self.num_query:]
g_pids = np.asarray(self.pids[self.num_query:])
g_camids = np.asarray(self.camids[self.num_query:])
m, n = qf.shape[0], gf.shape[0]
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.cpu().numpy()
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
return cmc, mAP