Skip to content

Commit

Permalink
Add min_stepsize and max_stepsize to NonlinearCG.
Browse files Browse the repository at this point in the history
mblondel committed Jun 10, 2022
1 parent dafe765 commit a666044
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,8 @@ class NonlinearCG(base.IterativeSolver):
(default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.2).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
@@ -87,6 +89,10 @@ class NonlinearCG(base.IterativeSolver):
maxls: int = 15
decrease_factor: float = 0.8
increase_factor: float = 1.2
max_stepsize: float = 1.0
# FIXME: should depend on whether float32 or float64 is used.
min_stepsize: float = 1e-6

implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None

@@ -110,7 +116,7 @@ def init_state(self,
value, grad = self._value_and_grad_fun(init_params, *args, **kwargs)

return NonlinearCGState(iter_num=jnp.asarray(0),
stepsize=jnp.asarray(1.0),
stepsize=jnp.asarray(self.max_stepsize),
error=jnp.asarray(jnp.inf),
value=value,
grad=grad,
@@ -133,12 +139,19 @@ def update(self,

eps = 1e-6
value, grad, descent_direction = state.value, state.grad, state.descent_direction
init_stepsize = state.stepsize * self.increase_factor
ls = BacktrackingLineSearch(fun=self._value_and_grad_fun,
value_and_grad=True,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
condition=self.condition)
condition=self.condition,
max_stepsize=self.max_stepsize)

init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
self.max_stepsize,
# Otherwise, we increase a bit the previous one.
state.stepsize * self.increase_factor)

new_stepsize, ls_state = ls.run(init_stepsize,
params,
value,

0 comments on commit a666044

Please sign in to comment.