-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a more general sequence LSTM implementation
- Loading branch information
1 parent
c20e98c
commit 0813ccc
Showing
1 changed file
with
84 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import numpy as np | ||
|
||
""" | ||
There are many formulations of LSTMs. This code follows the formulation from | ||
https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf with some simplifications | ||
""" | ||
|
||
def sigmoid(x): | ||
return 1.0 / (1.0 + np.exp(-x)) | ||
|
||
# LSTM functions | ||
def lstm_step(x, hidden_state, cell_state, Wx, Wh, b): | ||
""" | ||
Forward pass for a single timestep of an LSTM. | ||
The input data has dimension D, the hidden state has dimension H, and we use | ||
a minibatch size of B. | ||
Inputs: | ||
x: Input data, of shape (S, D) | ||
hidden_state: Previous hidden state, of shape (S, H) | ||
cell_state: previous cell state, of shape (S, H) | ||
Wx: Input-to-hidden weights, of shape (D, 4*H) | ||
Wh: Hidden-to-hidden weights, of shape (H, 4*H) | ||
b: Biases, of shape (4*H) | ||
Returns a tuple of: | ||
next_hidden_state: Next hidden state, of shape (B, H) | ||
next_cell_state: Next cell state, of shape (B, H) | ||
""" | ||
_, hidden_size = hidden_state.shape | ||
initial = hidden_state.dot(Wh) + x.dot(Wx) + b | ||
|
||
ingate = sigmoid(initial[:, 0:hidden_size]) | ||
forget = sigmoid(initial[:, hidden_size:2*hidden_size]) | ||
outgate = sigmoid(initial[:, 2*hidden_size:3*hidden_size]) | ||
change = np.tanh(initial[:, 3*hidden_size:4*hidden_size]) | ||
|
||
next_cell_state = forget * cell_state + ingate * change | ||
next_hidden_state = outgate * (np.tanh(next_cell_state)) | ||
|
||
return next_hidden_state, next_cell_state | ||
|
||
def lstm(x, hidden_size): | ||
""" | ||
For sequence length S, input dimesnion D, hidden_size H | ||
Run an LSTM over T timesteps | ||
Inputs: | ||
x: Input data of shape T x (S, D) | ||
hidden_size: hidden dimension of the LSTM | ||
Returns a tuple of: | ||
all_hidden: Hidden states for all timesteps of all sequences, of shape T x (S, H) | ||
cell_state: Last cell state of shape (S, H) | ||
""" | ||
if len(x) == 0: | ||
return None, None | ||
|
||
seq_len, input_dim = x[0].shape | ||
|
||
hidden_state = np.zeros((seq_len, hidden_size), dtype=np.float64) | ||
cell_state = np.zeros((seq_len, hidden_size), dtype=np.float64) | ||
Wx = np.random.rand(input_dim, 4*hidden_size) | ||
Wh = np.random.rand(hidden_size, 4*hidden_size) | ||
b = np.zeros((4*hidden_size)) | ||
|
||
all_hidden = [] | ||
for xt in x: | ||
next_hidden_state, next_cell_state = lstm_step(xt, hidden_state, cell_state, Wx, Wh, b) | ||
hidden_state = next_hidden_state | ||
cell_state = next_cell_state | ||
all_hidden.append(next_hidden_state) | ||
|
||
return all_hidden, cell_state | ||
|
||
|
||
### Test the LSTM with some sample data | ||
T, S, H, D = 17, 13, 7, 5 | ||
x = [np.random.rand(S,D) for _ in range(T)] | ||
## Pass through an LSTM for T timesteps. | ||
# This will produce the following output | ||
# input sequence length: 13, input dimension: 5, hidden size: 7, time steps 17 | ||
# input shape: (13, 5) | ||
# hidden state size: 17 x (13, 7) cell state size: (13, 7) | ||
print(f"input sequence length: {S}, input dimension: {D}, hidden size: {H}, time steps {T} \n") | ||
print(f"input shape: {x[0].shape} ") | ||
hidden_state, cell_state = lstm(x, H) | ||
print(f"hidden state size: {len(hidden_state)} x {hidden_state[0].shape} cell state size: {cell_state.shape}") |