Skip to content

Commit

Permalink
fix(lstm): normalize gradients by batch_size * timesteps to match num…
Browse files Browse the repository at this point in the history
…eric approximation
  • Loading branch information
marcpinet committed Dec 2, 2024
1 parent 652d0a4 commit 3e10a40
Showing 1 changed file with 41 additions and 44 deletions.
85 changes: 41 additions & 44 deletions neuralnetlib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,45 +1634,44 @@ def backward(self, dh_next: np.ndarray, dc_next: np.ndarray) -> tuple[np.ndarray
o_t = self.cache['o_t']
c_t = self.cache['c_t']
c_t_tanh = self.cache['c_t_tanh']

do = dh_next * c_t_tanh
dc = dc_next + dh_next * o_t * (1 - c_t_tanh ** 2)

dc = dc_next + dh_next * o_t * (1 - c_t_tanh**2)
do_input = do * o_t * (1 - o_t)
self.dWo += np.dot(x_t.T, do_input)
self.dUo += np.dot(h_prev.T, do_input)
self.dbo += np.sum(do_input, axis=0, keepdims=True)

dc_prev = dc * f_t
df = dc * c_prev
di = dc * c_tilde
dc_tilde = dc * i_t

df_input = df * f_t * (1 - f_t)
self.dWf += np.dot(x_t.T, df_input)
self.dUf += np.dot(h_prev.T, df_input)
self.dbf += np.sum(df_input, axis=0, keepdims=True)


di = dc * c_tilde
di_input = di * i_t * (1 - i_t)
self.dWi += np.dot(x_t.T, di_input)
self.dUi += np.dot(h_prev.T, di_input)
self.dbi += np.sum(di_input, axis=0, keepdims=True)

dc_tilde_input = dc_tilde * (1 - c_tilde ** 2)

dc_tilde = dc * i_t
dc_tilde_input = dc_tilde * (1 - c_tilde**2)
self.dWc += np.dot(x_t.T, dc_tilde_input)
self.dUc += np.dot(h_prev.T, dc_tilde_input)
self.dbc += np.sum(dc_tilde_input, axis=0, keepdims=True)

dx = (np.dot(df_input, self.Wf.T) +
np.dot(di_input, self.Wi.T) +
np.dot(dc_tilde_input, self.Wc.T) +
np.dot(do_input, self.Wo.T))

np.dot(di_input, self.Wi.T) +
np.dot(dc_tilde_input, self.Wc.T) +
np.dot(do_input, self.Wo.T))
dh_prev = (np.dot(df_input, self.Uf.T) +
np.dot(di_input, self.Ui.T) +
np.dot(dc_tilde_input, self.Uc.T) +
np.dot(do_input, self.Uo.T))

np.dot(di_input, self.Ui.T) +
np.dot(dc_tilde_input, self.Uc.T) +
np.dot(do_input, self.Uo.T))
return dx, dh_prev, dc_prev

def orthogonal_init(self, shape):
Expand Down Expand Up @@ -1799,41 +1798,39 @@ def forward_pass(self, x: np.ndarray, training: bool = True) -> np.ndarray | tup

def backward_pass(self, output_error: np.ndarray) -> np.ndarray:
batch_size, timesteps, input_dim = self.input_shape

if len(output_error.shape) == 2:
full_dout = np.zeros((batch_size, timesteps, self.units))
full_dout[:, -1, :] = output_error
output_error = full_dout

full_output_error = np.zeros((batch_size, timesteps, self.units))
full_output_error[:, -1, :] = output_error
output_error = full_output_error
dx = np.zeros((batch_size, timesteps, input_dim))
dh_next = np.zeros((batch_size, self.units))
dc_next = np.zeros((batch_size, self.units))

self.cell._init_gradients()

squared_norm_sum = 0.0


for t in reversed(range(timesteps)):
dh = output_error[:, t, :] + dh_next

self.cell.cache = self.cache[t]

dx_t, dh_next, dc_next = self.cell.backward(dh, dc_next)
dx[:, t, :] = dx_t

squared_norm_sum += (np.sum(dx_t ** 2) +
np.sum(self.cell.dWf ** 2) + np.sum(self.cell.dUf ** 2) + np.sum(self.cell.dbf ** 2) +
np.sum(self.cell.dWi ** 2) + np.sum(self.cell.dUi ** 2) + np.sum(self.cell.dbi ** 2) +
np.sum(self.cell.dWc ** 2) + np.sum(self.cell.dUc ** 2) + np.sum(self.cell.dbc ** 2) +
np.sum(self.cell.dWo ** 2) + np.sum(self.cell.dUo ** 2) + np.sum(self.cell.dbo ** 2))

global_norm = np.sqrt(squared_norm_sum)
scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8))
if scaling_factor < 1.0:
dx *= scaling_factor
for grad in self.cell.__dict__:
if grad.startswith('d'):
setattr(self.cell, grad, getattr(self.cell, grad) * scaling_factor)


self.cell.dWf /= (batch_size * timesteps)
self.cell.dUf /= (batch_size * timesteps)
self.cell.dbf /= (batch_size * timesteps)
self.cell.dWi /= (batch_size * timesteps)
self.cell.dUi /= (batch_size * timesteps)
self.cell.dbi /= (batch_size * timesteps)
self.cell.dWc /= (batch_size * timesteps)
self.cell.dUc /= (batch_size * timesteps)
self.cell.dbc /= (batch_size * timesteps)
self.cell.dWo /= (batch_size * timesteps)
self.cell.dUo /= (batch_size * timesteps)
self.cell.dbo /= (batch_size * timesteps)

return dx

def get_config(self) -> dict:
Expand Down

0 comments on commit 3e10a40

Please sign in to comment.