Skip to content

Commit

Permalink
reworked sensitivities in Q learning and DPG
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Oct 26, 2023
1 parent 7cc7205 commit cdf8ba5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/mpcrl/agents/lstd_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def update(self) -> Optional[str]:
)
if self.policy_gradients is not None:
self.policy_gradients.append(dJdtheta)
return self._do_gradient_update(dJdtheta)
return self.optimizer.update(dJdtheta)

def train_one_episode(
self,
Expand Down
59 changes: 31 additions & 28 deletions src/mpcrl/agents/lstd_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def update(self) -> Optional[str]:
hessians.append(H)
gradient = np.mean(gradients, 0)
hessian = np.mean(hessians, 0) if self.hessian_type != "none" else None
return self._do_gradient_update(gradient, hessian)
return self.optimizer.update(gradient, hessian)

def train_one_episode(
self,
Expand Down Expand Up @@ -205,36 +205,39 @@ def train_one_episode(

def _init_sensitivity(
self, hessian_type: Literal["none", "approx", "full"]
) -> Callable[[cs.DM], tuple[np.ndarray, np.ndarray]]:
) -> Union[
Callable[[cs.DM], np.ndarray], Callable[[cs.DM], tuple[np.ndarray, np.ndarray]]
]:
"""Internal utility to compute the derivative of Q(s,a) w.r.t. the learnable
parameters, a.k.a., theta."""
assert hessian_type in ("none", "approx", "full"), "Invalid hessian type."
order = self.optimizer._order
theta = cs.vvcat(self._learnable_pars.sym.values())
nlp = self._Q.nlp
nlp_ = NlpSensitivity(nlp, theta)
Lt = nlp_.jacobians["L-p"] # a.k.a., dQdtheta
Ltt = nlp_.hessians["L-pp"] # a.k.a., approximated d2Qdtheta2
x_lam_p = cs.vertcat(nlp.primal_dual, nlp.p)
dQ = nlp_.jacobians["L-p"] # a.k.a., dQdtheta

if hessian_type == "none":
d2Qdtheta2 = cs.DM.nan()
elif hessian_type == "approx":
d2Qdtheta2 = Ltt
elif hessian_type == "full":
dydtheta, _ = nlp_.parametric_sensitivity(second_order=False)
d2Qdtheta2 = dydtheta.T @ nlp_.jacobians["K-p"] + Ltt
assert order == 1, "Expected 1st-order optimizer with `hessian_type=none`."
sensitivity = cs.Function(
"S", (x_lam_p,), (dQ,), ("x_lam_p",), ("dQ",), {"cse": True}
)
return lambda v: np.asarray(sensitivity(v).elements())

assert (
order == 2
), "Expected 2nd-order optimizer with `hessian_type=approx` or `full`."
if hessian_type == "approx":
ddQ = nlp_.hessians["L-pp"]
else:
raise ValueError(f"Invalid type of hessian; got {hessian_type}.")
dydtheta, _ = nlp_.parametric_sensitivity(second_order=False)
ddQ = dydtheta.T @ nlp_.jacobians["K-p"] + nlp_.hessians["L-pp"]

# convert to function (much faster runtime)
x_lam_p = cs.vertcat(nlp.primal_dual, nlp.p)
sensitivity = cs.Function(
"Q_sensitivity",
(x_lam_p,),
(Lt, d2Qdtheta2),
("x_lam_p",),
("dQ", "d2Q"),
{"cse": True},
"S", (x_lam_p,), (dQ, ddQ), ("x_lam_p",), ("dQ", "ddQ"), {"cse": True}
)

# wrap to conveniently return numpy arrays
def func(sol_values: cs.DM) -> tuple[np.ndarray, np.ndarray]:
dQ, ddQ = sensitivity(sol_values)
return np.asarray(dQ.elements()), ddQ.toarray()
Expand All @@ -249,15 +252,15 @@ def _try_store_experience(
it. Returns whether it was successful or not."""
if solQ.success and solV.success:
sol_values = solQ.all_vals
dQ, ddQ = self._sensitivity(sol_values)
td_error = cost + self.discount_factor * solV.f - solQ.f
g = -td_error * dQ
H = (
(np.multiply.outer(dQ, dQ) - td_error * ddQ)
if self.hessian_type != "none"
else np.nan
)
self.store_experience((g, H))
if self.hessian_type == "none":
dQ = self._sensitivity(sol_values)
hessian = np.nan
else:
dQ, ddQ = self._sensitivity(sol_values)
hessian = np.multiply.outer(dQ, dQ) - td_error * ddQ
gradient = -td_error * dQ
self.store_experience((gradient, hessian))
success = True
else:
td_error = np.nan
Expand Down
14 changes: 1 addition & 13 deletions src/mpcrl/agents/rl_learning_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from abc import ABC
from typing import Any, Generic, Optional, TypeVar

import numpy as np
from typing import Any, Generic, TypeVar

from mpcrl.agents.agent import SymType
from mpcrl.agents.learning_agent import LearningAgent
Expand Down Expand Up @@ -45,13 +43,3 @@ def establish_callback_hooks(self) -> None:
lr_hook = lr.hook
if lr_hook is not None:
self.hook_callback(repr(lr), lr_hook, lr.step)

def _do_gradient_update(
self, gradient: np.ndarray, hessian: Optional[np.ndarray] = None
) -> Optional[str]:
"""Internal utility to call the optimizer and perform the gradient update."""
return (
self.optimizer.update(gradient)
if hessian is None
else self.optimizer.update(gradient, hessian)
)

0 comments on commit cdf8ba5

Please sign in to comment.