Skip to content

Commit

Permalink
Merge pull request #442 from vroulet:zoom_linesearch_revamp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543992823
  • Loading branch information
JAXopt authors committed Jun 28, 2023
2 parents 7760823 + ef25b9f commit 1572796
Show file tree
Hide file tree
Showing 9 changed files with 1,338 additions and 870 deletions.
1 change: 1 addition & 0 deletions jaxopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@
from jaxopt._src.scipy_wrappers import ScipyLeastSquares
from jaxopt._src.scipy_wrappers import ScipyMinimize
from jaxopt._src.scipy_wrappers import ScipyRootFinding
from jaxopt._src.zoom_linesearch import ZoomLineSearch
114 changes: 49 additions & 65 deletions jaxopt/_src/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch
from jaxopt._src.zoom_linesearch import zoom_linesearch
from jaxopt.tree_util import tree_add_scalar_mul
from jaxopt.tree_util import tree_l2_norm
from jaxopt.tree_util import tree_sub
from jaxopt._src.tree_util import tree_single_dtype
from jaxopt._src.scipy_wrappers import make_onp_to_jnp
from jaxopt._src.scipy_wrappers import pytree_topology_from_example
from jaxopt._src.linesearch_util import _reset_stepsize
from jaxopt._src.linesearch_util import _setup_linesearch


_dot = partial(jnp.dot, precision=jax.lax.Precision.HIGHEST)
Expand All @@ -53,6 +53,7 @@ def pytree_to_flat_array(pytree, dtype):

class BfgsState(NamedTuple):
"""Named tuple containing state information."""

iter_num: int
value: float
grad: Any
Expand All @@ -74,10 +75,8 @@ class BFGS(base.IterativeSolver):
value_and_grad: whether ``fun`` just returns the value (False) or both
the value and gradient (True).
has_aux: whether ``fun`` outputs auxiliary data or not.
If ``has_aux`` is False, ``fun`` is expected to be
scalar-valued.
If ``has_aux`` is True, then we have one of the following
two cases.
If ``has_aux`` is False, ``fun`` is expected to be scalar-valued.
If ``has_aux`` is True, then we have one of the following two cases.
If ``value_and_grad`` is False, the output should be
``value, aux = fun(...)``.
If ``value_and_grad == True``, the output should be
Expand All @@ -92,6 +91,8 @@ class BFGS(base.IterativeSolver):
or a callable specifying the **positive** stepsize to use at each iteration.
linesearch: the type of line search to use: "backtracking" for backtracking
line search or "zoom" for zoom line search.
condition: condition used to select the stepsize when using backtracking
linesearch
maxls: maximum number of iterations to use in the line search.
decrease_factor: factor by which to decrease the stepsize during line search
(default: 0.8).
Expand Down Expand Up @@ -151,6 +152,7 @@ def init_state(self,
init_params: pytree containing the initial parameters.
*args: additional positional arguments to be passed to ``fun``.
**kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
state
"""
Expand Down Expand Up @@ -179,6 +181,7 @@ def update(self,
state: named tuple containing the solver state.
*args: additional positional arguments to be passed to ``fun``.
**kwargs: additional keyword arguments to be passed to ``fun``.
Returns:
(params, state)
"""
Expand All @@ -190,65 +193,18 @@ def update(self,

descent_direction = flat_array_to_pytree(-_dot(state.H, flat_grad))

if not isinstance(self.stepsize, Callable) and self.stepsize <= 0:
# with line search

if self.linesearch == "backtracking":
ls = BacktrackingLineSearch(fun=self._value_and_grad_with_aux,
value_and_grad=True,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
max_stepsize=self.max_stepsize,
condition=self.condition,
jit=self.jit,
unroll=self.unroll,
has_aux=True)
init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
self.max_stepsize,
# Else, we increase a bit the previous one.
state.stepsize * self.increase_factor)
new_stepsize, ls_state = ls.run(init_stepsize,
params, value, grad,
descent_direction,
*args, **kwargs)
new_value = ls_state.value
new_aux = ls_state.aux
new_params = ls_state.params
new_grad = ls_state.grad

elif self.linesearch == "zoom":
ls_state = zoom_linesearch(f=self._value_and_grad_with_aux,
xk=params, pk=descent_direction,
old_fval=value, gfk=grad, maxiter=self.maxls,
value_and_grad=True, has_aux=True, aux=state.aux,
args=args, kwargs=kwargs)
new_value = ls_state.f_k
new_aux = ls_state.aux
new_stepsize = ls_state.a_k
new_grad = ls_state.g_k
# FIXME: zoom_linesearch currently doesn't return new_params
# so we have to recompute it.
t = new_stepsize.astype(tree_single_dtype(params))
new_params = tree_add_scalar_mul(params, t, descent_direction)
# FIXME: (zaccharieramzi) sometimes the linesearch fails
# and therefore its value g_k does not correspond
# to the gradient at the new parameters.
# with the following conditional loop we have a hot fix that just
# recomputes the value, gradient and auxiliary value
# at the new parameters. It would be better to understand
# what the g_k passed by zoom_linesearch is in this case
# and why it is wrong.
(new_value, new_aux), new_grad = jax.lax.cond(
ls_state.failed,
lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs),
lambda: ((new_value, new_aux), new_grad),
)
else:
raise ValueError("Invalid name in 'linesearch' option.")

use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0

if use_linesearch:
init_stepsize = self._reset_stepsize(state.stepsize)
new_stepsize, ls_state = self.run_ls(
init_stepsize, params, value, grad, descent_direction, *args, **kwargs
)
new_params = ls_state.params
new_value = ls_state.value
new_grad = ls_state.grad
new_aux = ls_state.aux
else:
# without line search
if isinstance(self.stepsize, Callable):
new_stepsize = self.stepsize(state.iter_num)
else:
Expand Down Expand Up @@ -284,7 +240,7 @@ def optimality_fun(self, params, *args, **kwargs):
return self._value_and_grad_fun(params, *args, **kwargs)[1]

def _value_and_grad_fun(self, params, *args, **kwargs):
(value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
(value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs)
return value, grad

def __post_init__(self):
Expand All @@ -294,3 +250,31 @@ def __post_init__(self):
has_aux=self.has_aux)

self.reference_signature = self.fun
jit, unroll = self._get_loop_options()
linesearch_solver = _setup_linesearch(
linesearch=self.linesearch,
fun=self._value_and_grad_with_aux,
value_and_grad=True,
has_aux=True,
maxlsiter=self.maxls,
max_stepsize=self.max_stepsize,
jit=jit,
unroll=unroll,
verbose=self.verbose,
condition=self.condition,
decrease_factor=self.decrease_factor,
increase_factor=self.increase_factor,
)

self._reset_stepsize = partial(
_reset_stepsize,
self.linesearch,
self.max_stepsize,
self.min_stepsize,
self.increase_factor,
)

if jit:
self.run_ls = jax.jit(linesearch_solver.run)
else:
self.run_ls = linesearch_solver.run
Loading

0 comments on commit 1572796

Please sign in to comment.