Skip to content

Commit

Permalink
feat(LSTM): vanishing gradient + gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Nov 9, 2024
1 parent d9fac00 commit dfee511
Showing 1 changed file with 114 additions and 53 deletions.
167 changes: 114 additions & 53 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,56 +1383,88 @@ def from_config(config: dict):


class LSTMCell:
def __init__(self, input_dim: int, units: int, random_state: int | None = None):
self.input_dim = input_dim
def __init__(self, units: int, random_state: int | None = None, clip_value: float = 5.0):
self.units = units
self.random_state = random_state
self.clip_value = clip_value
self.rng = np.random.default_rng(
random_state if random_state is not None else int(time.time_ns()))

self.Wf = None
self.Uf = None
self.bf = None

self.Wi = None
self.Ui = None
self.bi = None

self.Wc = None
self.Uc = None
self.bc = None

self.Wo = None
self.Uo = None
self.bo = None

self._init_gradients()
self.cache = None

scale_gates = np.sqrt(2.0 / (input_dim + units)) # Xavier
scale_candidates = np.sqrt(2.0 / input_dim) # He

def initialize_weights(self, input_dim: int):
# Forget gate
self.Wf = self.rng.normal(0, scale_gates, (input_dim, units))
self.Uf = self.rng.normal(0, scale_gates, (units, units))
self.bf = np.ones((1, units))
self.Wf = self.orthogonal_init((input_dim, self.units))
self.Uf = self.orthogonal_init((self.units, self.units))
self.bf = np.ones((1, self.units))

# Input gate
self.Wi = self.rng.normal(0, scale_gates, (input_dim, units))
self.Ui = self.rng.normal(0, scale_gates, (units, units))
self.bi = np.zeros((1, units))
self.Wi = self.orthogonal_init((input_dim, self.units))
self.Ui = self.orthogonal_init((self.units, self.units))
self.bi = np.zeros((1, self.units))

# Cell gate
self.Wc = self.rng.normal(0, scale_candidates, (input_dim, units))
self.Uc = self.rng.normal(0, scale_candidates, (units, units))
self.bc = np.zeros((1, units))
self.Wc = self.orthogonal_init((input_dim, self.units))
self.Uc = self.orthogonal_init((self.units, self.units))
self.bc = np.zeros((1, self.units))

# Output gate
self.Wo = self.rng.normal(0, scale_gates, (input_dim, units))
self.Uo = self.rng.normal(0, scale_gates, (units, units))
self.bo = np.zeros((1, units))

# Gradients
self.dWf = np.zeros_like(self.Wf)
self.dUf = np.zeros_like(self.Uf)
self.dbf = np.zeros_like(self.bf)
self.Wo = self.orthogonal_init((input_dim, self.units))
self.Uo = self.orthogonal_init((self.units, self.units))
self.bo = np.zeros((1, self.units))

self.dWi = np.zeros_like(self.Wi)
self.dUi = np.zeros_like(self.Ui)
self.dbi = np.zeros_like(self.bi)

self.dWc = np.zeros_like(self.Wc)
self.dUc = np.zeros_like(self.Uc)
self.dbc = np.zeros_like(self.bc)

self.dWo = np.zeros_like(self.Wo)
self.dUo = np.zeros_like(self.Uo)
self.dbo = np.zeros_like(self.bo)

self.cache = None
# Initialize gradients
self._init_gradients()

def _init_gradients(self):
if self.Wf is not None:
self.dWf = np.zeros_like(self.Wf)
self.dUf = np.zeros_like(self.Uf)
self.dbf = np.zeros_like(self.bf)

self.dWi = np.zeros_like(self.Wi)
self.dUi = np.zeros_like(self.Ui)
self.dbi = np.zeros_like(self.bi)

self.dWc = np.zeros_like(self.Wc)
self.dUc = np.zeros_like(self.Uc)
self.dbc = np.zeros_like(self.bc)

