From 8a8bd98def09d09dc3820a1b932b024c1424ac1b Mon Sep 17 00:00:00 2001 From: Filippo Airaldi Date: Tue, 17 Oct 2023 13:34:03 +0200 Subject: [PATCH] implemented RMSprop; better optim tests --- src/mpcrl/optim/__init__.py | 3 +- src/mpcrl/optim/adam.py | 2 +- src/mpcrl/optim/rmsprop.py | 146 ++++++++++++++++++++++++++++++++++++ tests/test_optim.py | 113 +++++++++++++++++++++++++++- 4 files changed, 259 insertions(+), 5 deletions(-) create mode 100644 src/mpcrl/optim/rmsprop.py diff --git a/src/mpcrl/optim/__init__.py b/src/mpcrl/optim/__init__.py index 4a75bdd..890cdac 100644 --- a/src/mpcrl/optim/__init__.py +++ b/src/mpcrl/optim/__init__.py @@ -1,8 +1,9 @@ -__all__ = ["Adam", "GradientDescent", "GD", "NetwonMethod", "NM"] +__all__ = ["Adam", "GradientDescent", "GD", "NetwonMethod", "NM", "RMSprop"] from .adam import Adam from .gradient_descent import GradientDescent from .newton_method import NetwonMethod +from .rmsprop import RMSprop GD = GradientDescent NM = NetwonMethod diff --git a/src/mpcrl/optim/adam.py b/src/mpcrl/optim/adam.py index ba01761..bad1f4e 100644 --- a/src/mpcrl/optim/adam.py +++ b/src/mpcrl/optim/adam.py @@ -29,7 +29,7 @@ class Adam(GradientBasedOptimizer): """ _hessian_sparsity = "diag" - """In Adam, the hessian is at most diagonal, i.e., in case we have constraints.""" + """In Adam, hessian is at most diagonal, i.e., in case we have constraints.""" def __init__( self, diff --git a/src/mpcrl/optim/rmsprop.py b/src/mpcrl/optim/rmsprop.py new file mode 100644 index 0000000..defffbc --- /dev/null +++ b/src/mpcrl/optim/rmsprop.py @@ -0,0 +1,146 @@ +from typing import Optional, Union + +import casadi as cs +import numpy as np +import numpy.typing as npt + +from mpcrl.core.learning_rate import LearningRate, LrType +from mpcrl.core.parameters import LearnableParametersDict, SymType +from mpcrl.core.schedulers import Scheduler +from mpcrl.optim.gradient_based_optimizer import GradientBasedOptimizer + + +class RMSprop(GradientBasedOptimizer): + """RMSprop optimizer, based on [1,2]. + + References + ---------- + [1] Geoffrey Hinton, Nitish Srivastava, and Kevin Swersky. Neural networks for + machine learning lecture 6a overview of mini-batch gradient descent. page 14, + 2012. + [2] RMSprop - PyTorch 2.1 documentation. + https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html + """ + + _hessian_sparsity = "diag" + """In RMSprop, hessian is at most diagonal, i.e., in case we have constraints.""" + + def __init__( + self, + learning_rate: Union[LrType, Scheduler[LrType], LearningRate[LrType]], + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + max_percentage_update: float = float("+inf"), + ) -> None: + """Instantiates the optimizer. + + Parameters + ---------- + learning_rate : float/array, scheduler or LearningRate + The learning rate of the optimizer. A float/array can be passed in case the + learning rate must stay constant; otherwise, a scheduler can be passed which + will be stepped `on_update` by default. Otherwise, a `LearningRate` object + can be passed, allowing to specify both the scheduling and stepping + strategies of this fundamental hyper-parameter. + alpha : float, optional + A positive float that specifies the decay rate of the running average of the + gradient. By default, it is set to `0.99`. + eps : float, optional + Term added to the denominator to improve numerical stability. By default, it + is set to `1e-8`. + weight_decay : float, optional + A positive float that specifies the decay of the learnable parameters in the + form of an L2 regularization term. By default, it is set to `0.0`, so no + decay/regularization takes place. + momentum : float, optional + A positive float that specifies the momentum factor. By default, it is set + to `0.0`, so no momentum is used. + centered : bool, optional + If `True`, compute the centered RMSProp, i.e., the gradient is normalized by + an estimation of its variance. + max_percentage_update : float, optional + A positive float that specifies the maximum percentage change the learnable + parameters can experience in each update. For example, + `max_percentage_update=0.5` means that the parameters can be updated by up + to 50% of their current value. By default, it is set to `+inf`. If + specified, the update becomes constrained and has to be solved as a QP, + which is inevitably slower than its unconstrained counterpart. + """ + super().__init__(learning_rate, max_percentage_update) + self.weight_decay = weight_decay + self.alpha = alpha + self.eps = eps + self.momentum = momentum + self.centered = centered + + def set_learnable_parameters(self, pars: LearnableParametersDict[SymType]) -> None: + super().set_learnable_parameters(pars) + # initialize also running averages + n = pars.size + self._square_avg = np.zeros(n, dtype=float) + self._grad_avg = np.zeros(n, dtype=float) if self.centered else None + self._momentum_buf = np.zeros(n, dtype=float) if self.momentum > 0.0 else None + + def _first_order_update( + self, gradient: npt.NDArray[np.floating] + ) -> tuple[npt.NDArray[np.floating], Optional[str]]: + theta = self.learnable_parameters.value + + # compute candidate update + weight_decay = self.weight_decay + lr = self.learning_rate.value + if weight_decay > 0.0: + gradient = gradient + weight_decay * theta + dtheta, self._square_avg, self._grad_avg, self._momentum_buf = _rmsprop( + gradient, + self._square_avg, + lr, + self.alpha, + self.eps, + self.centered, + self._grad_avg, + self.momentum, + self._momentum_buf, + ) + + # if unconstrained, apply the update directly; otherwise, solve the QP + solver = self._update_solver + if solver is None: + return theta + dtheta, None + lbx, ubx = self._get_update_bounds(theta) + sol = solver(h=cs.DM.eye(theta.shape[0]), g=-dtheta, lbx=lbx, ubx=ubx) + dtheta = sol["x"].full().reshape(-1) + stats = solver.stats() + return theta + dtheta, None if stats["success"] else stats["return_status"] + + +def _rmsprop( + grad: np.ndarray, + square_avg: np.ndarray, + lr: LrType, + alpha: float, + eps: float, + centered: bool, + grad_avg: Optional[np.ndarray], + momentum: float, + momentum_buffer: Optional[np.ndarray], +) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: + """Computes the update's change according to Adam algorithm.""" + square_avg = alpha * square_avg + (1 - alpha) * np.square(grad) + + if centered: + grad_avg = alpha * grad_avg + (1 - alpha) * grad + avg = np.sqrt(square_avg - np.square(grad_avg)) + else: + avg = np.sqrt(square_avg) + avg += eps + + if momentum > 0.0: + momentum_buffer = momentum * momentum_buffer + grad / avg + dtheta = -lr * momentum_buffer + else: + dtheta = -lr * grad / avg + return dtheta, square_avg, grad_avg, momentum_buffer diff --git a/tests/test_optim.py b/tests/test_optim.py index 5e946e4..18630c7 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -229,12 +229,11 @@ def test_update__constrained__with_small_bounds( class TestAdam(unittest.TestCase): - @parameterized.expand(product((False, True), (False, True))) - def test(self, decouple_weight_decay: bool, amsgrad: bool): + @parameterized.expand(product((0, 0.01), (False, True), (False, True))) + def test(self, weight_decay: float, decouple_weight_decay: bool, amsgrad: bool): # prepare data betas = tuple(np.random.uniform(0.9, 1.0, size=2)) eps = np.random.uniform(1e-8, 1e-6) - weight_decay = 0 if np.random.rand() < 0.5 else np.random.uniform(0.0, 1e-2) lr = np.random.uniform(1e-4, 1e-3) # prepare torch elements @@ -337,5 +336,113 @@ def test(self, decouple_weight_decay: bool, amsgrad: bool): ) +class TestRMSprop(unittest.TestCase): + @parameterized.expand(product((0, 0.1), (0, 0.9), (False, True))) + def test(self, weight_decay: float, momentum: float, centered: bool): + # prepare data + alpha = np.random.uniform(0.9, 0.99) + eps = np.random.uniform(1e-8, 1e-6) + lr = np.random.uniform(1e-4, 1e-3) + + # prepare torch elements + x = torch.linspace(-np.pi, np.pi, 2, dtype=torch.float64) + xx_torch = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) + y_actual_torch = torch.sin(x) + model_torch = torch.nn.Linear(3, 1, dtype=torch.float64) + loss_fn_torch = torch.nn.MSELoss(reduction="sum") + optimizer_torch = torch.optim.RMSprop( + params=model_torch.parameters(), + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + ) + + # prepare mpcrl elements + xx_dm = cs.DM(xx_torch.detach().numpy()) + A_sym = cs.MX.sym("A", *model_torch.weight.shape) + b_sym = cs.MX.sym("b", *model_torch.bias.shape) + y_pred_sym = xx_dm @ A_sym.T + b_sym + y_actual_dm = cs.DM(y_actual_torch.detach().clone().numpy()) + loss_sym = cs.sumsqr(y_pred_sym - y_actual_dm) + p_sym = cs.veccat(A_sym, b_sym) + dldp_sym = cs.gradient(loss_sym, p_sym) + model_mpcrl = cs.Function( + "F", [p_sym], [y_pred_sym, loss_sym, dldp_sym], ["p"], ["y", "l", "dldp"] + ) + A_mpcrl = model_torch.weight.data.detach().clone().numpy().flatten() + b_mpcrl = model_torch.bias.data.detach().clone().numpy() + learnable_pars = LearnableParametersDict( + [ + LearnableParameter("A", A_mpcrl.size, A_mpcrl), + LearnableParameter("b", b_mpcrl.size, b_mpcrl), + ] + ) + optimizer_mpcrl = O.RMSprop( + learning_rate=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=centered, + max_percentage_update=1e4, # test constrained + ) + optimizer_mpcrl.set_learnable_parameters(learnable_pars) + + # run test + cmp = lambda x, y, msg: np.testing.assert_allclose( + x, y, rtol=1e-6, atol=1e-6, err_msg=msg + ) + for i in range(20): + # torch + y_pred_torch = model_torch(xx_torch).flatten() + loss_torch = loss_fn_torch(y_pred_torch, y_actual_torch) + optimizer_torch.zero_grad() + loss_torch.backward() + optimizer_torch.step() + grad_torch = np.concatenate( + [ + model_torch.weight.grad.detach().clone().numpy(), + model_torch.bias.grad.detach().clone().numpy(), + ], + None, + ) + + # mpcrl + y_pred_mpcrl, loss_mpcrl, grad_mpcrl = model_mpcrl( + np.concatenate([A_mpcrl, b_mpcrl], None) + ) + grad_mpcrl = grad_mpcrl.full().flatten() + status = optimizer_mpcrl.update(grad_mpcrl) + p_new = learnable_pars.value + A_mpcrl, b_mpcrl = np.array_split(p_new, [A_mpcrl.size]) + + # check + self.assertIsNone(status) + cmp( + y_pred_mpcrl.full().flatten(), + y_pred_torch.detach().clone().numpy(), + f"prediction mismatch at iteration {i}", + ) + cmp( + float(loss_mpcrl), + loss_torch.detach().clone().item(), + f"loss mismatch at iteration {i}", + ) + cmp(grad_mpcrl, grad_torch, f"gradient mismatch at iteration {i}") + cmp( + A_mpcrl, + model_torch.weight.detach().clone().numpy().reshape(A_mpcrl.shape), + f"`A` mismatch at iteration {i}", + ) + cmp( + b_mpcrl, + model_torch.bias.detach().clone().numpy().reshape(b_mpcrl.shape), + f"`b` mismatch at iteration {i}", + ) + + if __name__ == "__main__": unittest.main()