forked from QData/spacetimeformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
revin.py
89 lines (75 loc) · 2.69 KB
/
revin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Reversible Instance Normalization from
https://github.com/ts-kim/RevIN
"""
import torch
import torch.nn as nn
class MovingAvg(nn.Module):
def __init__(self, kernel_size, stride):
super().__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class SeriesDecomposition(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.moving_avg = MovingAvg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class RevIN(nn.Module):
def __init__(self, num_features: int, eps=1e-5, affine=True):
"""
:param num_features: the number of features or channels
:param eps: a value added for numerical stability
:param affine: if True, RevIN has learnable affine parameters
"""
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self._init_params()
def forward(self, x, mode: str, update_stats=True):
assert x.ndim == 3
if mode == "norm":
if update_stats:
self._get_statistics(x)
x = self._normalize(x)
elif mode == "denorm":
x = self._denormalize(x)
else:
raise NotImplementedError
return x
def _init_params(self):
# initialize RevIN params: (C,)
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
def _get_statistics(self, x):
dim2reduce = tuple(range(1, x.ndim - 1))
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
self.stdev = torch.sqrt(
torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
).detach()
def _normalize(self, x):
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps * self.eps)
x = x * self.stdev
x = x + self.mean
return x