-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
156 lines (136 loc) · 5.48 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import tensorflow as tf
import numpy as np
import time
def masked_softmax_cross_entropy(preds, labels, mask):
"""Softmax cross-entropy loss with masking."""
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.reduce_mean(mask)
loss *= mask
return tf.reduce_mean(loss)
def masked_accuracy(preds, labels, mask):
"""Accuracy with masking."""
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
accuracy_all = tf.cast(correct_prediction, tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.reduce_mean(mask)
accuracy_all *= mask
return tf.reduce_mean(accuracy_all)
def calc_precision_recall(qB, rB, query_L, retrieval_L, eps=2.2204e-16):
"""
calculate precision recall
Input:
query_L: 0-1 label matrix (numQuery * numLabel) for query set.
retrieval_L: 0-1 label matrix (numQuery * numLabel) for retrieval set.
qB: compressed binary code for query set.
rB: compressed binary code for retrieval set.
Output:
Pre: maxR-dims vector. Precision within different hamming radius.
Rec: maxR-dims vector. Recall within different hamming radius.
"""
Wtrue = (np.dot(query_L, np.transpose(retrieval_L)) > 0).astype(int)
Dhamm = calc_hammingDist(qB, rB)
maxHamm = int(np.max(Dhamm))
totalGoodPairs = np.sum(Wtrue)
precision = np.zeros((maxHamm+1, 1))
recall = np.zeros((maxHamm+1, 1))
for i in range(maxHamm+1):
j = (Dhamm <= (i + 0.001)).astype(int)
retrievalPairs = np.sum(j)
retrievalGoodPairs = np.sum(np.multiply(Wtrue, j))
print(retrievalGoodPairs, retrievalPairs)
precision[i] = retrievalGoodPairs * 1.0 / (retrievalPairs + eps)
recall[i] = retrievalGoodPairs * 1.0 / totalGoodPairs
return precision, recall
def calc_map(qB, rB, query_L, retrieval_L):
"""from deep cross modal hashing"""
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# query_L: {0,1}^{mxl}
# retrieval_L: {0,1}^{nxl}
num_query = query_L.shape[0]
map = 0
for iter in xrange(num_query):
gnd = (np.dot(query_L[iter, :], retrieval_L.transpose()) > 0).astype(np.float32)
tsum = np.sum(gnd)
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
ind = np.argsort(hamm)
gnd = gnd[ind]
count = np.linspace(1, tsum, tsum)
tindex = np.asarray(np.where(gnd == 1)) + 1.0
map = map + np.mean(count / (tindex))
map = map / num_query
return map
def calculate_map(test_img_feats_trans, test_txt_vecs_trans, test_labels):
"""Calculate top-50 mAP"""
start = time.time()
avg_precs = []
all_precs = []
all_k = [50]
for k in all_k:
for i in range(len(test_txt_vecs_trans)):
query_label = test_labels[i]
# distances and sort by distances
wv = test_txt_vecs_trans[i]
#dists = calc_l2_norm(wv, test_img_feats_trans)
dists = calc_hammingDist(wv, test_img_feats_trans)
sorted_idx = np.argsort(dists)
# for each k do top-k
precs = []
for topk in range(1, k + 1):
hits = 0
top_k = sorted_idx[0: topk]
# if query_label != test_labels[top_k[-1]]:
# continue
if np.any(query_label != test_labels[top_k[-1]]):
continue
for ii in top_k:
retrieved_label = test_labels[ii]
if np.all(retrieved_label == query_label):
hits += 1
precs.append(float(hits) / float(topk))
if len(precs) == 0:
precs.append(0)
avg_precs.append(np.average(precs))
mean_avg_prec = np.mean(avg_precs)
all_precs.append(mean_avg_prec)
print('[Eval - txt2img] mAP: %f in %4.4fs' % (all_precs[0], (time.time() - start)))
avg_precs = []
all_precs = []
all_k = [50]
for k in all_k:
for i in range(len(test_img_feats_trans)):
query_img_feat = test_img_feats_trans[i]
ground_truth_label = test_labels[i]
# calculate distance and sort
#dists = calc_l2_norm(query_img_feat, test_txt_vecs_trans)
dists = calc_hammingDist(query_img_feat, test_txt_vecs_trans)
sorted_idx = np.argsort(dists)
# for each k in top-k
precs = []
for topk in range(1, k + 1):
hits = 0
top_k = sorted_idx[0: topk]
if np.any(ground_truth_label != test_labels[top_k[-1]]):
continue
for ii in top_k:
retrieved_label = test_labels[ii]
if np.all(ground_truth_label == retrieved_label):
hits += 1
precs.append(float(hits) / float(topk))
if len(precs) == 0:
precs.append(0)
avg_precs.append(np.average(precs))
mean_avg_prec = np.mean(avg_precs)
all_precs.append(mean_avg_prec)
print('[Eval - img2txt] mAP: %f in %4.4fs' % (all_precs[0], (time.time() - start)))
def calc_hammingDist(request, retrieval_all):
K = retrieval_all.shape[1]
distH = 0.5 * (K - np.dot(request, retrieval_all.transpose()))
return distH
def calc_l2_norm(request, retrieval_all):
diffs = retrieval_all - request
dists = np.linalg.norm(diffs, axis=1)
return dists