Skip to content

Commit

Permalink
implemented natural PG with fisher matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Apr 13, 2024
1 parent 38d5488 commit 8976e45
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions src/mpcrl/agents/lstd_dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
warmstart: Union[
Literal["last", "last-successful"], WarmStartStrategy
] = "last-successful",
hessian_type: Literal["none", "natural"] = "none",
rollout_length: int = -1,
record_policy_performance: bool = False,
record_policy_gradient: bool = False,
Expand Down Expand Up @@ -137,6 +138,13 @@ def __init__(
useful to generate multiple initial conditions for very non-convex problems.
Can only be used with an MPC that has an underlying multistart NLP problem
(see `csnlp.MultistartNlp`).
hessian_type : {'none', 'natural'}, optional
The type of hessian to use in this (potentially) second-order algorithm.
If 'none', no second order information is used. If `natural`, the Fisher
information matrix is used to perform a natural policy gradient update. This
option must be in accordance with the choice of `optimizer`, that is, if the
optimizer does not use second order information, then this option must be
set to `none`.
rollout_length : int, optional
Number of steps of each closed-loop simulation, which defines a complete
trajectory of the states (i.e., a rollout), and is saved in the experience
Expand Down Expand Up @@ -206,6 +214,7 @@ def __init__(
use_last_action_on_fail=use_last_action_on_fail,
name=name,
)
self.hessian_type = hessian_type
self._sensitivity = self._init_sensitivity(linsolver)
self._Phi = (
monomials_basis_function(mpc.ns, 0, 2)
Expand All @@ -232,12 +241,18 @@ def __init__(

def update(self) -> Optional[str]:
sample = self.experience.sample()
dJdtheta = _estimate_gradient_update(
sample, self.discount_factor, self.regularization
)
if self.hessian_type == "natural":
dJdtheta, fisher_hessian = _estimate_gradient_update(
sample, self.discount_factor, self.regularization, True
)
else:
dJdtheta = _estimate_gradient_update(
sample, self.discount_factor, self.regularization, False
)
fisher_hessian = None
if self.policy_gradients is not None:
self.policy_gradients.append(dJdtheta)
return self.optimizer.update(dJdtheta)
return self.optimizer.update(dJdtheta, fisher_hessian)

def train_one_episode(
self,
Expand Down Expand Up @@ -305,7 +320,9 @@ def train_one_episode(
def _init_sensitivity(self, linsolver: str) -> Callable[[cs.DM, int], np.ndarray]:
"""Internal utility to compute the derivatives w.r.t. the learnable parameters
and other functions in order to estimate the policy gradient."""
assert self.optimizer._order == 1, "Expected 1st-order optimizer."
assert (self.hessian_type == "none" and self.optimizer._order == 1) or (
self.hessian_type == "natural" and self.optimizer._order == 2
), "expected 1st-order (2nd-order) optimizer with `none` (`natural`) hessian"
nlp = self._V.nlp
y = nlp.primal_dual
theta = cs.vvcat(self._learnable_pars.sym.values())
Expand Down Expand Up @@ -426,18 +443,32 @@ def _compute_cafa_weight_w(


def _estimate_gradient_update(
sample: Iterator[ExpType], discount_factor: float, regularization: float
) -> np.ndarray:
"""Internal utility to estimate the gradient of the policy."""
sample: Iterator[ExpType],
discount_factor: float,
regularization: float,
return_fisher_hessian: bool,
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
"""Internal utility to estimate the gradient of the policy and possibly the
Fisher information matrix as well."""
# compute average v and w
sample_ = list(sample) # load whole iterator into a list
v = np.mean([o[4] for o in sample_], 0)
w_list = [
w_ = [
_compute_cafa_weight_w(Phi, Psi, L, v, discount_factor, regularization)
for L, Phi, Psi, _, _ in sample_
]
w = np.mean(w_list, 0)

# compute policy gradient estimate
dJdtheta_list = [(o[3] @ o[3].transpose((0, 2, 1))).sum(0) @ w for o in sample_]
return np.mean(dJdtheta_list, 0)
w = np.mean(w_, 0)

if return_fisher_hessian:
# compute both policy gradient and Fisher information matrix
fisher_hess_ = []
dJdtheta_ = []
for _, _, _, dpidtheta, _ in sample_:
F = (dpidtheta @ dpidtheta.transpose((0, 2, 1))).sum(0)
fisher_hess_.append(F)
dJdtheta_.append(F @ w)
return np.mean(dJdtheta_, 0), np.mean(fisher_hess_, 0)

# compute only policy gradient estimate
dJdtheta_ = [(o[3] @ o[3].transpose((0, 2, 1))).sum(0) @ w for o in sample_]
return np.mean(dJdtheta_, 0)

0 comments on commit 8976e45

Please sign in to comment.