-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
131 lines (107 loc) · 3.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import csv
import torch
import torch.multiprocessing as mp
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
print("--> Running on the CUDA")
elif torch.backends.mps.is_available():
device = torch.device("mps")
print("--> Running on the Metal")
else:
device = torch.device("cpu")
print("--> Running on the CPU")
return device
def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(),
shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad
def plot_curve(csv_path, save_path, algorithm):
''' Read data from csv file and plot the results
'''
import os
import csv
import matplotlib.pyplot as plt
with open(csv_path) as csvfile:
reader = csv.DictReader(csvfile)
xs = []
ys = []
for row in reader:
xs.append(int(row['timestep']))
ys.append(float(row['reward']))
fig, ax = plt.subplots()
ax.plot(xs, ys, label=algorithm)
ax.set(xlabel='timestep', ylabel='reward')
ax.legend()
ax.grid()
save_dir = os.path.dirname(save_path)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
fig.savefig(save_path)
# https://github.com/alexis-jacq/Pytorch-DPPO/blob/master/utils.py
class Counter:
"""enable the chief to access worker's total number of updates"""
def __init__(self, val=True):
self.val = mp.Value("i", 0)
self.lock = mp.Lock()
def get(self):
# used by chief
with self.lock:
return self.val.value
def increment(self):
# used by workers
with self.lock:
self.val.value += 1
def reset(self):
# used by chief
with self.lock:
self.val.value = 0
class Logger(object):
''' Logger saves the running results and helps make plots from the results
'''
def __init__(self, log_dir):
''' Initialize the labels, legend and paths of the plot and log file.
Args:
log_path (str): The path the log files
'''
self.log_dir = log_dir
def __enter__(self):
self.txt_path = os.path.join(self.log_dir, 'log.txt')
self.csv_path = os.path.join(self.log_dir, 'performance.csv')
self.fig_path = os.path.join(self.log_dir, 'fig.png')
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.txt_file = open(self.txt_path, 'w')
self.csv_file = open(self.csv_path, 'w')
fieldnames = ['timestep', 'reward', 'loss']
self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
self.writer.writeheader()
return self
def log(self, text):
''' Write the text to log file.
Args:
text(string): text to log
'''
self.txt_file.write(text+'\n')
self.txt_file.flush()
def log_performance(self, timestep, reward, loss):
''' Log a point in the curve
Args:
timestep (int): the timestep of the current point
reward (float): the reward of the current point
'''
self.writer.writerow({'timestep': timestep, 'reward': reward, 'loss': loss})
self.log('----------------------------------------')
self.log(' timestep | ' + str(timestep))
self.log(' reward | ' + str(reward))
self.log(' loss | ' + str(loss))
self.log('----------------------------------------')
def __exit__(self, type, value, traceback):
if self.txt_path is not None:
self.txt_file.close()
if self.csv_path is not None:
self.csv_file.close()
print('\nLogs saved in', self.log_dir)