diff --git a/jaxopt/_src/backtracking_linesearch.py b/jaxopt/_src/backtracking_linesearch.py index 189e3df5..60f032f4 100644 --- a/jaxopt/_src/backtracking_linesearch.py +++ b/jaxopt/_src/backtracking_linesearch.py @@ -39,7 +39,7 @@ class BacktrackingLineSearchState(NamedTuple): params: Any value: float grad: Any # either initial or final for armijo or glodstein - value_init: float + value_init: float grad_init: Any error: float done: bool @@ -260,11 +260,11 @@ def update( if self.condition in ["armijo", "goldstein"]: # If we are done for the armijo or the goldstein conditions, - # we compute the final gradient (we had not computed it before since + # we compute the final gradient (we had not computed it before since # these conditions did not require it) new_grad = cond(done | failed, self._compute_final_grad, - lambda *_: grad, + lambda *_: grad, new_params, fun_args, fun_kwargs, jit=self.jit) maybe_additional_eval = jnp.asarray(done | failed, dtype=base.NUM_EVAL_DTYPE) @@ -284,7 +284,7 @@ def update( num_grad_eval=num_grad_eval) return base.LineSearchStep(stepsize=new_stepsize, state=new_state) - + def _compute_final_grad(self, params, fun_args, fun_kwargs): return self._grad_with_aux(params, *fun_args, **fun_kwargs)[0] diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index bd54222d..83e34988 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -14,6 +14,8 @@ """Limited-memory Broyden method""" +import warnings + from functools import partial from typing import Any @@ -28,7 +30,7 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch +from jaxopt._src.linesearch_util import _setup_linesearch, _init_stepsize from jaxopt.tree_util import tree_map from jaxopt.tree_util import tree_vdot from jaxopt.tree_util import tree_add_scalar_mul @@ -80,21 +82,21 @@ def inv_jacobian_product(pytree: Any, Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`. c_history: pytree with the same structure as `pytree`. Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`. - gamma: scalar to use for the initial inverse jacobian approximation, + gamma: pytree with scalars to use for the initial inverse jacobian approximation, i.e., `gamma * I`. start: starting index in the circular buffer. """ fun = partial(inv_jacobian_product_leaf, - gamma=gamma, start=start) - return tree_map(fun, pytree, d_history, c_history) + return tree_map(fun, pytree, d_history, c_history, gamma) def inv_jacobian_rproduct(pytree: Any, d_history: Any, c_history: Any, gamma: float = 1.0, start: int = 0): - return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start) + gamma_conj = tree_map(jnp.conjugate, gamma) + return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start) def init_history(pytree, history_size): @@ -115,7 +117,7 @@ class BroydenState(NamedTuple): error: float d_history: Any c_history: Any - gamma: jnp.ndarray + gamma: Any aux: Optional[Any] = None failed_linesearch: bool = False @@ -194,10 +196,11 @@ class Broyden(base.IterativeSolver): stepsize: Union[float, Callable] = 0.0 linesearch: str = "backtracking" + linesearch_init: str = "increase" stop_if_linesearch_fails: bool = False - condition: str = "wolfe" + condition: Any = None # deprecated in v0.8 maxls: int = 15 - decrease_factor: float = 0.8 + decrease_factor: Any = None # deprecated in v0.8 increase_factor: float = 1.5 max_stepsize: float = 1.0 # FIXME: should depend on whether float32 or float64 is used. @@ -205,6 +208,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 + compute_gamma: bool = True implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -244,18 +248,37 @@ def init_state(self, iter_num=init_params.state.iter_num, stepsize=init_params.state.stepsize, ) + # XXX: not computing the jacobian init approx + # when starting from an OptStep object init_params = init_params.params dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) else: dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) + if self.compute_gamma: + # we use scipy's formula: + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569 + # self.alpha = 0.5*max(norm(x0), 1) / normf0 + normf0 = tree_map(jnp.linalg.norm, value) + normx0 = tree_map(jnp.linalg.norm, init_params) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0) + def safe_divide_by_zero(x, y): + # a classical division of x by y + # when y == 0 then return 1 + return jnp.where(y == 0, 1, x / y) + gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) + else: + gamma = self.gamma + # repeat gamma as a pytree of the shape of init_params + gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), c_history=init_history(init_params, self.history_size), - gamma=jnp.asarray(self.gamma, dtype=dtype), + gamma=gamma, iter_num=jnp.asarray(0), stepsize=jnp.asarray(self.max_stepsize, dtype=dtype), ) - value, aux = self._value_with_aux(init_params, *args, **kwargs) return BroydenState(value=value, error=jnp.asarray(jnp.inf), **state_kwargs, @@ -308,45 +331,29 @@ def update(self, use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0 if use_linesearch: - if self.linesearch == "backtracking": - # we need to build the function used for the line search - # which is going to be the squared norm of the original function - # as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278 - # we then need to check if the gradient can be obtained with jax - # and if not we can build it in the same fashion as scipy - # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285 - def ls_fun_with_aux(params, *args, **kwargs): - f, aux = self._value_with_aux(params, *args, **kwargs) - norm_squared = tree_l2_norm(f, squared=True) - return norm_squared, (f, aux) - # here we need a check if the function is not smooth - ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True) - ls = BacktrackingLineSearch(fun=ls_fun_with_aux_and_grad, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - max_stepsize=self.max_stepsize, - condition=self.condition, - jit=self.jit, - unroll=self.unroll, - has_aux=True, - tol=1e-2) - init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Else, we increase a bit the previous one. - state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize, - params, value, None, - descent_direction, - fun_args=args, fun_kwargs=kwargs) - new_value, new_aux = ls_state.aux - new_params = ls_state.params - new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num - new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval - failed_linesearch = ls_state.failed - else: - raise ValueError("Invalid name in 'linesearch' option.") + init_stepsize = _init_stepsize( + self.linesearch_init, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + state.stepsize, + ) + new_stepsize, ls_state = self.run_ls( + init_stepsize, + params, + value=tree_l2_norm(value), + # in the case of Broyden, it's the value that's actually the equivalent + # of the gradient in the optimization case. + grad=value, + descent_direction=descent_direction, + fun_args=args, + fun_kwargs=kwargs, + ) + new_value, new_aux = ls_state.aux + new_params = ls_state.params + new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num + new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval + failed_linesearch = ls_state.failed else: # without line search if isinstance(self.stepsize, Callable): @@ -409,3 +416,36 @@ def __post_init__(self): if self.history_size is None: self.history_size = self.maxiter + + # we need to build the function used for the line search + # which is going to be the squared norm of the original function + # as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278 + # we then need to check if thtree_l2_norme gradient can be obtained with jax + # and if not we can build it in the same fashion as scipy + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285 + def ls_fun_with_aux(params, *args, **kwargs): + f, aux = self._value_with_aux(params, *args, **kwargs) + norm_squared = tree_l2_norm(f, squared=True) + return norm_squared, (f, aux) + # here we need a check if the function is not smooth + ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True) + self.linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=ls_fun_with_aux_and_grad, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=self.jit, + unroll=self.unroll, + verbose=self.verbose, + ) + self.run_ls = self.linesearch_solver.run + + # FIXME: to remove in future releases + if self.condition is not None: + warnings.warn("Argument condition is deprecated", DeprecationWarning) + if self.decrease_factor is not None: + warnings.warn( + "Argument decrease_factor is deprecated", DeprecationWarning + ) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 60503a55..1c9427eb 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0) + sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, stop_if_linesearch_fails=True).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol) @@ -134,7 +134,7 @@ def test_affine_contractive_mapping(self): def g(x, M, b): return M @ x + b - x tol = 1e-6 - fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) + fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) self.assertLess(state.error, tol)