Skip to content

Commit

Permalink
implemented RMSprop; better optim tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Oct 17, 2023
1 parent cca9d3c commit 8a8bd98
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/mpcrl/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__all__ = ["Adam", "GradientDescent", "GD", "NetwonMethod", "NM"]
__all__ = ["Adam", "GradientDescent", "GD", "NetwonMethod", "NM", "RMSprop"]

from .adam import Adam
from .gradient_descent import GradientDescent
from .newton_method import NetwonMethod
from .rmsprop import RMSprop

GD = GradientDescent
NM = NetwonMethod
2 changes: 1 addition & 1 deletion src/mpcrl/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Adam(GradientBasedOptimizer):
"""

_hessian_sparsity = "diag"
"""In Adam, the hessian is at most diagonal, i.e., in case we have constraints."""
"""In Adam, hessian is at most diagonal, i.e., in case we have constraints."""

def __init__(
self,
Expand Down
146 changes: 146 additions & 0 deletions src/mpcrl/optim/rmsprop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from typing import Optional, Union

import casadi as cs
import numpy as np
import numpy.typing as npt

from mpcrl.core.learning_rate import LearningRate, LrType
from mpcrl.core.parameters import LearnableParametersDict, SymType
from mpcrl.core.schedulers import Scheduler
from mpcrl.optim.gradient_based_optimizer import GradientBasedOptimizer


class RMSprop(GradientBasedOptimizer):
"""RMSprop optimizer, based on [1,2].
References
----------
[1] Geoffrey Hinton, Nitish Srivastava, and Kevin Swersky. Neural networks for
machine learning lecture 6a overview of mini-batch gradient descent. page 14,
2012.
[2] RMSprop - PyTorch 2.1 documentation.
https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
"""

_hessian_sparsity = "diag"
"""In RMSprop, hessian is at most diagonal, i.e., in case we have constraints."""

def __init__(
self,
learning_rate: Union[LrType, Scheduler[LrType], LearningRate[LrType]],
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
max_percentage_update: float = float("+inf"),
) -> None:
"""Instantiates the optimizer.
Parameters
----------
learning_rate : float/array, scheduler or LearningRate
The learning rate of the optimizer. A float/array can be passed in case the
learning rate must stay constant; otherwise, a scheduler can be passed which
will be stepped `on_update` by default. Otherwise, a `LearningRate` object
can be passed, allowing to specify both the scheduling and stepping
strategies of this fundamental hyper-parameter.
alpha : float, optional
A positive float that specifies the decay rate of the running average of the
gradient. By default, it is set to `0.99`.
eps : float, optional
Term added to the denominator to improve numerical stability. By default, it
is set to `1e-8`.
weight_decay : float, optional
A positive float that specifies the decay of the learnable parameters in the
form of an L2 regularization term. By default, it is set to `0.0`, so no
decay/regularization takes place.
momentum : float, optional
A positive float that specifies the momentum factor. By default, it is set
to `0.0`, so no momentum is used.
centered : bool, optional
If `True`, compute the centered RMSProp, i.e., the gradient is normalized by
an estimation of its variance.
max_percentage_update : float, optional
A positive float that specifies the maximum percentage change the learnable
parameters can experience in each update. For example,
`max_percentage_update=0.5` means that the parameters can be updated by up
to 50% of their current value. By default, it is set to `+inf`. If
specified, the update becomes constrained and has to be solved as a QP,
which is inevitably slower than its unconstrained counterpart.
"""
super().__init__(learning_rate, max_percentage_update)
self.weight_decay = weight_decay
self.alpha = alpha
self.eps = eps
self.momentum = momentum
self.centered = centered

def set_learnable_parameters(self, pars: LearnableParametersDict[SymType]) -> None:
super().set_learnable_parameters(pars)
# initialize also running averages
n = pars.size
self._square_avg = np.zeros(n, dtype=float)
self._grad_avg = np.zeros(n, dtype=float) if self.centered else None
self._momentum_buf = np.zeros(n, dtype=float) if self.momentum > 0.0 else None

def _first_order_update(
self, gradient: npt.NDArray[np.floating]
) -> tuple[npt.NDArray[np.floating], Optional[str]]:
theta = self.learnable_parameters.value

# compute candidate update
weight_decay = self.weight_decay
lr = self.learning_rate.value
if weight_decay > 0.0:
gradient = gradient + weight_decay * theta
dtheta, self._square_avg, self._grad_avg, self._momentum_buf = _rmsprop(
gradient,
self._square_avg,
lr,
self.alpha,
self.eps,
self.centered,
self._grad_avg,
self.momentum,
self._momentum_buf,
)

# if unconstrained, apply the update directly; otherwise, solve the QP
solver = self._update_solver
if solver is None:
return theta + dtheta, None
lbx, ubx = self._get_update_bounds(theta)
sol = solver(h=cs.DM.eye(theta.shape[0]), g=-dtheta, lbx=lbx, ubx=ubx)
dtheta = sol["x"].full().reshape(-1)
stats = solver.stats()
return theta + dtheta, None if stats["success"] else stats["return_status"]


def _rmsprop(
grad: np.ndarray,
square_avg: np.ndarray,
lr: LrType,
alpha: float,
eps: float,
centered: bool,
grad_avg: Optional[np.ndarray],
momentum: float,
momentum_buffer: Optional[np.ndarray],
) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
"""Computes the update's change according to Adam algorithm."""
square_avg = alpha * square_avg + (1 - alpha) * np.square(grad)

