forked from QData/spacetimeformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtime2vec.py
36 lines (33 loc) · 1.44 KB
/
time2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
class Time2Vec(nn.Module):
def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
assert embed_dim % input_dim == 0
super(Time2Vec, self).__init__()
self.enabled = embed_dim > 0
if self.enabled:
self.embed_dim = embed_dim // input_dim
self.input_dim = input_dim
self.embed_weight = nn.parameter.Parameter(
torch.randn(self.input_dim, self.embed_dim)
)
self.embed_bias = nn.parameter.Parameter(
torch.randn(self.input_dim, self.embed_dim)
)
self.act_function = act_function
def forward(self, x):
if self.enabled:
x = torch.diag_embed(x)
# x.shape = (bs, sequence_length, input_dim, input_dim)
x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
# x_affine.shape = (bs, sequence_length, input_dim, time_embed_dim)
x_affine_0, x_affine_remain = torch.split(
x_affine, [1, self.embed_dim - 1], dim=-1
)
x_affine_remain = self.act_function(x_affine_remain)
x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
# x_output.shape = (bs, sequence_length, input_dim * time_embed_dim)
else:
x_output = x
return x_output