diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 9f79ada9..dfac2aaf 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -144,6 +144,7 @@ class LbfgsState(NamedTuple): rho_history: jnp.ndarray gamma: jnp.ndarray aux: Optional[Any] = None + failed_linesearch: bool = False @dataclass(eq=False) @@ -173,6 +174,8 @@ class LBFGS(base.IterativeSolver): or a callable specifying the **positive** stepsize to use at each iteration. linesearch: the type of line search to use: "backtracking" for backtracking line search or "zoom" for zoom line search. + stop_if_linesearch_fails: whether to stop iterations if the line search fails. + When True, this matches the behavior of core JAX. maxls: maximum number of iterations to use in the line search. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). @@ -212,6 +215,7 @@ class LBFGS(base.IterativeSolver): stepsize: Union[float, Callable] = 0.0 linesearch: str = "zoom" + stop_if_linesearch_fails: bool = False condition: str = "strong-wolfe" maxls: int = 15 decrease_factor: float = 0.8 @@ -231,6 +235,14 @@ class LBFGS(base.IterativeSolver): verbose: bool = False + def _cond_fun(self, inputs): + _, state = inputs[0] + if self.verbose: + print("error:", state.error) + # We continue the optimization loop while the error tolerance is not met and, + # either failed linesearch is disallowed or linesearch hasn't failed. + return (state.error > self.tol) & (~self.stop_if_linesearch_fails | ~state.failed_linesearch) + def init_state(self, init_params: Any, *args, @@ -257,7 +269,8 @@ def init_state(self, y_history=init_history(init_params, self.history_size), rho_history=jnp.zeros(self.history_size, dtype=dtype), gamma=jnp.asarray(1.0, dtype=dtype), - aux=aux) + aux=aux, + failed_linesearch=jnp.asarray(False)) def update(self, params: Any, @@ -284,7 +297,8 @@ def update(self, start=start) descent_direction = tree_scalar_mul(-1.0, product) - if not isinstance(self.stepsize, Callable) and self.stepsize <= 0: + use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0 + if use_linesearch: # with line search if self.linesearch == "backtracking": @@ -352,6 +366,11 @@ def update(self, else: gamma = jnp.array(1.0) + if use_linesearch and self.linesearch == "zoom": + failed_linesearch = ls_state.failed + else: # backtracking linesearch doesn't support failed state yet + failed_linesearch = jnp.asarray(False) + new_state = LbfgsState(iter_num=state.iter_num + 1, value=new_value, stepsize=jnp.asarray(new_stepsize), @@ -363,7 +382,8 @@ def update(self, # FIXME: we should return new_aux here but # BacktrackingLineSearch currently doesn't support # an aux. - aux=aux) + aux=aux, + failed_linesearch=failed_linesearch) return base.OptStep(params=new_params, state=new_state)