forked from bgshih/aster
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
142 lines (113 loc) · 4.9 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import functools
import logging
import tensorflow as tf
from google.protobuf import text_format
from aster import evaluator
from aster.protos import eval_pb2
from aster.protos import pipeline_pb2
from aster.protos import input_reader_pb2
from aster.builders import model_builder
from aster.builders import input_reader_builder
logging.getLogger('tensorflow').propagate = False # avoid logging duplicates
tf.logging.set_verbosity(tf.logging.INFO)
logging.basicConfig(level=logging.INFO)
flags = tf.app.flags
flags.DEFINE_boolean('repeat', True, 'If true, evaluate repeatedly.')
flags.DEFINE_boolean('eval_training_data', False,
'If training data should be evaluated for this job.')
flags.DEFINE_string('checkpoint_dir', '',
'Directory containing checkpoints to evaluate, typically '
'set to `train_dir` used in the training job.')
flags.DEFINE_string('exp_dir', '',
'Directory containing config, training log and evaluations')
flags.DEFINE_string('eval_dir', '',
'Directory to write eval summaries to.')
flags.DEFINE_string('pipeline_config_path', '',
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file. If provided, other configs are ignored')
flags.DEFINE_string('eval_config_path', '',
'Path to an eval_pb2.EvalConfig config file.')
flags.DEFINE_string('input_config_path', '',
'Path to an input_reader_pb2.InputReader config file.')
flags.DEFINE_string('model_config_path', '',
'Path to a model_pb2.DetectionModel config file.')
FLAGS = flags.FLAGS
def get_configs_from_exp_dir():
pipeline_config_path = os.path.join(FLAGS.exp_dir, 'config/trainval.prototxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
model_config = pipeline_config.model
if FLAGS.eval_training_data:
eval_config = pipeline_config.train_config
else:
eval_config = pipeline_config.eval_config
input_config = pipeline_config.eval_input_reader
return model_config, eval_config, input_config
def get_configs_from_pipeline_file():
"""Reads evaluation configuration from a pipeline_pb2.TrainEvalPipelineConfig.
Reads evaluation config from file specified by pipeline_config_path flag.
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
model_config = pipeline_config.model
if FLAGS.eval_training_data:
eval_config = pipeline_config.train_config
else:
eval_config = pipeline_config.eval_config
input_config = pipeline_config.eval_input_reader
return model_config, eval_config, input_config
def get_configs_from_multiple_files():
"""Reads evaluation configuration from multiple config files.
Reads the evaluation config from the following files:
model_config: Read from --model_config_path
eval_config: Read from --eval_config_path
input_config: Read from --input_config_path
Returns:
model_config: a model_pb2.DetectionModel
eval_config: a eval_pb2.EvalConfig
input_config: a input_reader_pb2.InputReader
"""
eval_config = eval_pb2.EvalConfig()
with tf.gfile.GFile(FLAGS.eval_config_path, 'r') as f:
text_format.Merge(f.read(), eval_config)
model_config = model_pb2.DetectionModel()
with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f:
text_format.Merge(f.read(), model_config)
input_config = input_reader_pb2.InputReader()
with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f:
text_format.Merge(f.read(), input_config)
return model_config, eval_config, input_config
def main(unused_argv):
if FLAGS.exp_dir:
checkpoint_dir = os.path.join(FLAGS.exp_dir, 'log')
eval_dir = os.path.join(FLAGS.exp_dir, 'log/eval')
model_config, eval_config, input_config = get_configs_from_exp_dir()
else:
assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
assert FLAGS.eval_dir, '`eval_dir` is missing.'
if FLAGS.pipeline_config_path:
model_config, eval_config, input_config = get_configs_from_pipeline_file()
else:
model_config, eval_config, input_config = get_configs_from_multiple_files()
checkpoint_dir = FLAGS.checkpoint_dir
eval_dir = FLAGS.eval_dir
model_fn = functools.partial(
model_builder.build,
config=model_config,
is_training=False
)
create_input_dict_fn = functools.partial(
input_reader_builder.build,
input_config)
evaluator.evaluate(create_input_dict_fn, model_fn, eval_config,
checkpoint_dir, eval_dir,
repeat_evaluation=FLAGS.repeat)
if __name__ == '__main__':
tf.app.run()