forked from bgshih/aster
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluator.py
155 lines (134 loc) · 5.67 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import logging
import tensorflow as tf
import numpy as np
import editdistance
from aster.core import preprocessor
from aster.core import prefetcher
from aster.core import standard_fields as fields
from aster.builders import preprocessor_builder
from aster import eval_util
EVAL_METRICS_FN_DICT = {
'recognition_metrics': eval_util.evaluate_recognition_results,
}
def _extract_prediction_tensors(model,
create_input_dict_fn,
data_preprocessing_steps,
ignore_groundtruth=False,
evaluate_with_lexicon=False):
# input queue
input_dict = create_input_dict_fn()
prefetch_queue = prefetcher.prefetch(input_dict, capacity=500)
input_dict = prefetch_queue.dequeue()
original_image = tf.to_float(input_dict[fields.InputDataFields.image])
original_image_shape = tf.shape(original_image)
input_dict[fields.InputDataFields.image] = original_image
# data preprocessing
preprocessed_input_dict = preprocessor.preprocess(input_dict, data_preprocessing_steps)
# model inference
preprocessed_image = preprocessed_input_dict[fields.InputDataFields.image]
preprocessed_image_shape = tf.shape(preprocessed_image)
predictions_dict = model.predict(tf.expand_dims(preprocessed_image, 0))
recognitions = model.postprocess(predictions_dict)
def _lexicon_search(lexicon, word):
edit_distances = []
for lex_word in lexicon:
edit_distances.append(editdistance.eval(lex_word.lower(), word.lower()))
edit_distances = np.asarray(edit_distances, dtype=np.int)
argmin = np.argmin(edit_distances)
return lexicon[argmin]
if evaluate_with_lexicon:
lexicon = input_dict[fields.InputDataFields.lexicon]
recognition_text = tf.py_func(
_lexicon_search,
[lexicon, recognitions['text'][0]],
tf.string,
stateful=False,
)
else:
recognition_text = recognitions['text'][0]
tensor_dict = {
'original_image': original_image,
'original_image_shape': original_image_shape,
'preprocessed_image_shape': preprocessed_image_shape,
'filename': preprocessed_input_dict[fields.InputDataFields.filename],
'groundtruth_text': input_dict[fields.InputDataFields.groundtruth_text],
'recognition_text': recognition_text,
}
if 'control_points' in predictions_dict:
tensor_dict.update({
'control_points': predictions_dict['control_points'],
'rectified_images': predictions_dict['rectified_images']
})
return tensor_dict
def evaluate(create_input_dict_fn, create_model_fn, eval_config,
checkpoint_dir, eval_dir,
repeat_evaluation=True):
model = create_model_fn()
data_preprocessing_steps = [
preprocessor_builder.build(step)
for step in eval_config.data_preprocessing_steps]
tensor_dict = _extract_prediction_tensors(
model=model,
create_input_dict_fn=create_input_dict_fn,
data_preprocessing_steps=data_preprocessing_steps,
ignore_groundtruth=eval_config.ignore_groundtruth,
evaluate_with_lexicon=eval_config.eval_with_lexicon)
summary_writer = tf.summary.FileWriter(eval_dir)
def _process_batch(tensor_dict, sess, batch_index, counters, update_op):
if batch_index >= eval_config.num_visualizations:
if 'original_image' in tensor_dict:
tensor_dict = {k: v for (k, v) in tensor_dict.items()
if k != 'original_image'}
try:
(result_dict, _) = sess.run([tensor_dict, update_op])
counters['success'] += 1
except tf.errors.InvalidArgumentError:
logging.info('Skipping image')
counters['skipped'] += 1
return {}
global_step = tf.train.global_step(sess, tf.train.get_global_step())
if batch_index < eval_config.num_visualizations:
eval_util.visualize_recognition_results(
result_dict,
'Recognition_{}'.format(batch_index),
global_step,
summary_dir=eval_dir,
export_dir=os.path.join(eval_dir, 'vis'),
summary_writer=summary_writer,
only_visualize_incorrect=eval_config.only_visualize_incorrect)
return result_dict
def _process_aggregated_results(result_lists):
eval_metric_fn_key = eval_config.metrics_set
if eval_metric_fn_key not in EVAL_METRICS_FN_DICT:
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key))
return EVAL_METRICS_FN_DICT[eval_metric_fn_key](result_lists)
variables_to_restore = tf.global_variables()
global_step = tf.train.get_or_create_global_step()
variables_to_restore.append(global_step)
if eval_config.use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
def _restore_latest_checkpoint(sess):
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, latest_checkpoint)
eval_util.repeated_checkpoint_run(
tensor_dict=tensor_dict,
update_op=tf.no_op(),
summary_dir=eval_dir,
aggregated_result_processor=_process_aggregated_results,
batch_processor=_process_batch,
checkpoint_dirs=[checkpoint_dir],
variables_to_restore=None,
restore_fn=_restore_latest_checkpoint,
num_batches=eval_config.num_examples,
eval_interval_secs=eval_config.eval_interval_secs,
max_number_of_evaluations=(
1 if eval_config.ignore_groundtruth else
eval_config.max_evals if eval_config.max_evals else
None if repeat_evaluation else 1),
master=eval_config.eval_master,
save_graph=eval_config.save_graph,
save_graph_dir=(eval_dir if eval_config.save_graph else ''))
summary_writer.close()