forked from lottery-ticket/code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
is_done_training.py
68 lines (52 loc) · 1.85 KB
/
is_done_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
import argparse
import os
import multiprocessing
def meets_thresh(curr_min, new_val, threshold):
if threshold.endswith('%'):
improvement = 1 - new_val / curr_min
return improvement >= float(threshold[:-1]) / 100
else:
return curr_min - new_val >= float(threshold)
def get_losses(events_file):
import tensorflow as tf
res = []
for e in tf.train.summary_iterator(events_file):
for v in e.summary.value:
if 'loss' in v.tag:
res.append((e.step, v.simple_value))
return res
def listdir(directory):
import tensorflow as tf
return list(tf.gfile.ListDirectory(directory))
def check_dir(directory, iterations_without_improvement, improvement_threshold):
res = []
p = multiprocessing.Pool(1)
event_file_candidates = p.apply(listdir, [directory])
p.close()
for events_file in event_file_candidates:
if not events_file.startswith('events.out'):
continue
p = multiprocessing.Pool(1)
res.extend(p.apply(get_losses, [os.path.join(directory, events_file)]))
p.close()
res = sorted(res)
if len(res) == 0:
return False
thresh_iter, thresh_min = res[0]
print(res[0])
for it, val in res:
if meets_thresh(thresh_min, val, improvement_threshold):
thresh_iter, thresh_min = it, val
print((it, val))
print(res[-1][0])
return res[-1][0] - thresh_iter > iterations_without_improvement
def main():
parser = argparse.ArgumentParser()
parser.add_argument('directory')
parser.add_argument('iterations_without_improvement', type=int)
parser.add_argument('improvement_threshold')
args = parser.parse_args()
print(check_dir(args.directory, args.iterations_without_improvement, args.improvement_threshold))
if __name__ == '__main__':
main()