Skip to content

Commit

Permalink
feat(optimizers): add RAdam
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 3, 2024
1 parent d059bad commit 48a71b2
Showing 1 changed file with 118 additions and 0 deletions.
118 changes: 118 additions & 0 deletions neuralnetlib/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,121 @@ def __str__(self):
return (f"{self.__class__.__name__}(learning_rate={self.learning_rate}, "
f"beta_1={self.beta_1}, beta_2={self.beta_2}, epsilon={self.epsilon}, "
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")


class RAdam(Optimizer):
def __init__(self, learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: float = 0.999,
epsilon: float = 1e-8, clip_norm: float = None, clip_value: float = None) -> None:
super().__init__(learning_rate)
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.clip_norm = clip_norm
self.clip_value = clip_value
self.t = 0

self.m_w, self.v_w = {}, {}
self.m_b, self.v_b = {}, {}

self._min_denom = 1e-16
self._max_exp = np.log(np.finfo(np.float64).max)

self.rho_inf = 2/(1-beta_2) - 1

def _clip_gradients(self, grad: np.ndarray) -> np.ndarray:
if grad is None:
return None

if self.clip_norm is not None:
grad_norm = np.linalg.norm(grad)
if grad_norm > self.clip_norm:
grad = grad * (self.clip_norm / (grad_norm + self._min_denom))

if self.clip_value is not None:
grad = np.clip(grad, -self.clip_value, self.clip_value)

return grad

def _compute_moments(self, param: np.ndarray, grad: np.ndarray, m: np.ndarray, v: np.ndarray) -> tuple:
grad = self._clip_gradients(grad)

m = self.beta_1 * m + (1 - self.beta_1) * grad
v = self.beta_2 * v + (1 - self.beta_2) * np.square(grad)

beta1_t = self.beta_1 ** self.t
beta2_t = self.beta_2 ** self.t

m_hat = m / (1 - beta1_t)

rho_t = self.rho_inf - 2 * self.t * beta2_t / (1 - beta2_t)

if rho_t > 4:
v_hat = np.sqrt(v / (1 - beta2_t))
r_t = np.sqrt(((rho_t - 4) * (rho_t - 2) * self.rho_inf) /
((self.rho_inf - 4) * (self.rho_inf - 2) * rho_t))

denom = v_hat + self.epsilon
update = r_t * self.learning_rate * m_hat / np.maximum(denom, self._min_denom)
else:
update = self.learning_rate * m_hat

update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
param -= update

return param, m, v

def update(self, layer_index: int, weights: np.ndarray, weights_grad: np.ndarray, bias: np.ndarray,
bias_grad: np.ndarray) -> None:
if layer_index not in self.m_w:
self.m_w[layer_index] = np.zeros_like(weights)
self.v_w[layer_index] = np.zeros_like(weights)
self.m_b[layer_index] = np.zeros_like(bias)
self.v_b[layer_index] = np.zeros_like(bias)

self.t += 1

weights, self.m_w[layer_index], self.v_w[layer_index] = self._compute_moments(
weights, weights_grad, self.m_w[layer_index], self.v_w[layer_index]
)

bias, self.m_b[layer_index], self.v_b[layer_index] = self._compute_moments(
bias, bias_grad, self.m_b[layer_index], self.v_b[layer_index]
)

def get_config(self) -> dict:
return {
"name": self.__class__.__name__,
"learning_rate": self.learning_rate,
"beta_1": self.beta_1,
"beta_2": self.beta_2,
"epsilon": self.epsilon,
"clip_norm": self.clip_norm,
"clip_value": self.clip_value,
"t": self.t,
"m_w": dict_with_ndarray_to_dict_with_list(self.m_w),
"v_w": dict_with_ndarray_to_dict_with_list(self.v_w),
"m_b": dict_with_ndarray_to_dict_with_list(self.m_b),
"v_b": dict_with_ndarray_to_dict_with_list(self.v_b)
}

@staticmethod
def from_config(config: dict):
radam = RAdam(
learning_rate=config['learning_rate'],
beta_1=config['beta_1'],
beta_2=config['beta_2'],
epsilon=config['epsilon'],
clip_norm=config.get('clip_norm'),
clip_value=config.get('clip_value')
)
radam.t = config['t']
radam.m_w = dict_with_list_to_dict_with_ndarray(config['m_w'])
radam.v_w = dict_with_list_to_dict_with_ndarray(config['v_w'])
radam.m_b = dict_with_list_to_dict_with_ndarray(config['m_b'])
radam.v_b = dict_with_list_to_dict_with_ndarray(config['v_b'])
return radam

def __str__(self):
return (f"{self.__class__.__name__}(learning_rate={self.learning_rate}, "
f"beta_1={self.beta_1}, beta_2={self.beta_2}, epsilon={self.epsilon}, "
f"clip_norm={self.clip_norm}, clip_value={self.clip_value})")

0 comments on commit 48a71b2

Please sign in to comment.