-
Notifications
You must be signed in to change notification settings - Fork 4
/
plotresults.py
118 lines (84 loc) · 2.96 KB
/
plotresults.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/python
import numpy as np
import matplotlib.pyplot as plt
import sys
import argparse
def loaddata(filename):
try:
fname = str(filename).replace('.dat','')
fname = fname.replace('data/','')
a = np.loadtxt('data/' + fname + '.dat', delimiter=",")
except:
print("Error in loading file ",fname)
return None, None, None
try:
tm = np.array(a[:,1]) # time vector
sv = a[:,2] # score vector
rv = a[:,3] # reward vector
gv = a[:,4] # goal reached vector
# ov = a[:,6] # optimal (no exploration) vector
except: # old version
sv = a[:,0] # score vector
rv = a[:,1] # reward vector
gv = a[:,2] # goal reached vector
tm = range(0,len(rv))
# ov = a[:,4] # optimal (no exploration) vector
return tm, rv, fname
def getplotdata(tm,data):
x = [] # x axis vector
y = [] # y axis vector
ytop = [] # confidence interval (top-edge)
ybot = [] # confidence interval (bottom-edge)
n = len(data)
d = int(n/100) # size of interval
for i in range(0,int(n/d)):
di = data[i*d:min(n,(i+1)*d)]
ti = tm[i*d:min(n,(i+1)*d)]
if (len(ti)>0):
x.append(np.mean(ti))
y.append(np.mean(di))
ytop.append(np.mean(di)+0.5*np.std(di))
ybot.append(np.mean(di)-0.5*np.std(di))
return x,y,ytop,ybot
def showplots(xx,yy,yytop,yybot,yylabel,save):
colors = ['r','b','g','yellow','cyan','magenta']
ytop = max(max(l) for l in yytop)
plt.ylim(bottom = 0, top = ytop*1.2)
plt.title("Average reward")
plt.xlabel('Time')
plt.ylabel('Avg Reward')
for i in range(0,len(xx)):
plt.fill_between(xx[i], yytop[i], yybot[i], facecolor=colors[i], alpha=0.25)
plt.plot(xx[i],yy[i],colors[i],label=yylabel[i])
plt.legend()
if save is not None:
plt.savefig(save)
print('File saved: ',save)
plt.show()
def plotdata(datafiles, save):
xx = []
yy = []
yytop = []
yybot = []
yylabel = []
for f in datafiles:
tm,rv,fname = loaddata(f)
if tm is not None:
x,y,ytop,ybot = getplotdata(tm,rv)
xx += [x]
yy += [y]
yytop += [ytop]
yybot += [ybot]
yylabel += [fname]
if (len(xx)>0):
showplots(xx,yy,yytop,yybot,yylabel,save)
# main
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Plot results')
#parser.add_argument('file', type=str, help='File name with data')
#parser.add_argument('--reward', help='plot reward', action='store_true')
#parser.add_argument('--score', help='plot score', action='store_true')
parser.add_argument('-save', type=str, help='save figure on specified file', default=None)
parser.add_argument('-datafiles', nargs='+', help='[Required] Data files to plot', required=True)
args = parser.parse_args()
plotdata(args.datafiles, args.save)