This repository has been archived by the owner on May 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_merged_metadata.py
executable file
·92 lines (75 loc) · 2.5 KB
/
plot_merged_metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import os
import sys
import cPickle as pickle
from matplotlib import pyplot as plt
import numpy as np
def listdir(root_dir, predicate=lambda f: True):
result = []
for (p, d, filenames) in os.walk(root_dir):
if not filenames:
continue
result.extend([os.path.join(p, f) for f in filenames if predicate(f)])
return result
def merge_metadata(data_dir):
merged_data = {}
for filename in listdir(data_dir, lambda f: f.endswith('.pkl')):
with open(filename) as pkl:
part = pickle.load(pkl)
for k, v in part.iteritems():
if k in merged_data:
merged_data[k].extend(v)
else:
merged_data[k] = v
print('merged {}'.format(filename))
for k in merged_data:
merged_data[k] = sorted(merged_data[k], key=lambda v: v[0])
return merged_data
def average_sample(x, step):
return [np.mean(x[i:i + step]) for i in xrange(0, len(x), step)]
def main(data_dir):
plt.figure()
# plot overall loss
plt.subplot(2, 1, 1)
data = merge_metadata(data_dir)
reconstruction_loss = data['reconstruction_loss']
steps, losses = zip(*reconstruction_loss)
plt.plot(steps, losses, color='g', label='reconstruction loss')
train_loss = data['train_loss']
steps, losses = zip(*train_loss)
plt.plot(steps, losses, label='train loss')
dev_loss = data['dev_loss']
steps, losses = zip(*dev_loss)
plt.plot(steps, losses, color='r', label='dev loss')
plt.xlabel('step')
plt.ylabel('loss')
plt.legend(loc='upper left')
plt.ylim((0, 10))
# plot vae loss, and annealing weight
ax1 = plt.subplot(2, 1, 2)
annealing_weight = data['annealing_weight']
steps, weights = zip(*annealing_weight)
ax1.plot(steps, weights, 'b-', label='KL term weight', lw=2.0)
plt.ylim(0, 1)
plt.ylabel('KL term weight')
plt.xlabel('step')
ax2 = ax1.twinx()
kl_loss = data['kl_loss']
steps, losses = zip(*kl_loss)
ax2.plot(steps, losses, 'r-', label='KL term value', lw=2.0)
plt.ylim(0, 8)
plt.yticks(np.linspace(0, 8, 9))
plt.ylabel('KL term value')
plt.show()
if __name__ == '__main__':
args = sys.argv[1:]
if not args:
print("""usage:
{} log_dir
""".format(__file__))
sys.exit()
main(args[0])