-
Notifications
You must be signed in to change notification settings - Fork 0
/
valuenorm.py
97 lines (79 loc) · 3.26 KB
/
valuenorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# From https://github.com/PKU-MARL/Multi-Agent-Transformer/blob/main/mat/utils/valuenorm.py
import numpy as np
import torch
import torch.nn as nn
class ValueNorm(nn.Module):
"""Normalize a vector of observations - across the first norm_axes dimensions"""
def __init__(
self,
input_shape,
norm_axes=1,
beta=0.99999,
per_element_update=False,
epsilon=1e-5,
device=torch.device("cpu"),
):
super(ValueNorm, self).__init__()
self.input_shape = input_shape
self.norm_axes = norm_axes
self.epsilon = epsilon
self.beta = beta
self.per_element_update = per_element_update
self.tpdv = dict(dtype=torch.float32, device=device)
self.running_mean = nn.Parameter(
torch.zeros(input_shape), requires_grad=False
).to(**self.tpdv)
self.running_mean_sq = nn.Parameter(
torch.zeros(input_shape), requires_grad=False
).to(**self.tpdv)
self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(
**self.tpdv
)
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_mean_sq.zero_()
self.debiasing_term.zero_()
def running_mean_var(self):
debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(
min=self.epsilon
)
debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2)
return debiased_mean, debiased_var
@torch.no_grad()
def update(self, input_vector):
if type(input_vector) == np.ndarray:
input_vector = torch.from_numpy(input_vector)
input_vector = input_vector.to(**self.tpdv)
batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes)))
batch_sq_mean = (input_vector**2).mean(dim=tuple(range(self.norm_axes)))
if self.per_element_update:
batch_size = np.prod(input_vector.size()[: self.norm_axes])
weight = self.beta**batch_size
else:
weight = self.beta
self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))
def normalize(self, input_vector):
# Make sure input is float32
if type(input_vector) == np.ndarray:
input_vector = torch.from_numpy(input_vector)
input_vector = input_vector.to(**self.tpdv)
mean, var = self.running_mean_var()
out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[
(None,) * self.norm_axes
]
return out
def denormalize(self, input_vector):
"""Transform normalized data back into original distribution"""
if type(input_vector) == np.ndarray:
input_vector = torch.from_numpy(input_vector)
input_vector = input_vector.to(**self.tpdv)
mean, var = self.running_mean_var()
out = (
input_vector * torch.sqrt(var)[(None,) * self.norm_axes]
+ mean[(None,) * self.norm_axes]
)
return out