Skip to content

Commit

Permalink
Merge pull request #323 from zaccharieramzi:line-search-failed-breaks…
Browse files Browse the repository at this point in the history
…-lbfgs

PiperOrigin-RevId: 481901559
  • Loading branch information
JAXopt authors committed Oct 18, 2022
2 parents 6fe6d62 + 3ecda44 commit 38ebb66
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class LbfgsState(NamedTuple):
rho_history: jnp.ndarray
gamma: jnp.ndarray
aux: Optional[Any] = None
failed_linesearch: bool = False


@dataclass(eq=False)
Expand Down Expand Up @@ -173,6 +174,8 @@ class LBFGS(base.IterativeSolver):
or a callable specifying the **positive** stepsize to use at each iteration.
linesearch: the type of line search to use: "backtracking" for backtracking
line search or "zoom" for zoom line search.
stop_if_linesearch_fails: whether to stop iterations if the line search fails.
When True, this matches the behavior of core JAX.
maxls: maximum number of iterations to use in the line search.
decrease_factor: factor by which to decrease the stepsize during line search
(default: 0.8).
Expand Down Expand Up @@ -212,6 +215,7 @@ class LBFGS(base.IterativeSolver):

stepsize: Union[float, Callable] = 0.0
linesearch: str = "zoom"
stop_if_linesearch_fails: bool = False
condition: str = "strong-wolfe"
maxls: int = 15
decrease_factor: float = 0.8
Expand All @@ -231,6 +235,14 @@ class LBFGS(base.IterativeSolver):

verbose: bool = False

def _cond_fun(self, inputs):
_, state = inputs[0]
if self.verbose:
print("error:", state.error)
# We continue the optimization loop while the error tolerance is not met and,
# either failed linesearch is disallowed or linesearch hasn't failed.
return (state.error > self.tol) & (~self.stop_if_linesearch_fails | ~state.failed_linesearch)

def init_state(self,
init_params: Any,
*args,
Expand All @@ -257,7 +269,8 @@ def init_state(self,
y_history=init_history(init_params, self.history_size),
rho_history=jnp.zeros(self.history_size, dtype=dtype),
gamma=jnp.asarray(1.0, dtype=dtype),
aux=aux)
aux=aux,
failed_linesearch=jnp.asarray(False))

def update(self,
params: Any,
Expand All @@ -284,7 +297,8 @@ def update(self,
start=start)
descent_direction = tree_scalar_mul(-1.0, product)

if not isinstance(self.stepsize, Callable) and self.stepsize <= 0:
use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0
if use_linesearch:
# with line search

if self.linesearch == "backtracking":
Expand Down Expand Up @@ -352,6 +366,11 @@ def update(self,
else:
gamma = jnp.array(1.0)

if use_linesearch and self.linesearch == "zoom":
failed_linesearch = ls_state.failed
else: # backtracking linesearch doesn't support failed state yet
failed_linesearch = jnp.asarray(False)

new_state = LbfgsState(iter_num=state.iter_num + 1,
value=new_value,
stepsize=jnp.asarray(new_stepsize),
Expand All @@ -363,7 +382,8 @@ def update(self,
# FIXME: we should return new_aux here but
# BacktrackingLineSearch currently doesn't support
# an aux.
aux=aux)
aux=aux,
failed_linesearch=failed_linesearch)

return base.OptStep(params=new_params, state=new_state)

Expand Down

0 comments on commit 38ebb66

Please sign in to comment.