-
Notifications
You must be signed in to change notification settings - Fork 538
/
evaluate.py
99 lines (73 loc) · 3.12 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
from collections import Counter
from models.hmm import HMM
from models.crf import CRFModel
from models.bilstm_crf import BILSTM_Model
from utils import save_model, flatten_lists
from evaluating import Metrics
def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_O=False):
"""训练并评估hmm模型"""
# 训练HMM模型
train_word_lists, train_tag_lists = train_data
test_word_lists, test_tag_lists = test_data
hmm_model = HMM(len(tag2id), len(word2id))
hmm_model.train(train_word_lists,
train_tag_lists,
word2id,
tag2id)
save_model(hmm_model, "./ckpts/hmm.pkl")
# 评估hmm模型
pred_tag_lists = hmm_model.test(test_word_lists,
word2id,
tag2id)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists
def crf_train_eval(train_data, test_data, remove_O=False):
# 训练CRF模型
train_word_lists, train_tag_lists = train_data
test_word_lists, test_tag_lists = test_data
crf_model = CRFModel()
crf_model.train(train_word_lists, train_tag_lists)
save_model(crf_model, "./ckpts/crf.pkl")
pred_tag_lists = crf_model.test(test_word_lists)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists
def bilstm_train_and_eval(train_data, dev_data, test_data,
word2id, tag2id, crf=True, remove_O=False):
train_word_lists, train_tag_lists = train_data
dev_word_lists, dev_tag_lists = dev_data
test_word_lists, test_tag_lists = test_data
start = time.time()
vocab_size = len(word2id)
out_size = len(tag2id)
bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
bilstm_model.train(train_word_lists, train_tag_lists,
dev_word_lists, dev_tag_lists, word2id, tag2id)
model_name = "bilstm_crf" if crf else "bilstm"
save_model(bilstm_model, "./ckpts/"+model_name+".pkl")
print("训练完毕,共用时{}秒.".format(int(time.time()-start)))
print("评估{}模型中...".format(model_name))
pred_tag_lists, test_tag_lists = bilstm_model.test(
test_word_lists, test_tag_lists, word2id, tag2id)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists
def ensemble_evaluate(results, targets, remove_O=False):
"""ensemble多个模型"""
for i in range(len(results)):
results[i] = flatten_lists(results[i])
pred_tags = []
for result in zip(*results):
ensemble_tag = Counter(result).most_common(1)[0][0]
pred_tags.append(ensemble_tag)
targets = flatten_lists(targets)
assert len(pred_tags) == len(targets)
print("Ensemble 四个模型的结果如下:")
metrics = Metrics(targets, pred_tags, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()