From fb7f4460e35d09c95a9a0c50bca172278817a0b9 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Sat, 30 Mar 2024 17:13:08 +0100 Subject: [PATCH 1/2] Replace jnp with tree_utils --- jaxopt/_src/levenberg_marquardt.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/jaxopt/_src/levenberg_marquardt.py b/jaxopt/_src/levenberg_marquardt.py index 1a3646ac..6b69db88 100644 --- a/jaxopt/_src/levenberg_marquardt.py +++ b/jaxopt/_src/levenberg_marquardt.py @@ -35,8 +35,9 @@ from jaxopt._src.linear_solve import solve_inv from jaxopt._src.linear_solve import solve_lu from jaxopt._src.linear_solve import solve_qr -from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul +from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul, tree_vdot, tree_zeros_like, tree_scalar_mul +import jax.flatten_util class LevenbergMarquardtState(NamedTuple): """Named tuple containing state information.""" @@ -214,9 +215,9 @@ def init_state(self, init_params: Any, *args, hess_res = None gradient = self._jt_op(init_params, residual, *args, **kwargs) jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs) - damping_factor = self.damping_parameter * jnp.max(jtj_diag) + damping_factor = self.damping_parameter * tree_inf_norm(jtj_diag) - delta_params = jnp.zeros_like(init_params) + delta_params = tree_zeros_like(init_params) return LevenbergMarquardtState( iter_num=jnp.asarray(0), @@ -320,14 +321,14 @@ def update_state_using_delta_params(self, loss_curr, params, delta_params, for the value of dparams. """ - updated_params = params + delta_params + updated_params = tree_add(params, delta_params) residual_next = self._fun(updated_params, *args, **kwargs) # Calculate denominator of the gain ratio based on Eq. 6.16, "Introduction # to optimization and data fitting", L(0)-L(hlm)=0.5*hlm^T*(mu*hlm-g). - gain_ratio_denom = 0.5 * delta_params.T @ ( - damping_factor * delta_params - gradient) + gain_ratio_denom = 0.5 * tree_vdot(delta_params, + tree_sub(tree_scalar_mul(damping_factor, delta_params), gradient)) # Current value of loss function F=0.5*||f||^2. loss_next = 0.5 * jnp.sum(jnp.square(residual_next)) @@ -421,7 +422,7 @@ def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep: # Negative coefficient is due to the sign of the RHS vector in the update equation # (J^T @ J + µ I) @ ∆params = -J^T @ f(x). - delta_params = -delta_params + delta_params = tree_scalar_mul(-1, delta_params) # Checking if the dparams satisfy the "sufficiently small" criteria. params, damping_factor, increase_factor, residual, gradient, jac, jt, jtj, hess_res, aux = ( @@ -531,8 +532,13 @@ def _jtj_op(self, params, vec, *args, **kwargs): def _jtj_diag_op(self, params, *args, **kwargs): """Diagonal elements of J^T.J, where J is jacobian of fun at params.""" - diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs) - return jax.vmap(diag_op)(jnp.eye(len(params))).T + diag_op = lambda v: tree_vdot(v, self._jtj_op(params, v, *args, **kwargs)) + _, unflatten_fn = jax.flatten_util.ravel_pytree(params) + param_count = sum(x.size for x in jax.tree_leaves(params)) + eye_pytree = jax.vmap(unflatten_fn)(jnp.eye(param_count)) + diag_vec = jax.vmap(diag_op)(eye_pytree).T + diag_pytree = unflatten_fn(diag_vec) + return diag_pytree def _d2fvv_op(self, primals, tangents1, tangents2, *args, **kwargs): """Product with d2f.v1v2.""" @@ -554,7 +560,7 @@ def _solve_linear_eqs(self, matvec, state, params, *args, **kwargs): acceleration = self.solver_fn(matvec, jtrpp, ridge=state.damping_factor) delta_params += 0.5*acceleration else: - acceleration = jnp.zeros_like(velocity) + acceleration = tree_zeros_like(velocity) return (velocity, acceleration, delta_params) From 9dc3e7b6b7a657223a7f30b28397af12b1c1dee6 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Sun, 31 Mar 2024 10:50:39 +0200 Subject: [PATCH 2/2] Replace jnp with tree_utils in geodesic --- jaxopt/_src/levenberg_marquardt.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/levenberg_marquardt.py b/jaxopt/_src/levenberg_marquardt.py index 6b69db88..c3d92225 100644 --- a/jaxopt/_src/levenberg_marquardt.py +++ b/jaxopt/_src/levenberg_marquardt.py @@ -35,7 +35,10 @@ from jaxopt._src.linear_solve import solve_inv from jaxopt._src.linear_solve import solve_lu from jaxopt._src.linear_solve import solve_qr -from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul, tree_vdot, tree_zeros_like, tree_scalar_mul + +from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add +from jaxopt._src.tree_util import tree_mul, tree_vdot, tree_zeros_like +from jaxopt._src.tree_util import tree_scalar_mul, tree_add_scalar_mul import jax.flatten_util @@ -415,7 +418,7 @@ def update(self, params, state: NamedTuple, *args, **kwargs) -> base.OptStep: ) if self.geodesic: - contribution_ratio_diff = jnp.linalg.norm(acceleration) / jnp.linalg.norm( + contribution_ratio_diff = tree_l2_norm(acceleration) / tree_l2_norm( velocity) - self.contribution_ratio_threshold else: contribution_ratio_diff = 0.0 @@ -558,7 +561,7 @@ def _solve_linear_eqs(self, matvec, state, params, *args, **kwargs): rpp = self._d2fvv_op(params, velocity, velocity, *args, **kwargs) jtrpp = self._jt_op(params, rpp, *args, **kwargs) acceleration = self.solver_fn(matvec, jtrpp, ridge=state.damping_factor) - delta_params += 0.5*acceleration + delta_params = tree_add_scalar_mul(delta_params, 0.5, acceleration) else: acceleration = tree_zeros_like(velocity)