diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 88603f41..11eed03e 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -63,6 +63,8 @@ class NonlinearCG(base.IterativeSolver): (default: 0.8). increase_factor: factor by which to increase the stepsize during line search (default: 1.2). + max_stepsize: upper bound on stepsize. + min_stepsize: lower bound on stepsize. implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. @@ -87,6 +89,10 @@ class NonlinearCG(base.IterativeSolver): maxls: int = 15 decrease_factor: float = 0.8 increase_factor: float = 1.2 + max_stepsize: float = 1.0 + # FIXME: should depend on whether float32 or float64 is used. + min_stepsize: float = 1e-6 + implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -110,7 +116,7 @@ def init_state(self, value, grad = self._value_and_grad_fun(init_params, *args, **kwargs) return NonlinearCGState(iter_num=jnp.asarray(0), - stepsize=jnp.asarray(1.0), + stepsize=jnp.asarray(self.max_stepsize), error=jnp.asarray(jnp.inf), value=value, grad=grad, @@ -133,12 +139,19 @@ def update(self, eps = 1e-6 value, grad, descent_direction = state.value, state.grad, state.descent_direction - init_stepsize = state.stepsize * self.increase_factor ls = BacktrackingLineSearch(fun=self._value_and_grad_fun, value_and_grad=True, maxiter=self.maxls, decrease_factor=self.decrease_factor, - condition=self.condition) + condition=self.condition, + max_stepsize=self.max_stepsize) + + init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, + # If stepsize became too small, we restart it. + self.max_stepsize, + # Otherwise, we increase a bit the previous one. + state.stepsize * self.increase_factor) + new_stepsize, ls_state = ls.run(init_stepsize, params, value,