-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_sent_level.py
120 lines (101 loc) · 3.46 KB
/
eval_sent_level.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding:UTF-8 -*-
# modified from https://github.com/DaDaMrX/ReaLiSe
import argparse
from collections import OrderedDict
import json
def read_file(path):
with open(path, 'r', encoding='utf-8') as f:
rows = [r.strip().split(', ') for r in f.read().splitlines()]
data = []
for row in rows:
item = [row[0]]
data.append(item)
if len(row) == 2 and row[1] == '0':
continue
for i in range(1, len(row), 2):
item.append((int(row[i]), row[i + 1]))
return data
def metric_file(pred_path, targ_path):
preds = read_file(pred_path)
targs = read_file(targ_path)
metrics = OrderedDict()
metrics["Detection"] = sent_metric_detect(preds=preds, targs=targs)
metrics["Correction"] = sent_metric_correct(preds=preds, targs=targs)
return metrics
def sent_metric_detect(preds, targs):
assert len(preds) == len(targs)
tp, targ_p, pred_p, hit = 0, 0, 0, 0
for pred_item, targ_item in zip(preds, targs):
assert pred_item[0] == targ_item[0]
pred, targ = sorted(pred_item[1:]), sorted(targ_item[1:])
if targ != []:
targ_p += 1
if pred != []:
pred_p += 1
if len(pred) == len(targ) and all(p[0] == t[0] for p, t in zip(pred, targ)):
hit += 1
if pred != [] and len(pred) == len(targ) and all(p[0] == t[0] for p, t in zip(pred, targ)):
tp += 1
acc = hit / len(targs)
p = tp / pred_p
r = tp / targ_p
f1 = 2 * p * r / (p + r) if p + r > 0 else 0.0
results = OrderedDict({
'Accuracy': acc * 100,
'Precision': p * 100,
'Recall': r * 100,
'F1': f1 * 100,
})
return results
def sent_metric_correct(preds, targs):
assert len(preds) == len(targs)
tp, targ_p, pred_p, hit = 0, 0, 0, 0
for pred_item, targ_item in zip(preds, targs):
assert pred_item[0] == targ_item[0]
pred, targ = sorted(pred_item[1:]), sorted(targ_item[1:])
if targ != []:
targ_p += 1
if pred != []:
pred_p += 1
if pred == targ:
hit += 1
if pred != [] and pred == targ:
tp += 1
acc = hit / len(targs)
p = tp / pred_p
r = tp / targ_p
f1 = 2 * p * r / (p + r) if p + r > 0 else 0.0
results = OrderedDict({
'Accuracy': acc * 100,
'Precision': p * 100,
'Recall': r * 100,
'F1': f1 * 100,
})
return results
def get_sent_metrics(pred_path, targ_path):
metrics = metric_file(pred_path=pred_path,
targ_path=targ_path)
print("=" * 10 + " Sentence Level " + "=" * 10)
for k, v in metrics.items():
print(f"{k}: ")
print(", ".join([f"{k_i}: {round(v_i, 2)}" for k_i, v_i in v.items()]))
return metrics
def main(args):
metrics = metric_file(
pred_path=args.hyp,
targ_path=args.gold,
)
print("=" * 10 + " Sentence Level " + "=" * 10)
for k, v in metrics.items():
print(f"{k}: ")
print(", ".join([f"{k_i}: {round(v_i, 2)}" for k_i, v_i in v.items()]))
if args.json:
with open(args.json, "w", encoding="utf8") as fw:
fw.write(json.dumps(metrics, ensure_ascii=False, indent=2))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hyp', required=True)
parser.add_argument('--gold', required=True)
parser.add_argument("--json", required=False)
args = parser.parse_args()
main(args)