if centered:
grad_avg = alpha * grad_avg + (1 - alpha) * grad
avg = np.sqrt(square_avg - np.square(grad_avg))
else:
avg = np.sqrt(square_avg)
avg += eps

if momentum > 0.0:
momentum_buffer = momentum * momentum_buffer + grad / avg
dtheta = -lr * momentum_buffer
else:
dtheta = -lr * grad / avg
return dtheta, square_avg, grad_avg, momentum_buffer
113 changes: 110 additions & 3 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,11 @@ def test_update__constrained__with_small_bounds(


class TestAdam(unittest.TestCase):
@parameterized.expand(product((False, True), (False, True)))
def test(self, decouple_weight_decay: bool, amsgrad: bool):
@parameterized.expand(product((0, 0.01), (False, True), (False, True)))
def test(self, weight_decay: float, decouple_weight_decay: bool, amsgrad: bool):
# prepare data
betas = tuple(np.random.uniform(0.9, 1.0, size=2))
eps = np.random.uniform(1e-8, 1e-6)
weight_decay = 0 if np.random.rand() < 0.5 else np.random.uniform(0.0, 1e-2)
lr = np.random.uniform(1e-4, 1e-3)

# prepare torch elements
Expand Down Expand Up @@ -337,5 +336,113 @@ def test(self, decouple_weight_decay: bool, amsgrad: bool):
)


class TestRMSprop(unittest.TestCase):
@parameterized.expand(product((0, 0.1), (0, 0.9), (False, True)))
def test(self, weight_decay: float, momentum: float, centered: bool):
# prepare data
alpha = np.random.uniform(0.9, 0.99)
eps = np.random.uniform(1e-8, 1e-6)
lr = np.random.uniform(1e-4, 1e-3)

# prepare torch elements
x = torch.linspace(-np.pi, np.pi, 2, dtype=torch.float64)
xx_torch = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3]))
y_actual_torch = torch.sin(x)
model_torch = torch.nn.Linear(3, 1, dtype=torch.float64)
loss_fn_torch = torch.nn.MSELoss(reduction="sum")
optimizer_torch = torch.optim.RMSprop(
params=model_torch.parameters(),
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
)

# prepare mpcrl elements
xx_dm = cs.DM(xx_torch.detach().numpy())
A_sym = cs.MX.sym("A", *model_torch.weight.shape)
b_sym = cs.MX.sym("b", *model_torch.bias.shape)
y_pred_sym = xx_dm @ A_sym.T + b_sym
y_actual_dm = cs.DM(y_actual_torch.detach().clone().numpy())
loss_sym = cs.sumsqr(y_pred_sym - y_actual_dm)
p_sym = cs.veccat(A_sym, b_sym)
dldp_sym = cs.gradient(loss_sym, p_sym)
model_mpcrl = cs.Function(
"F", [p_sym], [y_pred_sym, loss_sym, dldp_sym], ["p"], ["y", "l", "dldp"]
)
A_mpcrl = model_torch.weight.data.detach().clone().numpy().flatten()
b_mpcrl = model_torch.bias.data.detach().clone().numpy()
learnable_pars = LearnableParametersDict(
[
LearnableParameter("A", A_mpcrl.size, A_mpcrl),
LearnableParameter("b", b_mpcrl.size, b_mpcrl),
]
)
optimizer_mpcrl = O.RMSprop(
learning_rate=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
max_percentage_update=1e4, # test constrained
)
optimizer_mpcrl.set_learnable_parameters(learnable_pars)

# run test
cmp = lambda x, y, msg: np.testing.assert_allclose(
x, y, rtol=1e-6, atol=1e-6, err_msg=msg
)
for i in range(20):
# torch
y_pred_torch = model_torch(xx_torch).flatten()
loss_torch = loss_fn_torch(y_pred_torch, y_actual_torch)
optimizer_torch.zero_grad()
loss_torch.backward()
optimizer_torch.step()
grad_torch = np.concatenate(
[
model_torch.weight.grad.detach().clone().numpy(),
model_torch.bias.grad.detach().clone().numpy(),
],
None,
)

# mpcrl
y_pred_mpcrl, loss_mpcrl, grad_mpcrl = model_mpcrl(
np.concatenate([A_mpcrl, b_mpcrl], None)
)
grad_mpcrl = grad_mpcrl.full().flatten()
status = optimizer_mpcrl.update(grad_mpcrl)
p_new = learnable_pars.value
A_mpcrl, b_mpcrl = np.array_split(p_new, [A_mpcrl.size])

# check
self.assertIsNone(status)
cmp(
y_pred_mpcrl.full().flatten(),
y_pred_torch.detach().clone().numpy(),
f"prediction mismatch at iteration {i}",
)
cmp(
float(loss_mpcrl),
loss_torch.detach().clone().item(),
f"loss mismatch at iteration {i}",
)
cmp(grad_mpcrl, grad_torch, f"gradient mismatch at iteration {i}")
cmp(
A_mpcrl,
model_torch.weight.detach().clone().numpy().reshape(A_mpcrl.shape),
f"`A` mismatch at iteration {i}",
)
cmp(
b_mpcrl,
model_torch.bias.detach().clone().numpy().reshape(b_mpcrl.shape),
f"`b` mismatch at iteration {i}",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8a8bd98

Please sign in to comment.