-
Notifications
You must be signed in to change notification settings - Fork 6
/
eval_utils.py
105 lines (90 loc) · 3.1 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from collections import defaultdict
class DSTEvaluator:
def __init__(self, slot_meta):
self.slot_meta = slot_meta
self.init()
def init(self):
self.joint_goal_hit = 0
self.all_hit = 0
self.slot_turn_acc = 0
self.slot_F1_pred = 0
self.slot_F1_count = 0
def update(self, gold, pred):
self.all_hit += 1
if set(pred) == set(gold):
self.joint_goal_hit += 1
temp_acc = compute_acc(gold, pred, self.slot_meta)
self.slot_turn_acc += temp_acc
temp_f1, _, _, count = compute_prf(gold, pred)
self.slot_F1_pred += temp_f1
self.slot_F1_count += count
def compute(self):
turn_acc_score = self.slot_turn_acc / self.all_hit
slot_F1_score = self.slot_F1_pred / self.slot_F1_count
joint_goal_accuracy = self.joint_goal_hit / self.all_hit
eval_result = {
"joint_goal_accuracy": joint_goal_accuracy,
"turn_slot_accuracy": turn_acc_score,
"turn_slot_f1": slot_F1_score,
}
return eval_result
def compute_acc(gold, pred, slot_meta):
miss_gold = 0
miss_slot = []
for g in gold:
if g not in pred:
miss_gold += 1
miss_slot.append(g.rsplit("-", 1)[0])
wrong_pred = 0
for p in pred:
if p not in gold and p.rsplit("-", 1)[0] not in miss_slot:
wrong_pred += 1
ACC_TOTAL = len(slot_meta)
ACC = len(slot_meta) - miss_gold - wrong_pred
ACC = ACC / float(ACC_TOTAL)
return ACC
def compute_prf(gold, pred):
TP, FP, FN = 0, 0, 0
if len(gold) != 0:
count = 1
for g in gold:
if g in pred:
TP += 1
else:
FN += 1
for p in pred:
if p not in gold:
FP += 1
precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
F1 = (
2 * precision * recall / float(precision + recall)
if (precision + recall) != 0
else 0
)
else:
if len(pred) == 0:
precision, recall, F1, count = 1, 1, 1, 1
else:
precision, recall, F1, count = 0, 0, 0, 1
return F1, recall, precision, count
def eval_wrong_count(predictions, dev_labels):
wrong_value = defaultdict(float)
wrong_slot = defaultdict(float)
total = 0
for key in predictions.keys():
label_list = dev_labels[key]
pred_list = predictions[key]
for label in label_list:
domain, slot, value = label.split('-')
if label not in pred_list:
wrong_value[label] += 1
wrong_slot[domain + '-' + slot] += 1
total += 1
wrong_value = dict(sorted(wrong_value.items(), key=(lambda x: x[1]), reverse=True)[:10])
wrong_slot = dict(sorted(wrong_slot.items(), key=(lambda x: x[1]), reverse=True)[:10])
for key in wrong_value.keys():
wrong_value[key] /= total
for key in wrong_slot.keys():
wrong_slot[key] /= total
return wrong_value, wrong_slot