-
Notifications
You must be signed in to change notification settings - Fork 2
/
base.py
365 lines (310 loc) · 12.4 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import torch
from einops import rearrange
from torch import nn
from torchdiffeq import odeint_adjoint
from basehelper import *
class Tinvariant_NLayerNN(NLayerNN):
def forward(self, t, x):
return super(Tinvariant_NLayerNN, self).forward(x)
class dfwrapper(nn.Module):
def __init__(self, df, shape, recf=None):
super(dfwrapper, self).__init__()
self.df = df
self.shape = shape
self.recf = recf
def forward(self, t, x):
bsize = x.shape[0]
if self.recf:
x = x[:, :-self.recf.osize].reshape(bsize, *self.shape)
dx = self.df(t, x)
dr = self.recf(t, x, dx).reshape(bsize, -1)
dx = dx.reshape(bsize, -1)
dx = torch.cat([dx, dr], dim=1)
else:
x = x.reshape(bsize, *self.shape)
dx = self.df(t, x)
dx = dx.reshape(bsize, -1)
return dx
class NODEintegrate(nn.Module):
def __init__(self, df, shape=None, tol=tol, adjoint=True, evaluation_times=None, recf=None):
"""
Create an OdeRnnBase model
x' = df(x)
x(t0) = x0
:param df: a function that computes derivative. input & output shape [batch, channel, feature]
:param x0: initial condition.
- if x0 is set to be nn.parameter then it can be trained.
- if x0 is set to be nn.Module then it can be computed through some network.
"""
super().__init__()
self.df = dfwrapper(df, shape, recf) if shape else df
self.tol = tol
self.odeint = torchdiffeq.odeint_adjoint if adjoint else torchdiffeq.odeint
self.evaluation_times = evaluation_times if evaluation_times is not None else torch.Tensor([0.0, 1.0])
self.shape = shape
self.recf = recf
if recf:
assert shape is not None
def forward(self, x0):
"""
Evaluate odefunc at given evaluation time
:param x0: shape [batch, channel, feature]. Set to None while training.
:param evaluation_times: time stamps where method evaluates, shape [time]
:param x0stats: statistics to compute x0 when self.x0 is a nn.Module, shape required by self.x0
:return: prediction by ode at evaluation_times, shape [time, batch, channel, feature]
"""
bsize = x0.shape[0]
if self.shape:
assert x0.shape[1:] == torch.Size(self.shape), \
'Input shape {} does not match with model shape {}'.format(x0.shape[1:], self.shape)
x0 = x0.reshape(bsize, -1)
if self.recf:
reczeros = torch.zeros_like(x0[:, :1])
reczeros = repeat(reczeros, 'b 1 -> b c', c=self.recf.osize)
x0 = torch.cat([x0, reczeros], dim=1)
out = odeint(self.df, x0, self.evaluation_times, rtol=self.tol, atol=self.tol)
if self.recf:
rec = out[-1, :, -self.recf.osize:]
out = out[:, :, :-self.recf.osize]
out = out.reshape(-1, bsize, *self.shape)
return out, rec
else:
return out
else:
out = odeint(self.df, x0, self.evaluation_times, rtol=self.tol, atol=self.tol)
return out
@property
def nfe(self):
return self.df.nfe
def to(self, device, *args, **kwargs):
super().to(device, *args, **kwargs)
self.evaluation_times.to(device)
class NODElayer(NODEintegrate):
def forward(self, x0):
out = super(NODElayer, self).forward(x0)
if isinstance(out, tuple):
out, rec = out
return out[-1], rec
else:
return out[-1]
'''
class ODERNN(nn.Module):
def __init__(self, node, rnn, evaluation_times, nhidden):
super(ODERNN, self).__init__()
self.t = torch.as_tensor(evaluation_times).float()
self.n_t = len(self.t)
self.node = node
self.rnn = rnn
self.nhidden = (nhidden,) if isinstance(nhidden, int) else nhidden
def forward(self, x):
assert len(x) == self.n_t
batchsize = x.shape[1]
out = torch.zeros([self.n_t, batchsize, *self.nhidden]).to(x.device)
for i in range(1, self.n_t):
odesol = odeint(self.node, out[i - 1], self.t[i - 1:i + 1])
h_ode = odesol[1]
out[i] = self.rnn(h_ode, x[i])
return out
'''
class NODE(nn.Module):
def __init__(self, df=None, **kwargs):
super(NODE, self).__init__()
self.__dict__.update(kwargs)
self.df = df
self.nfe = 0
self.elem_t = None
def forward(self, t, x):
self.nfe += 1
if self.elem_t is None:
return self.df(t, x)
else:
return self.elem_t * self.df(self.elem_t, x)
def update(self, elem_t):
self.elem_t = elem_t.view(*elem_t.shape, 1)
class SONODE(NODE):
def forward(self, t, x):
"""
Compute [y y']' = [y' y''] = [y' df(t, y, y')]
:param t: time, shape [1]
:param x: [y y'], shape [batch, 2, vec]
:return: [y y']', shape [batch, 2, vec]
"""
self.nfe += 1
v = x[:, 1:, :]
out = self.df(t, x)
return torch.cat((v, out), dim=1)
class HeavyBallNODE(NODE):
def __init__(self, df, actv_h=None, gamma_guess=-3.0, gamma_act='sigmoid', corr=-100, corrf=True, sign=1):
super().__init__(df)
# Momentum parameter gamma
self.gamma = Parameter([gamma_guess], frozen=False)
self.gammaact = nn.Sigmoid() if gamma_act == 'sigmoid' else gamma_act
self.corr = Parameter([corr], frozen=corrf)
self.sp = nn.Softplus()
self.sign = sign # Sign of df
self.actv_h = nn.Identity() if actv_h is None else actv_h # Activation for dh, GHBNODE only
def forward(self, t, x):
"""
Compute [theta' m' v'] with heavy ball parametrization in
$$ h' = -m $$
$$ m' = sign * df - gamma * m $$
based on paper https://www.jmlr.org/papers/volume21/18-808/18-808.pdf
:param t: time, shape [1]
:param x: [theta m], shape [batch, 2, dim]
:return: [theta' m'], shape [batch, 2, dim]
"""
self.nfe += 1
h, m = torch.split(x, 1, dim=1)
dh = self.actv_h(- m)
dm = self.df(t, h) * self.sign - self.gammaact(self.gamma()) * m
dm = dm + self.sp(self.corr()) * h
out = torch.cat((dh, dm), dim=1)
if self.elem_t is None:
return out
else:
return self.elem_t * out
def update(self, elem_t):
self.elem_t = elem_t.view(*elem_t.shape, 1, 1)
HBNODE = HeavyBallNODE # Alias
class AdamNODEs(NODE):
def __init__(self, df, actv_h=None, gamma_guess=-3.0, gamma_act='sigmoid', corr=-100, corrf=True, sign=1):
super().__init__(df)
# Momentum parameter gamma
self.gamma = Parameter([gamma_guess], frozen=False)
self.gammaact = nn.Sigmoid() if gamma_act == 'sigmoid' else gamma_act
self.corr = Parameter([corr], frozen=corrf)
self.corr2 = Parameter([corr], frozen=corrf)
self.sp = nn.Softplus()
self.sign = sign # Sign of df
self.actv_h = nn.Identity() if actv_h is None else actv_h # Activation for dh, GHBNODE only
self.alpha_1 = nn.Parameter(torch.Tensor([-5.0]))
self.alpha_2 = nn.Parameter(torch.Tensor([5.0]))
self.epsilon = 1e-8
self.act = nn.Softplus()
# self.act = nn.ReLU()
def forward(self, t, x):
"""
Compute [theta' m' v'] with heavy ball parametrization in
$$ h' = -m $$
$$ m' = sign * df - gamma * m $$
based on paper https://www.jmlr.org/papers/volume21/18-808/18-808.pdf
:param t: time, shape [1]
:param x: [theta m], shape [batch, 2, dim]
:return: [theta' m'], shape [batch, 2, dim]
"""
self.nfe += 1
# h, m, v = torch.split(x, 1, dim=1)
h, m, v = torch.tensor_split(x, 3, dim=1)
# import pdb; pdb.set_trace()
# dh = self.actv_h(-m) / (torch.sqrt(torch.sigmoid(v))+ self.epsilon)
dh = self.actv_h(-m) / (torch.sqrt(self.act(v))+ self.epsilon)
# dm = self.df(t, h) * self.sign - self.gammaact(self.gamma()) * m
df = self.df(t, h)
dm = torch.sigmoid(self.alpha_1) * (-df - m)
dv = torch.sigmoid(self.alpha_2) * (torch.pow(df,2) - v)
dm = dm + self.sp(self.corr()) * h
dv = dv + self.sp(self.corr2()) * h
out = torch.cat((dh, dm, dv), dim=1)
if self.elem_t is None:
return out
else:
return self.elem_t * out
def update(self, elem_t):
self.elem_t = elem_t.view(*elem_t.shape, 1, 1)
AdamNODE = AdamNODEs # Alias
class ODE_RNN(nn.Module):
def __init__(self, ode, rnn, nhid, ic, rnn_out=False, both=False, tol=1e-7):
super().__init__()
self.ode = ode
self.t = torch.Tensor([0, 1])
self.nhid = [nhid] if isinstance(nhid, int) else nhid
self.rnn = rnn
self.tol = tol
self.rnn_out = rnn_out
self.ic = ic
self.both = both
def forward(self, t, x, multiforecast=None):
"""
--
:param t: [time, batch]
:param x: [time, batch, ...]
:return: [time, batch, *nhid]
"""
n_t, n_b = t.shape
h_ode = torch.zeros(n_t + 1, n_b, *self.nhid, device=x.device)
h_rnn = torch.zeros(n_t + 1, n_b, *self.nhid, device=x.device)
if self.ic:
h_ode[0] = h_rnn[0] = self.ic(rearrange(x, 't b c -> b (t c)')).view(h_ode[0].shape)
if self.rnn_out:
for i in range(n_t):
self.ode.update(t[i])
h_ode[i] = odeint(self.ode, h_rnn[i], self.t, atol=self.tol, rtol=self.tol)[-1]
h_rnn[i + 1] = self.rnn(h_ode[i], x[i])
out = (h_rnn,)
else:
for i in range(n_t):
self.ode.update(t[i])
h_rnn[i] = self.rnn(h_ode[i], x[i])
h_ode[i + 1] = odeint(self.ode, h_rnn[i], self.t, atol=self.tol, rtol=self.tol)[-1]
out = (h_ode,)
if self.both:
out = (h_rnn, h_ode)
if multiforecast is not None:
self.ode.update(torch.ones_like((t[0])))
forecast = odeint(self.ode, out[-1][-1], multiforecast * 1.0, atol=self.tol, rtol=self.tol)
out = (*out, forecast)
return out
class ODE_RNN_with_Grad_Listener(nn.Module):
def __init__(self, ode, rnn, nhid, ic, rnn_out=False, both=False, tol=1e-7):
super().__init__()
self.ode = ode
self.t = torch.Tensor([0, 1])
self.nhid = [nhid] if isinstance(nhid, int) else nhid
self.rnn = rnn
self.tol = tol
self.rnn_out = rnn_out
self.ic = ic
self.both = both
def forward(self, t, x, multiforecast=None, retain_grad=False):
"""
--
:param t: [time, batch]
:param x: [time, batch, ...]
:return: [time, batch, *nhid]
"""
n_t, n_b = t.shape
h_ode = [None] * (n_t + 1)
h_rnn = [None] * (n_t + 1)
h_ode[-1] = h_rnn[-1] = torch.zeros(n_b, *self.nhid)
if self.ic:
h_ode[0] = h_rnn[0] = self.ic(rearrange(x, 't b c -> b (t c)')).view((n_b, *self.nhid))
else:
h_ode[0] = h_rnn[0] = torch.zeros(n_b, *self.nhid, device=x.device)
if self.rnn_out:
for i in range(n_t):
self.ode.update(t[i])
h_ode[i] = odeint(self.ode, h_rnn[i], self.t, atol=self.tol, rtol=self.tol)[-1]
h_rnn[i + 1] = self.rnn(h_ode[i], x[i])
out = (h_rnn,)
else:
for i in range(n_t):
self.ode.update(t[i])
h_rnn[i] = self.rnn(h_ode[i], x[i])
h_ode[i + 1] = odeint(self.ode, h_rnn[i], self.t, atol=self.tol, rtol=self.tol)[-1]
out = (h_ode,)
if self.both:
out = (h_rnn, h_ode)
out = [torch.stack(h, dim=0) for h in out]
if multiforecast is not None:
self.ode.update(torch.ones_like((t[0])))
forecast = odeint(self.ode, out[-1][-1], multiforecast * 1.0, atol=self.tol, rtol=self.tol)
out = (*out, forecast)
if retain_grad:
self.h_ode = h_ode
self.h_rnn = h_rnn
for i in range(n_t + 1):
if self.h_ode[i].requires_grad:
self.h_ode[i].retain_grad()
if self.h_rnn[i].requires_grad:
self.h_rnn[i].retain_grad()
return out