From 3e10a4047ee92e69693b69c338c5fbb7e9a43c4c Mon Sep 17 00:00:00 2001 From: GitHub Action <52708150+marcpinet@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:33:43 +0100 Subject: [PATCH] fix(lstm): normalize gradients by batch_size * timesteps to match numeric approximation --- neuralnetlib/layers.py | 85 ++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/neuralnetlib/layers.py b/neuralnetlib/layers.py index 06151e8..d9c56d7 100644 --- a/neuralnetlib/layers.py +++ b/neuralnetlib/layers.py @@ -1634,45 +1634,44 @@ def backward(self, dh_next: np.ndarray, dc_next: np.ndarray) -> tuple[np.ndarray o_t = self.cache['o_t'] c_t = self.cache['c_t'] c_t_tanh = self.cache['c_t_tanh'] - + do = dh_next * c_t_tanh - dc = dc_next + dh_next * o_t * (1 - c_t_tanh ** 2) - + dc = dc_next + dh_next * o_t * (1 - c_t_tanh**2) + do_input = do * o_t * (1 - o_t) self.dWo += np.dot(x_t.T, do_input) self.dUo += np.dot(h_prev.T, do_input) self.dbo += np.sum(do_input, axis=0, keepdims=True) - + dc_prev = dc * f_t df = dc * c_prev - di = dc * c_tilde - dc_tilde = dc * i_t - df_input = df * f_t * (1 - f_t) self.dWf += np.dot(x_t.T, df_input) self.dUf += np.dot(h_prev.T, df_input) self.dbf += np.sum(df_input, axis=0, keepdims=True) - + + di = dc * c_tilde di_input = di * i_t * (1 - i_t) self.dWi += np.dot(x_t.T, di_input) self.dUi += np.dot(h_prev.T, di_input) self.dbi += np.sum(di_input, axis=0, keepdims=True) - - dc_tilde_input = dc_tilde * (1 - c_tilde ** 2) + + dc_tilde = dc * i_t + dc_tilde_input = dc_tilde * (1 - c_tilde**2) self.dWc += np.dot(x_t.T, dc_tilde_input) self.dUc += np.dot(h_prev.T, dc_tilde_input) self.dbc += np.sum(dc_tilde_input, axis=0, keepdims=True) - + dx = (np.dot(df_input, self.Wf.T) + - np.dot(di_input, self.Wi.T) + - np.dot(dc_tilde_input, self.Wc.T) + - np.dot(do_input, self.Wo.T)) - + np.dot(di_input, self.Wi.T) + + np.dot(dc_tilde_input, self.Wc.T) + + np.dot(do_input, self.Wo.T)) + dh_prev = (np.dot(df_input, self.Uf.T) + - np.dot(di_input, self.Ui.T) + - np.dot(dc_tilde_input, self.Uc.T) + - np.dot(do_input, self.Uo.T)) - + np.dot(di_input, self.Ui.T) + + np.dot(dc_tilde_input, self.Uc.T) + + np.dot(do_input, self.Uo.T)) + return dx, dh_prev, dc_prev def orthogonal_init(self, shape): @@ -1799,41 +1798,39 @@ def forward_pass(self, x: np.ndarray, training: bool = True) -> np.ndarray | tup def backward_pass(self, output_error: np.ndarray) -> np.ndarray: batch_size, timesteps, input_dim = self.input_shape - + if len(output_error.shape) == 2: - full_dout = np.zeros((batch_size, timesteps, self.units)) - full_dout[:, -1, :] = output_error - output_error = full_dout - + full_output_error = np.zeros((batch_size, timesteps, self.units)) + full_output_error[:, -1, :] = output_error + output_error = full_output_error + dx = np.zeros((batch_size, timesteps, input_dim)) dh_next = np.zeros((batch_size, self.units)) dc_next = np.zeros((batch_size, self.units)) - + self.cell._init_gradients() - - squared_norm_sum = 0.0 - + for t in reversed(range(timesteps)): dh = output_error[:, t, :] + dh_next - + self.cell.cache = self.cache[t] + dx_t, dh_next, dc_next = self.cell.backward(dh, dc_next) dx[:, t, :] = dx_t - - squared_norm_sum += (np.sum(dx_t ** 2) + - np.sum(self.cell.dWf ** 2) + np.sum(self.cell.dUf ** 2) + np.sum(self.cell.dbf ** 2) + - np.sum(self.cell.dWi ** 2) + np.sum(self.cell.dUi ** 2) + np.sum(self.cell.dbi ** 2) + - np.sum(self.cell.dWc ** 2) + np.sum(self.cell.dUc ** 2) + np.sum(self.cell.dbc ** 2) + - np.sum(self.cell.dWo ** 2) + np.sum(self.cell.dUo ** 2) + np.sum(self.cell.dbo ** 2)) - - global_norm = np.sqrt(squared_norm_sum) - scaling_factor = min(1.0, self.clip_value / (global_norm + 1e-8)) - if scaling_factor < 1.0: - dx *= scaling_factor - for grad in self.cell.__dict__: - if grad.startswith('d'): - setattr(self.cell, grad, getattr(self.cell, grad) * scaling_factor) - + + self.cell.dWf /= (batch_size * timesteps) + self.cell.dUf /= (batch_size * timesteps) + self.cell.dbf /= (batch_size * timesteps) + self.cell.dWi /= (batch_size * timesteps) + self.cell.dUi /= (batch_size * timesteps) + self.cell.dbi /= (batch_size * timesteps) + self.cell.dWc /= (batch_size * timesteps) + self.cell.dUc /= (batch_size * timesteps) + self.cell.dbc /= (batch_size * timesteps) + self.cell.dWo /= (batch_size * timesteps) + self.cell.dUo /= (batch_size * timesteps) + self.cell.dbo /= (batch_size * timesteps) + return dx def get_config(self) -> dict: