-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
130 lines (115 loc) · 3.97 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# @File :
# @Author: Cuiqingyao
# @Date : 2018/12/11
# @Desc :
# @Contact: [email protected]
from cut_model import HMM, MechanicalSegmentation, Bigram
from util import read_from_txt
from settings import *
class statistic(object):
"""
计算模型的准确率,召回率,F1值
"""
def __init__(self, model):
# 要统计的模型
self.model = model()
# 原始词数
self.origin_words_num = 0
# 总分词数
self.total_seg_words_num = 0
# 正确分词数
self.correct_words_num = 0
# 错误分词数
self.incor_words_num = 0
def __statistic(self, test_data_file, answer_file):
'''
基于测试数据,统计分词总数,原始词数,正确分词数,错误分词数
:param test_data_file:
:param answer_file:
:return:
'''
answers = read_from_txt(answer_file)
segment_words = self.model.segment(file=test_data_file)
# 判断是否完成完整的分词,分词的答案行数与模型分词的行数必须一致
# assert len(answers) == len(segment_words)
# 统计原始分词数
for answer in answers:
self.origin_words_num += len(answer)
# 统计分词总数
for words in segment_words:
self.total_seg_words_num += len(words)
# 统计正确分词个数 和 错误分词个数
for idx, answer in enumerate(answers):
for word in segment_words[idx]:
if word in answer:
self.correct_words_num += 1
else:
self.incor_words_num += 1
# 保证正确分词个数加错误分词个数 与 分词总数相等
assert self.total_seg_words_num == self.correct_words_num + self.incor_words_num
def precision(self):
'''
计算准确率
:return:
'''
return self.correct_words_num / self.origin_words_num
def recall(self):
'''
计算召回率
:return:
'''
return self.correct_words_num / self.total_seg_words_num
def F1(self):
'''
计算F1值
:return:
'''
return 2 * self.precision() * self.recall() / (self.precision() + self.recall())
def print_report(self):
'''
打印统计报告
:return:
'''
print('当前分词模型:',self.model)
print('原始分词数量(origin_words_num):', self.origin_words_num)
print('总分词数量(total_seg_words_num):', self.total_seg_words_num)
print('正确分词数量(correct_words_num):', self.correct_words_num)
print('错误分词数量(incor_words_num):', self.incor_words_num)
print('准确率(precision):', self.precision())
print('召回率(recall):', self.recall())
print('F1值:', self.F1())
def evaluation(self, test_data_file, answer_file):
'''
评估程序接口
:param test_data_file: 测试数据
:param answer_file: 正确分词答案
:return:
'''
self.__statistic(test_data_file=test_data_file, answer_file=answer_file)
self.print_report()
if __name__ == '__main__':
evaluation_model = None
import sys
if len(sys.argv) == 1:
evaluation_model = MechanicalSegmentation
elif len(sys.argv) == 2:
para = sys.argv[1]
# print(para)
if para == 'bigram':
evaluation_model = Bigram
elif para == 'hmm' :
evaluation_model = HMM
elif para == 'ms':
IS_COMBINE = False
evaluation_model = MechanicalSegmentation
else:
print("不合法的参数!")
exit()
else:
print("输入参数错误!")
exit()
if evaluation_model:
stt = statistic(model=evaluation_model)
stt.evaluation(test_data_file=TEST_FILE, answer_file=ANSWERS)
else:
print("模型输入错误!")