Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LevenbergMarquardt and pytrees #587

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions jaxopt/_src/levenberg_marquardt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
from jaxopt._src.linear_solve import solve_inv
from jaxopt._src.linear_solve import solve_lu
from jaxopt._src.linear_solve import solve_qr
from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul

from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add
from jaxopt._src.tree_util import tree_mul, tree_vdot, tree_zeros_like
from jaxopt._src.tree_util import tree_scalar_mul, tree_add_scalar_mul

import jax.flatten_util

class LevenbergMarquardtState(NamedTuple):
"""Named tuple containing state information."""
Expand Down Expand Up @@ -214,9 +218,9 @@ def init_state(self, init_params: Any, *args,
hess_res = None
gradient = self._jt_op(init_params, residual, *args, **kwargs)
jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
damping_factor = self.damping_parameter * jnp.max(jtj_diag)
damping_factor = self.damping_parameter * tree_inf_norm(jtj_diag)

delta_params = jnp.zeros_like(init_params)
delta_params = tree_zeros_like(init_params)

return LevenbergMarquardtState(
iter_num=jnp.asarray(0),
Expand Down Expand Up @@ -320,14 +324,14 @@ def update_state_using_delta_params(self, loss_curr, params, delta_params,
for the value of dparams.
"""

updated_params = params + delta_params
updated_params = tree_add(params, delta_params)

residual_next = self._fun(updated_params, *args, **kwargs)

# Calculate denominator of the gain ratio based on Eq. 6.16, "Introduction
# to optimization and data fitting", L(0)-L(hlm)=0.5*hlm^T*(mu*hlm-g).
gain_ratio_denom = 0.5 * delta_params.T @ (
damping_factor * delta_params - gradient)
gain_ratio_denom = 0.5 * tree_vdot(delta_params,
tree_sub(tree_scalar_mul(damping_factor, delta_params), gradient))

# Current value of loss function F=0.5*||f||^2.
loss_next = 0.5 * jnp.sum(jnp.square(residual_next))
Expand Down Expand Up @@ -414,14 +418,14 @@ def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep:
)

if self.geodesic:
contribution_ratio_diff = jnp.linalg.norm(acceleration) / jnp.linalg.norm(
contribution_ratio_diff = tree_l2_norm(acceleration) / tree_l2_norm(
velocity) - self.contribution_ratio_threshold
else:
contribution_ratio_diff = 0.0

# Negative coefficient is due to the sign of the RHS vector in the update equation
# (J^T @ J + µ I) @ ∆params = -J^T @ f(x).
delta_params = -delta_params
delta_params = tree_scalar_mul(-1, delta_params)

# Checking if the dparams satisfy the "sufficiently small" criteria.
params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = (
Expand Down Expand Up @@ -531,8 +535,13 @@ def _jtj_op(self, params, vec, *args, **kwargs):

def _jtj_diag_op(self, params, *args, **kwargs):
"""Diagonal elements of J^T.J, where J is jacobian of fun at params."""
diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
return jax.vmap(diag_op)(jnp.eye(len(params))).T
diag_op = lambda v: tree_vdot(v, self._jtj_op(params, v, *args, **kwargs))
_, unflatten_fn = jax.flatten_util.ravel_pytree(params)
param_count = sum(x.size for x in jax.tree_leaves(params))
eye_pytree = jax.vmap(unflatten_fn)(jnp.eye(param_count))
diag_vec = jax.vmap(diag_op)(eye_pytree).T
diag_pytree = unflatten_fn(diag_vec)
return diag_pytree

def _d2fvv_op(self, primals, tangents1, tangents2, *args, **kwargs):
"""Product with d2f.v1v2."""
Expand All @@ -552,9 +561,9 @@ def _solve_linear_eqs(self, matvec, state, params, *args, **kwargs):
rpp = self._d2fvv_op(params, velocity, velocity, *args, **kwargs)
jtrpp = self._jt_op(params, rpp, *args, **kwargs)
acceleration = self.solver_fn(matvec, jtrpp, ridge=state.damping_factor)
delta_params += 0.5*acceleration
delta_params = tree_add_scalar_mul(delta_params, 0.5, acceleration)
else:
acceleration = jnp.zeros_like(velocity)
acceleration = tree_zeros_like(velocity)

return (velocity, acceleration, delta_params)

Expand Down