-
Notifications
You must be signed in to change notification settings - Fork 1
/
ConvLSTMCell.py
48 lines (34 loc) · 1.64 KB
/
ConvLSTMCell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, padding, activation, frame_size):
super(ConvLSTMCell, self).__init__()
if activation == "tanh":
self.activation = torch.tanh
elif activation == "relu":
self.activation = torch.relu
# Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
self.conv = nn.Conv2d(
in_channels=in_channels + out_channels,
out_channels=4 * out_channels,
kernel_size=kernel_size,
padding=padding)
# Initialize weights for Hadamard Products
self.W_ci = nn.Parameter(torch.Tensor(out_channels, *frame_size))
self.W_co = nn.Parameter(torch.Tensor(out_channels, *frame_size))
self.W_cf = nn.Parameter(torch.Tensor(out_channels, *frame_size))
def forward(self, X, H_prev, C_prev):
# Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
conv_output = self.conv(torch.cat([X, H_prev], dim=1))
# Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)
input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )
# Current Cell output
C = forget_gate*C_prev + input_gate * self.activation(C_conv)
output_gate = torch.sigmoid(o_conv + self.W_co * C )
# Current Hidden State
H = output_gate * self.activation(C)
return H, C