self.dWo = np.zeros_like(self.Wo)
self.dUo = np.zeros_like(self.Uo)
self.dbo = np.zeros_like(self.bo)

def orthogonal_init(self, shape):
if len(shape) < 2:
return self.rng.normal(0, 1, shape)
flat_shape = (shape[0], np.prod(shape[1:]))
a = self.rng.normal(0, 1, flat_shape)
u, _, vt = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else vt
q = q.reshape(shape)
return q

def forward(self, x_t: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
if self.Wf is None:
self.initialize_weights(x_t.shape[1])

# Store inputs for backprop
self.x_t = x_t
self.h_prev = h_prev
Expand Down Expand Up @@ -1478,7 +1510,7 @@ def forward(self, x_t: np.ndarray, h_prev: np.ndarray, c_prev: np.ndarray) -> tu
'h_t': self.h_t
}

return self.h_t, self.c_t
return self.h_t, self.c_t

def backward(self, dh_next: np.ndarray, dc_next: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x_t = self.cache['x_t']
Expand All @@ -1491,12 +1523,20 @@ def backward(self, dh_next: np.ndarray, dc_next: np.ndarray) -> tuple[np.ndarray
c_t = self.cache['c_t']
c_t_tanh = self.cache['c_t_tanh']

# Clip incoming gradients
dh_next = self.clip_gradients(dh_next)
dc_next = self.clip_gradients(dc_next)

# Gradients of the hidden and cell states
do = dh_next * c_t_tanh
dc = dc_next + dh_next * o_t * (1 - c_t_tanh ** 2)

# Clip cell state gradients
dc = self.clip_gradients(dc)

# Output gate gradients
do_input = do * o_t * (1 - o_t)
do_input = self.clip_gradients(do_input)
self.dWo += np.dot(x_t.T, do_input)
self.dUo += np.dot(h_prev.T, do_input)
self.dbo += np.sum(do_input, axis=0, keepdims=True)
Expand All @@ -1507,50 +1547,67 @@ def backward(self, dh_next: np.ndarray, dc_next: np.ndarray) -> tuple[np.ndarray
di = dc * c_tilde
dc_tilde = dc * i_t

# Clip all gate gradients
df = self.clip_gradients(df)
di = self.clip_gradients(di)
dc_tilde = self.clip_gradients(dc_tilde)

# Forget gate gradients
df_input = df * f_t * (1 - f_t)
df_input = self.clip_gradients(df_input)
self.dWf += np.dot(x_t.T, df_input)
self.dUf += np.dot(h_prev.T, df_input)
self.dbf += np.sum(df_input, axis=0, keepdims=True)

# Input gate gradients
di_input = di * i_t * (1 - i_t)
di_input = self.clip_gradients(di_input)
self.dWi += np.dot(x_t.T, di_input)
self.dUi += np.dot(h_prev.T, di_input)
self.dbi += np.sum(di_input, axis=0, keepdims=True)

# Cell candidate gradients
dc_tilde_input = dc_tilde * (1 - c_tilde ** 2)
dc_tilde_input = self.clip_gradients(dc_tilde_input)
self.dWc += np.dot(x_t.T, dc_tilde_input)
self.dUc += np.dot(h_prev.T, dc_tilde_input)
self.dbc += np.sum(dc_tilde_input, axis=0, keepdims=True)

# Input gradients
dx = (np.dot(df_input, self.Wf.T) +
np.dot(di_input, self.Wi.T) +
np.dot(dc_tilde_input, self.Wc.T) +
np.dot(do_input, self.Wo.T))

dh_prev = (np.dot(df_input, self.Uf.T) +
np.dot(di_input, self.Ui.T) +
np.dot(dc_tilde_input, self.Uc.T) +
np.dot(do_input, self.Uo.T))

return dx, dh_prev, dc_prev
dx = self.clip_gradients(dx)
dh_prev = self.clip_gradients(dh_prev)
dc_prev = self.clip_gradients(dc_prev)

def sigmoid(self, x: np.ndarray) -> np.ndarray:
return 0.5 * (1 + np.tanh(x * 0.5))
return dx, dh_prev, dc_prev

def get_config(self) -> dict:
return {
'input_dim': self.input_dim,
'units': self.units,
'random_state': self.random_state,
'clip_value': self.clip_value
}

def clip_gradients(self, gradient: np.ndarray) -> np.ndarray:
return np.clip(gradient, -self.clip_value, self.clip_value)

def sigmoid(self, x: np.ndarray) -> np.ndarray:
EPSILON = 1e-12
result = 0.5 * (1 + np.tanh(x * 0.5))
return np.clip(result, EPSILON, 1 - EPSILON)


class LSTM(Layer):
def __init__(self, units: int, return_sequences: bool = False, return_state: bool = False, random_state: int | None = None, **kwargs):
def __init__(self, units: int, return_sequences: bool = False, return_state: bool = False, random_state: int | None = None, clip_value: float = 5.0, **kwargs):
super().__init__()
self.units = units
self.return_sequences = return_sequences
Expand All @@ -1562,19 +1619,21 @@ def __init__(self, units: int, return_sequences: bool = False, return_state: boo
self.last_c = None
self.cache = None
self.input_shape = None
self.clip_value = clip_value

for key, value in kwargs.items():
setattr(self, key, value)

def __str__(self) -> str:
return f'LSTM(units={self.units}, return_sequences={self.return_sequences}, return_state={self.return_state})'
return f'LSTM(units={self.units}, return_sequences={self.return_sequences}, return_state={self.return_state}, random_state={self.random_state}, clip_value={self.clip_value})'

def forward_pass(self, x: np.ndarray, training: bool = True) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray]:
self.input_shape = x.shape

batch_size, timesteps, input_dim = x.shape

if not self.initialized:
self.cell = LSTMCell(input_dim, self.units, self.random_state)
self.cell = LSTMCell(self.units, self.random_state, self.clip_value)
self.cell.initialize_weights(input_dim)
self.initialized = True

h = np.zeros((batch_size, self.units))
Expand Down Expand Up @@ -1606,26 +1665,26 @@ def forward_pass(self, x: np.ndarray, training: bool = True) -> np.ndarray | tup
return self.last_h, self.last_h, self.last_c
return self.last_h

def backward_pass(self, dout: np.ndarray) -> np.ndarray:
def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
batch_size, timesteps, input_dim = self.input_shape

if len(dout.shape) == 2:
if len(output_error.shape) == 2:
full_dout = np.zeros((batch_size, timesteps, self.units))
full_dout[:, -1, :] = dout
dout = full_dout
full_dout[:, -1, :] = output_error
output_error = full_dout

dx = np.zeros((batch_size, timesteps, self.cell.input_dim))
dx = np.zeros((batch_size, timesteps, input_dim))
dh_next = np.zeros((batch_size, self.units))
dc_next = np.zeros((batch_size, self.units))

for t in reversed(range(timesteps)):
dh = dout[:, t, :] + dh_next
dh = output_error[:, t, :] + dh_next
dh = np.clip(dh, -self.clip_value, self.clip_value)

self.cell.cache = self.cache[t]
dx_t, dh_next, dc_next = self.cell.backward(dh, dc_next)
dx[:, t, :] = dx_t

self.cache = None

return dx

def get_config(self) -> dict:
Expand All @@ -1634,6 +1693,7 @@ def get_config(self) -> dict:
'units': self.units,
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'clip_value': self.clip_value,
'random_state': self.random_state,
'cell': self.cell.get_config() if self.cell is not None else None
}
Expand All @@ -1644,6 +1704,7 @@ def from_config(config: dict):
config['units'],
config['return_sequences'],
config['return_state'],
config.get('clip_value', 5.0),
config['random_state']
)

Expand Down

0 comments on commit dfee511

Please sign in to comment.