-
Notifications
You must be signed in to change notification settings - Fork 11
/
normalizer.py
70 lines (64 loc) · 2.71 KB
/
normalizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import threading
import numpy as np
from mpi4py import MPI
class normalizer:
def __init__(self, size, eps=1e-2, default_clip_range=np.inf):
self.size = size
self.eps = eps
self.default_clip_range = default_clip_range
# some local information
self.local_sum = np.zeros(self.size, np.float32)
self.local_sumsq = np.zeros(self.size, np.float32)
self.local_count = np.zeros(1, np.float32)
# get the total sum sumsq and sum count
self.total_sum = np.zeros(self.size, np.float32)
self.total_sumsq = np.zeros(self.size, np.float32)
self.total_count = np.ones(1, np.float32)
# get the mean and std
self.mean = np.zeros(self.size, np.float32)
self.std = np.ones(self.size, np.float32)
# thread locker
self.lock = threading.Lock()
# update the parameters of the normalizer
def update(self, v):
v = v.reshape(-1, self.size)
# do the computing
with self.lock:
self.local_sum += v.sum(axis=0)
self.local_sumsq += (np.square(v)).sum(axis=0)
self.local_count[0] += v.shape[0]
# sync the parameters across the cpus
def sync(self, local_sum, local_sumsq, local_count):
local_sum[...] = self._mpi_average(local_sum)
local_sumsq[...] = self._mpi_average(local_sumsq)
local_count[...] = self._mpi_average(local_count)
return local_sum, local_sumsq, local_count
def recompute_stats(self):
with self.lock:
local_count = self.local_count.copy()
local_sum = self.local_sum.copy()
local_sumsq = self.local_sumsq.copy()
# reset
self.local_count[...] = 0
self.local_sum[...] = 0
self.local_sumsq[...] = 0
# synrc the stats
sync_sum, sync_sumsq, sync_count = self.sync(local_sum, local_sumsq, local_count)
# update the total stuff
self.total_sum += sync_sum
self.total_sumsq += sync_sumsq
self.total_count += sync_count
# calculate the new mean and std
self.mean = self.total_sum / self.total_count
self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count)))
# average across the cpu's data
def _mpi_average(self, x):
buf = np.zeros_like(x)
MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM)
buf /= MPI.COMM_WORLD.Get_size()
return buf
# normalize the observation
def normalize(self, v, clip_range=None):
if clip_range is None:
clip_range = self.default_clip_range
return np.clip((v - self.mean) / (self.std), -clip_range, clip_range)