From ef25b9f606da469ca30a87dc55e427794f654c66 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Wed, 14 Jun 2023 09:09:32 -0700 Subject: [PATCH] revamped zoom linesearch --- jaxopt/__init__.py | 1 + jaxopt/_src/bfgs.py | 114 ++-- jaxopt/_src/lbfgs.py | 130 ++-- jaxopt/_src/lbfgsb.py | 171 ++--- jaxopt/_src/linesearch_util.py | 109 +++ jaxopt/_src/nonlinear_cg.py | 120 ++-- jaxopt/_src/zoom_linesearch.py | 1143 +++++++++++++++++++------------- tests/lbfgsb_test.py | 4 +- tests/zoom_linesearch_test.py | 416 ++++++++---- 9 files changed, 1338 insertions(+), 870 deletions(-) create mode 100644 jaxopt/_src/linesearch_util.py diff --git a/jaxopt/__init__.py b/jaxopt/__init__.py index bdc28b09..f5cca75c 100644 --- a/jaxopt/__init__.py +++ b/jaxopt/__init__.py @@ -52,3 +52,4 @@ from jaxopt._src.scipy_wrappers import ScipyLeastSquares from jaxopt._src.scipy_wrappers import ScipyMinimize from jaxopt._src.scipy_wrappers import ScipyRootFinding +from jaxopt._src.zoom_linesearch import ZoomLineSearch diff --git a/jaxopt/_src/bfgs.py b/jaxopt/_src/bfgs.py index 3f60b98e..e1d5229b 100644 --- a/jaxopt/_src/bfgs.py +++ b/jaxopt/_src/bfgs.py @@ -28,14 +28,14 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch -from jaxopt._src.zoom_linesearch import zoom_linesearch from jaxopt.tree_util import tree_add_scalar_mul from jaxopt.tree_util import tree_l2_norm from jaxopt.tree_util import tree_sub from jaxopt._src.tree_util import tree_single_dtype from jaxopt._src.scipy_wrappers import make_onp_to_jnp from jaxopt._src.scipy_wrappers import pytree_topology_from_example +from jaxopt._src.linesearch_util import _reset_stepsize +from jaxopt._src.linesearch_util import _setup_linesearch _dot = partial(jnp.dot, precision=jax.lax.Precision.HIGHEST) @@ -53,6 +53,7 @@ def pytree_to_flat_array(pytree, dtype): class BfgsState(NamedTuple): """Named tuple containing state information.""" + iter_num: int value: float grad: Any @@ -74,10 +75,8 @@ class BFGS(base.IterativeSolver): value_and_grad: whether ``fun`` just returns the value (False) or both the value and gradient (True). has_aux: whether ``fun`` outputs auxiliary data or not. - If ``has_aux`` is False, ``fun`` is expected to be - scalar-valued. - If ``has_aux`` is True, then we have one of the following - two cases. + If ``has_aux`` is False, ``fun`` is expected to be scalar-valued. + If ``has_aux`` is True, then we have one of the following two cases. If ``value_and_grad`` is False, the output should be ``value, aux = fun(...)``. If ``value_and_grad == True``, the output should be @@ -92,6 +91,8 @@ class BFGS(base.IterativeSolver): or a callable specifying the **positive** stepsize to use at each iteration. linesearch: the type of line search to use: "backtracking" for backtracking line search or "zoom" for zoom line search. + condition: condition used to select the stepsize when using backtracking + linesearch maxls: maximum number of iterations to use in the line search. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). @@ -151,6 +152,7 @@ def init_state(self, init_params: pytree containing the initial parameters. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: state """ @@ -179,6 +181,7 @@ def update(self, state: named tuple containing the solver state. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: (params, state) """ @@ -190,65 +193,18 @@ def update(self, descent_direction = flat_array_to_pytree(-_dot(state.H, flat_grad)) - if not isinstance(self.stepsize, Callable) and self.stepsize <= 0: - # with line search - - if self.linesearch == "backtracking": - ls = BacktrackingLineSearch(fun=self._value_and_grad_with_aux, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - max_stepsize=self.max_stepsize, - condition=self.condition, - jit=self.jit, - unroll=self.unroll, - has_aux=True) - init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Else, we increase a bit the previous one. - state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize, - params, value, grad, - descent_direction, - *args, **kwargs) - new_value = ls_state.value - new_aux = ls_state.aux - new_params = ls_state.params - new_grad = ls_state.grad - - elif self.linesearch == "zoom": - ls_state = zoom_linesearch(f=self._value_and_grad_with_aux, - xk=params, pk=descent_direction, - old_fval=value, gfk=grad, maxiter=self.maxls, - value_and_grad=True, has_aux=True, aux=state.aux, - args=args, kwargs=kwargs) - new_value = ls_state.f_k - new_aux = ls_state.aux - new_stepsize = ls_state.a_k - new_grad = ls_state.g_k - # FIXME: zoom_linesearch currently doesn't return new_params - # so we have to recompute it. - t = new_stepsize.astype(tree_single_dtype(params)) - new_params = tree_add_scalar_mul(params, t, descent_direction) - # FIXME: (zaccharieramzi) sometimes the linesearch fails - # and therefore its value g_k does not correspond - # to the gradient at the new parameters. - # with the following conditional loop we have a hot fix that just - # recomputes the value, gradient and auxiliary value - # at the new parameters. It would be better to understand - # what the g_k passed by zoom_linesearch is in this case - # and why it is wrong. - (new_value, new_aux), new_grad = jax.lax.cond( - ls_state.failed, - lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs), - lambda: ((new_value, new_aux), new_grad), - ) - else: - raise ValueError("Invalid name in 'linesearch' option.") - + use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0 + + if use_linesearch: + init_stepsize = self._reset_stepsize(state.stepsize) + new_stepsize, ls_state = self.run_ls( + init_stepsize, params, value, grad, descent_direction, *args, **kwargs + ) + new_params = ls_state.params + new_value = ls_state.value + new_grad = ls_state.grad + new_aux = ls_state.aux else: - # without line search if isinstance(self.stepsize, Callable): new_stepsize = self.stepsize(state.iter_num) else: @@ -284,7 +240,7 @@ def optimality_fun(self, params, *args, **kwargs): return self._value_and_grad_fun(params, *args, **kwargs)[1] def _value_and_grad_fun(self, params, *args, **kwargs): - (value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs) + (value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs) return value, grad def __post_init__(self): @@ -294,3 +250,31 @@ def __post_init__(self): has_aux=self.has_aux) self.reference_signature = self.fun + jit, unroll = self._get_loop_options() + linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=self._value_and_grad_with_aux, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=jit, + unroll=unroll, + verbose=self.verbose, + condition=self.condition, + decrease_factor=self.decrease_factor, + increase_factor=self.increase_factor, + ) + + self._reset_stepsize = partial( + _reset_stepsize, + self.linesearch, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + ) + + if jit: + self.run_ls = jax.jit(linesearch_solver.run) + else: + self.run_ls = linesearch_solver.run \ No newline at end of file diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index ebd08f21..59aac2b8 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -28,9 +28,6 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch -from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch -from jaxopt._src.zoom_linesearch import zoom_linesearch from jaxopt.tree_util import tree_map from jaxopt.tree_util import tree_vdot from jaxopt.tree_util import tree_add_scalar_mul @@ -39,6 +36,8 @@ from jaxopt.tree_util import tree_sum from jaxopt.tree_util import tree_l2_norm from jaxopt._src.tree_util import tree_single_dtype +from jaxopt._src.linesearch_util import _reset_stepsize +from jaxopt._src.linesearch_util import _setup_linesearch def inv_hessian_product_leaf(v: jnp.ndarray, @@ -48,6 +47,7 @@ def inv_hessian_product_leaf(v: jnp.ndarray, gamma: float = 1.0, start: int = 0): + """Product between an approximate Hessian inverse and the leaf of a pytree.""" history_size = len(s_history) indices = (start + jnp.arange(history_size)) % history_size @@ -67,7 +67,7 @@ def body_left(r, args): r = r + s_history[i] * (alpha - beta) return r, beta - r, beta = jax.lax.scan(body_left, r, (indices, alpha)) + r, _ = jax.lax.scan(body_left, r, (indices, alpha)) return r @@ -97,6 +97,9 @@ def inv_hessian_product(pytree: Any, i.e., `gamma * I`. start: starting index in the circular buffer. + Returns: + Product between approximate Hessian inverse and the pytree + Reference: Jorge Nocedal and Stephen Wright. Numerical Optimization, second edition. @@ -110,6 +113,7 @@ def inv_hessian_product(pytree: Any, def compute_gamma(s_history: Any, y_history: Any, last: int): + """Compute scalar gamma defining the initialization of the approximate Hessian.""" # Let gamma = vdot(y_history[last], s_history[last]) / sqnorm(y_history[last]). # The initial inverse Hessian approximation can be set to gamma * I. # See Numerical Optimization, second edition, equation (7.20). @@ -179,6 +183,8 @@ class LBFGS(base.IterativeSolver): line search. stop_if_linesearch_fails: whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX. + condition: condition used to select the stepsize when using backtracking + linesearch maxls: maximum number of iterations to use in the line search. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). @@ -256,6 +262,7 @@ def init_state(self, init_params: pytree containing the initial parameters. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: state """ @@ -301,6 +308,7 @@ def update(self, state: named tuple containing the solver state. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: (params, state) """ @@ -318,83 +326,17 @@ def update(self, use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0 if use_linesearch: - # with line search - - if self.linesearch == "backtracking": - ls = BacktrackingLineSearch(fun=self._value_and_grad_with_aux, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - max_stepsize=self.max_stepsize, - condition=self.condition, - jit=self.jit, - unroll=self.unroll, - has_aux=True) - init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Else, we increase a bit the previous one. - state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize, - params, value, grad, - descent_direction, - *args, **kwargs) - new_value = ls_state.value - new_aux = ls_state.aux - new_params = ls_state.params - new_grad = ls_state.grad - - elif self.linesearch == "zoom": - ls_state = zoom_linesearch(f=self._value_and_grad_with_aux, - xk=params, pk=descent_direction, - old_fval=value, gfk=grad, maxiter=self.maxls, - value_and_grad=True, has_aux=True, aux=state.aux, - args=args, kwargs=kwargs) - new_value = ls_state.f_k - new_aux = ls_state.aux - new_stepsize = ls_state.a_k - new_grad = ls_state.g_k - # FIXME: zoom_linesearch currently doesn't return new_params - # so we have to recompute it. - t = new_stepsize.astype(tree_single_dtype(params)) - new_params = tree_add_scalar_mul(params, t, descent_direction) - # FIXME: (zaccharieramzi) sometimes the linesearch fails - # and therefore its value g_k does not correspond - # to the gradient at the new parameters. - # with the following conditional loop we have a hot fix that just - # recomputes the value, gradient and auxiliary value - # at the new parameters. It would be better to understand - # what the g_k passed by zoom_linesearch is in this case - # and why it is wrong. - (new_value, new_aux), new_grad = jax.lax.cond( - ls_state.failed, - lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs), - lambda: ((new_value, new_aux), new_grad), - ) - elif self.linesearch == "hager-zhang": - # By default Hager-Zhang uses the Wolfe Conditions & Approximate Wolfe - # Conditions. - ls = HagerZhangLineSearch(fun=self._value_and_grad_fun, - value_and_grad=True, - maxiter=self.maxls, - max_stepsize=self.max_stepsize, - jit=self.jit, - unroll=self.unroll) - # Note that HZL doesn't use the previous step size. - new_stepsize, ls_state = ls.run(self.max_stepsize, - params, value, grad, - descent_direction, - *args, **kwargs) - new_params = ls_state.params - new_value = ls_state.value - new_grad = ls_state.grad - new_aux = ls_state.aux - else: - raise ValueError("Invalid name in 'linesearch' option.") + init_stepsize = self._reset_stepsize(state.stepsize) + new_stepsize, ls_state = self.run_ls( + init_stepsize, params, value, grad, descent_direction, *args, **kwargs + ) + new_params = ls_state.params + new_value = ls_state.value + new_grad = ls_state.grad + new_aux = ls_state.aux failed_linesearch = ls_state.failed else: - # without line search if isinstance(self.stepsize, Callable): new_stepsize = self.stepsize(state.iter_num) else: @@ -442,7 +384,7 @@ def _value_and_grad_fun(self, **kwargs): if isinstance(params, base.OptStep): params = params.params - (value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs) + (value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs) return value, grad def __post_init__(self): @@ -452,3 +394,33 @@ def __post_init__(self): has_aux=self.has_aux) self.reference_signature = self.fun + + jit, unroll = self._get_loop_options() + + linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=self._value_and_grad_with_aux, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=jit, + unroll=unroll, + verbose=self.verbose, + condition=self.condition, + decrease_factor=self.decrease_factor, + increase_factor=self.increase_factor, + ) + + self._reset_stepsize = partial( + _reset_stepsize, + self.linesearch, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + ) + + if jit: + self.run_ls = jax.jit(linesearch_solver.run) + else: + self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/lbfgsb.py b/jaxopt/_src/lbfgsb.py index 819447a3..a8437d60 100644 --- a/jaxopt/_src/lbfgsb.py +++ b/jaxopt/_src/lbfgsb.py @@ -21,6 +21,7 @@ # [2] J. Nocedal and S. Wright. Numerical Optimization, second edition. import dataclasses +import functools from typing import Any, Callable, NamedTuple, Optional, Union import jax @@ -28,12 +29,12 @@ from jaxopt._src import base from jaxopt._src import projection -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch -from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch from jaxopt._src.lbfgs import init_history from jaxopt._src.lbfgs import update_history +from jaxopt._src.linesearch_util import _reset_stepsize +from jaxopt._src.linesearch_util import _setup_linesearch + from jaxopt._src.tree_util import tree_single_dtype -from jaxopt._src.zoom_linesearch import zoom_linesearch from jaxopt.tree_util import tree_add_scalar_mul from jaxopt.tree_util import tree_inf_norm from jaxopt.tree_util import tree_map @@ -202,6 +203,7 @@ def _minimize_subspace( class LbfgsbState(NamedTuple): """Named tuple containing state information.""" + iter_num: int value: float grad: Any @@ -226,14 +228,19 @@ class LBFGSB(base.IterativeSolver): fun: a smooth function of the form ``fun(x, *args, **kwargs)``. value_and_grad: whether ``fun`` just returns the value (False) or both the value and gradient (True). See base.make_funs_with_aux for details. - has_aux: whether ``fun`` outputs auxiliary data or not. If ``has_aux`` is - False, ``fun`` is expected to be scalar-valued. If ``has_aux`` is True, - then we have one of the following two cases. If ``value_and_grad`` is - False, the output should be ``value, aux = fun(...)``. If ``value_and_grad - == True``, the output should be ``(value, aux), grad = fun(...)``. See - base.make_funs_with_aux for details. + has_aux: whether ``fun`` outputs auxiliary data or not. + If ``has_aux`` is False, ``fun`` is expected to be scalar-valued. + If ``has_aux`` is True, then we have one of the following two cases. + If ``value_and_grad`` is False, the output should be + ``value, aux = fun(...)``. + If ``value_and_grad == True``, the output should be + ``(value, aux), grad = fun(...)``. + At each iteration of the algorithm, the auxiliary outputs are stored + in ``state.aux``. + maxiter: maximum number of proximal gradient descent iterations. tol: tolerance of the stopping criterion. + stepsize: a stepsize to use (if <= 0, use backtracking line search), or a callable specifying the **positive** stepsize to use at each iteration. linesearch: the type of line search to use: "backtracking" for backtracking @@ -241,6 +248,8 @@ class LBFGSB(base.IterativeSolver): Hager-Zhang line search. stop_if_linesearch_fails: whether to stop iterations if the line search fails. When True, this matches the behavior of core JAX. + condition: condition used to select the stepsize when using backtracking + linesearch maxls: maximum number of iterations to use in the line search. decrease_factor: factor by which to decrease the stepsize during line search (default: 0.8). @@ -248,16 +257,20 @@ class LBFGSB(base.IterativeSolver): (default: 1.5). max_stepsize: upper bound on stepsize. min_stepsize: lower bound on stepsize. + history_size: size of the memory to use. use_gamma: whether to initialize the Hessian approximation with gamma * theta, where gamma is chosen following equation (7.20) of 'Numerical Optimization' [2]. If use_gamma is set to False, theta is used as initialization. + implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. + jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). + verbose: whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit. """ @@ -450,105 +463,22 @@ def update( use_linesearch = (not isinstance(self.stepsize, Callable) and self.stepsize <= 0.) if use_linesearch: - if self.linesearch == "backtracking": - ls = BacktrackingLineSearch( - fun=self._value_and_grad_with_aux, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - max_stepsize=self.max_stepsize, - condition=self.condition, - jit=self.jit, - unroll=self.unroll, - has_aux=True, - ) - init_stepsize = jnp.where( - state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Else, we increase a bit the previous one. - state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run( - init_stepsize, - params, - state.value, - state.grad, - descent_direction, - *args, - **kwargs - ) - new_value = ls_state.value - new_aux = ls_state.aux - new_params = ls_state.params - new_grad = ls_state.grad - elif self.linesearch == "zoom": - ls_state = zoom_linesearch( - f=self._value_and_grad_with_aux, - xk=params, - pk=descent_direction, - old_fval=state.value, - gfk=state.grad, - maxiter=self.maxls, - value_and_grad=True, - has_aux=True, - aux=state.aux, - args=args, - kwargs=kwargs, - ) - new_value = ls_state.f_k - new_aux = ls_state.aux - new_stepsize = ls_state.a_k - new_grad = ls_state.g_k - # FIXME: zoom_linesearch currently doesn't return new_params - # so we have to recompute it. - new_params = tree_add_scalar_mul( - params, new_stepsize, descent_direction - ) - - # FIXME: (zaccharieramzi) sometimes the linesearch fails - # and therefore its value g_k does not correspond - # to the gradient at the new parameters. - # with the following conditional loop we have a hot fix that just - # recomputes the value, gradient and auxiliary value - # at the new parameters. It would be better to understand - # what the g_k passed by zoom_linesearch is in this case - # and why it is wrong. - (new_value, new_aux), new_grad = jax.lax.cond( - ls_state.failed, - lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs), - lambda: ((new_value, new_aux), new_grad), - ) - elif self.linesearch == "hager-zhang": - # By default Hager-Zhang uses the Wolfe Conditions & Approximate Wolfe - # Conditions. - ls = HagerZhangLineSearch( - fun=self._value_and_grad_fun, - value_and_grad=True, - maxiter=self.maxls, - max_stepsize=self.max_stepsize, - jit=self.jit, - unroll=self.unroll, - ) - # Note that HZL doesn't use the previous step size. - new_stepsize, ls_state = ls.run( - self.max_stepsize, - params, - state.value, - state.grad, - descent_direction, - *args, - **kwargs - ) - - new_params = ls_state.params - new_value = ls_state.value - new_grad = ls_state.grad - new_aux = ls_state.aux - else: - raise ValueError("Invalid name in 'linesearch' option.") + init_stepsize = self._reset_stepsize(state.stepsize) + new_stepsize, ls_state = self.run_ls( + init_stepsize, + params, + state.value, + state.grad, + descent_direction, + *args, + **kwargs, + ) + new_params = ls_state.params + new_value = ls_state.value + new_grad = ls_state.grad + new_aux = ls_state.aux failed_linesearch = ls_state.failed else: - # without line search if isinstance(self.stepsize, Callable): new_stepsize = self.stepsize(state.iter_num) else: @@ -634,3 +564,32 @@ def __post_init__(self): ) self.reference_signature = self.fun + + jit, unroll = self._get_loop_options() + linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=self._value_and_grad_with_aux, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=jit, + unroll=unroll, + verbose=self.verbose, + condition=self.condition, + decrease_factor=self.decrease_factor, + increase_factor=self.increase_factor, + ) + + self._reset_stepsize = functools.partial( + _reset_stepsize, + self.linesearch, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + ) + + if jit: + self.run_ls = jax.jit(linesearch_solver.run) + else: + self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/linesearch_util.py b/jaxopt/_src/linesearch_util.py new file mode 100644 index 00000000..97bbaf39 --- /dev/null +++ b/jaxopt/_src/linesearch_util.py @@ -0,0 +1,109 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Line searches utilities.""" + +from jax import numpy as jnp +from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch +from jaxopt._src.hager_zhang_linesearch import HagerZhangLineSearch +from jaxopt._src.zoom_linesearch import ZoomLineSearch + + +def _setup_linesearch( + linesearch, + fun, + value_and_grad, + has_aux, + maxlsiter, + max_stepsize, + jit, + unroll, + verbose, + condition, # for backtracking only + decrease_factor, # for backtracking only + increase_factor, # for zoom only +): + """Instantiate linesearch.""" + + available_linesearches = ["backtracking", "zoom", "hager-zhang"] + if linesearch == "backtracking": + linesearch_solver = BacktrackingLineSearch( + fun=fun, + value_and_grad=value_and_grad, + has_aux=has_aux, + maxiter=maxlsiter, + decrease_factor=decrease_factor, + max_stepsize=max_stepsize, + condition=condition, + jit=jit, + unroll=unroll, + verbose=verbose, + ) + elif linesearch == "zoom": + linesearch_solver = ZoomLineSearch( + fun=fun, + value_and_grad=value_and_grad, + has_aux=has_aux, + maxiter=maxlsiter, + max_stepsize=max_stepsize, + increase_factor=increase_factor, + jit=jit, + unroll=unroll, + verbose=verbose, + ) + elif linesearch == "hager-zhang": + linesearch_solver = HagerZhangLineSearch( + fun=fun, + value_and_grad=value_and_grad, + has_aux=has_aux, + maxiter=maxlsiter, + max_stepsize=max_stepsize, + jit=jit, + unroll=unroll, + verbose=verbose, + ) + else: + raise ValueError( + f"Linesearch {linesearch} not available/tested. " + f"Available linesearches: {available_linesearches}" + ) + return linesearch_solver + + +def _reset_stepsize( + linesearch, max_stepsize, min_stepsize, increase_factor, stepsize +): + """Set stepsize at the start of the linesearch from previous guess.""" + available_linesearches = ["backtracking", "zoom", "hager-zhang"] + if linesearch == "hager-zhang": + # FIXME: HZL should be able to use the previous stepsize (see the paper) + # For now, the current implementation is simply initialized at the maximum + # stepsize. + init_stepsize = max_stepsize + elif linesearch == "zoom": + init_stepsize = stepsize + elif linesearch == "backtracking": + init_stepsize = jnp.where( + stepsize <= min_stepsize, + # If stepsize became too small, we restart it. + max_stepsize, + # Else, we increase a bit the previous one. + stepsize * increase_factor, + ) + else: + raise ValueError( + f"Linesearch {linesearch} not available/tested. " + f"Available linesearches: {available_linesearches}" + ) + return init_stepsize diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 30007aa0..6e0c3a5b 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -21,12 +21,15 @@ from dataclasses import dataclass +import functools + import jax import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch -from jaxopt._src.zoom_linesearch import zoom_linesearch +from jaxopt._src.linesearch_util import _reset_stepsize +from jaxopt._src.linesearch_util import _setup_linesearch + from jaxopt.tree_util import tree_vdot from jaxopt.tree_util import tree_scalar_mul from jaxopt.tree_util import tree_add_scalar_mul @@ -38,6 +41,7 @@ class NonlinearCGState(NamedTuple): """Named tuple containing state information.""" + iter_num: int stepsize: float error: float @@ -82,10 +86,13 @@ class NonlinearCG(base.IterativeSolver): implicit_diff: whether to enable implicit diff or autodiff of unrolled iterations. implicit_diff_solve: the linear system solver to use. + jit: whether to JIT-compile the optimization loop (default: "auto"). unroll: whether to unroll the optimization loop (default: "auto"). + verbose: whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit. + Reference: Jorge Nocedal and Stephen Wright. Numerical Optimization, second edition. @@ -122,10 +129,12 @@ def init_state(self, *args, **kwargs) -> NonlinearCGState: """Initialize the solver state. + Args: init_params: pytree containing the initial parameters. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: state """ @@ -147,11 +156,13 @@ def update(self, *args, **kwargs) -> base.OptStep: """Performs one iteration of Fletcher-Reeves Algorithm. + Args: params: pytree containing the parameters. state: named tuple containing the solver state. *args: additional positional arguments to be passed to ``fun``. **kwargs: additional keyword arguments to be passed to ``fun``. + Returns: (params, state) """ @@ -161,61 +172,26 @@ def update(self, grad = state.grad descent_direction = state.descent_direction + # Kept choice of no descent direction for backtracking line-search + # FIXME: should discuss why it was the case if self.linesearch == "backtracking": - ls = BacktrackingLineSearch(fun=self._value_and_grad_with_aux, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - condition=self.condition, - max_stepsize=self.max_stepsize, - has_aux=True) - - init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Otherwise, we increase a bit the previous one. - state.stepsize * self.increase_factor) - - new_stepsize, ls_state = ls.run(init_stepsize, - params, - value, - grad, - None, # descent_direction - *args, **kwargs) - new_value = ls_state.value - new_aux = ls_state.aux - new_params = ls_state.params - new_grad = ls_state.grad - - elif self.linesearch == "zoom": - ls_state = zoom_linesearch(f=self._value_and_grad_with_aux, - xk=params, pk=descent_direction, - old_fval=value, gfk=grad, maxiter=self.maxls, - value_and_grad=True, has_aux=True, aux=state.aux, - args=args, kwargs=kwargs) - new_value = ls_state.f_k - new_aux = ls_state.aux - new_stepsize = ls_state.a_k - new_grad = ls_state.g_k - # FIXME: zoom_linesearch currently doesn't return new_params - # so we have to recompute it. - t = new_stepsize.astype(tree_single_dtype(params)) - new_params = tree_add_scalar_mul(params, t, descent_direction) - # FIXME: (zaccharieramzi) sometimes the linesearch fails - # and therefore its value g_k does not correspond - # to the gradient at the new parameters. - # with the following conditional loop we have a hot fix that just - # recomputes the value, gradient and auxiliary value - # at the new parameters. It would be better to understand - # what the g_k passed by zoom_linesearch is in this case - # and why it is wrong. - (new_value, new_aux), new_grad = jax.lax.cond( - ls_state.failed, - lambda: self._value_and_grad_with_aux(new_params, *args, **kwargs), - lambda: ((new_value, new_aux), new_grad), - ) + ls_descent_direction = None else: - raise ValueError("Invalid name in 'linesearch' option.") + ls_descent_direction = descent_direction + init_stepsize = self._reset_stepsize(state.stepsize) + new_stepsize, ls_state = self.run_ls( + init_stepsize, + params, + value, + grad, + ls_descent_direction, + *args, + **kwargs, + ) + new_params = ls_state.params + new_value = ls_state.value + new_grad = ls_state.grad + new_aux = ls_state.aux if self.method == "polak-ribiere": # See Numerical Optimization, second edition, equation (5.44). @@ -228,7 +204,7 @@ def update(self, gTg = tree_vdot(grad, grad) gTg = jnp.where(gTg >= eps, gTg, eps) new_beta = tree_div(tree_vdot(new_grad, new_grad), gTg) - elif self.method == 'hestenes-stiefel': + elif self.method == "hestenes-stiefel": # See Numerical Optimization, second edition, equation (5.45). grad_diff = tree_sub(new_grad, grad) dTg = tree_vdot(descent_direction, grad_diff) @@ -256,7 +232,7 @@ def optimality_fun(self, params, *args, **kwargs): return self._grad_fun(params, *args, **kwargs) def _value_and_grad_fun(self, params, *args, **kwargs): - (value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs) + (value, _), grad = self._value_and_grad_with_aux(params, *args, **kwargs) return value, grad def _grad_fun(self, params, *args, **kwargs): @@ -269,3 +245,33 @@ def __post_init__(self): has_aux=self.has_aux) self.reference_signature = self.fun + + jit, unroll = self._get_loop_options() + + linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=self._value_and_grad_with_aux, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=jit, + unroll=unroll, + verbose=self.verbose, + condition=self.condition, + decrease_factor=self.decrease_factor, + increase_factor=self.increase_factor, + ) + + self._reset_stepsize = functools.partial( + _reset_stepsize, + self.linesearch, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + ) + + if jit: + self.run_ls = jax.jit(linesearch_solver.run) + else: + self.run_ls = linesearch_solver.run diff --git a/jaxopt/_src/zoom_linesearch.py b/jaxopt/_src/zoom_linesearch.py index 4ebe01ba..6f2de7c9 100644 --- a/jaxopt/_src/zoom_linesearch.py +++ b/jaxopt/_src/zoom_linesearch.py @@ -1,4 +1,4 @@ -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,487 +14,734 @@ """Zoom line search algorithm.""" -# Original code by Joshua George Albert: -# https://github.com/google/jax/pull/3101 +import dataclasses +import functools +from typing import Any +from typing import Callable +from typing import NamedTuple +from typing import Optional -from typing import Any, NamedTuple, Optional, Union -from functools import partial - -#from jax._src.numpy.util import _promote_dtypes_inexact -import jax.numpy as jnp import jax from jax import lax -from jaxopt.tree_util import tree_vdot, tree_add_scalar_mul, tree_map +import jax.numpy as jnp +from jaxopt._src import base +from jaxopt._src.base import _make_funs_with_aux from jaxopt._src.tree_util import tree_single_dtype +from jaxopt.tree_util import tree_add_scalar_mul +from jaxopt.tree_util import tree_scalar_mul +from jaxopt.tree_util import tree_vdot +# pylint: disable=g-bare-generic +# pylint: disable=invalid-name -_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) +_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST) def _cubicmin(a, fa, fpa, b, fb, c, fc): + """Cubic interpolation. + + Finds a critical point of a cubic polynomial + p(x) = A *(x-a)^3 + B*(x-a)^2 + C*(x-a) + D, that goes through + the points (a,fa), (b,fb), and (c,fc) with derivative at a of fpa. + May return NaN (if radical<0), in that case, the point will be ignored. + Taken from scipy.optimize._linesearch.py. + + Args: + a: scalar + fa: value of a function f at a + fpa: slope of a function f at a + b: scalar + fb: value of a function f at b + c: scalar + fc: value of a function f at c + + Returns: + xmin: point at which p'(xmin) = 0 + """ C = fpa db = b - a dc = c - a denom = (db * dc) ** 2 * (db - dc) - d1 = jnp.array([[dc ** 2, -db ** 2], - [-dc ** 3, db ** 3]]) + d1 = jnp.array([[dc**2, -(db**2)], [-(dc**3), db**3]]) A, B = _dot(d1, jnp.array([fb - fa - C * db, fc - fa - C * dc])) / denom - radical = B * B - 3. * A * C - xmin = a + (-B + jnp.sqrt(radical)) / (3. * A) + radical = B * B - 3.0 * A * C + xmin = a + (-B + jnp.sqrt(radical)) / (3.0 * A) return xmin def _quadmin(a, fa, fpa, b, fb): + """Quadratic interpolation. + + Finds a critical point of a quadratic polynomial + p(x) = B*(x-a)^2 + C*(x-a) + D, that goes through + the points (a,fa), (b,fb) with derivative at a of fpa. + Taken from scipy.optimize._linesearch.py. + + Args: + a: scalar + fa: value of a function f at a + fpa: slope of a function f at a + b: scalar + fb: value of a function f at b + + Returns: + xmin: point at which p'(xmin) = 0 + """ D = fa C = fpa db = b - a - B = (fb - D - C * db) / (db ** 2) - xmin = a - C / (2. * B) + B = (fb - D - C * db) / (db**2) + xmin = a - C / (2.0 * B) return xmin -def _binary_replace(replace_bit, original_dict, new_dict, keys=None): - if keys is None: - keys = new_dict.keys() - out = dict() - for key in keys: - #out[key] = jnp.where(replace_bit, new_dict[key], original_dict[key]) - out[key] = tree_map(lambda x, y: jnp.where(replace_bit, x, y), - new_dict[key], - original_dict[key]) - return out - - -class _ZoomState(NamedTuple): - done: Union[bool, jnp.ndarray] - failed: Union[bool, jnp.ndarray] - j: Union[int, jnp.ndarray] - a_lo: Union[float, jnp.ndarray] - phi_lo: Union[float, jnp.ndarray] - dphi_lo: Union[float, jnp.ndarray] - a_hi: Union[float, jnp.ndarray] - phi_hi: Union[float, jnp.ndarray] - dphi_hi: Union[float, jnp.ndarray] - a_rec: Union[float, jnp.ndarray] - phi_rec: Union[float, jnp.ndarray] - a_star: Union[float, jnp.ndarray] - phi_star: Union[float, jnp.ndarray] - dphi_star: Union[float, jnp.ndarray] - g_star: Union[float, jnp.ndarray] - nfev: Union[int, jnp.ndarray] - ngev: Union[int, jnp.ndarray] - aux_lo: Union[float, jnp.ndarray] - aux_hi: Union[float, jnp.ndarray] - aux_star: Union[float, jnp.ndarray] - - -def _zoom(restricted_func_and_grad, wolfe_one, wolfe_two, a_lo, phi_lo, - dphi_lo, a_hi, phi_hi, dphi_hi, g_0, pass_through, has_aux=False, aux=jnp.nan): - """ - Implementation of zoom. Algorithm 3.6 from Wright and Nocedal, 'Numerical - Optimization', 1999, pg. 59-61. Tries cubic, quadratic, and bisection methods - of zooming. +def _set_values(cond, candidate, default): + def _set_val(x, y): + return jnp.where(cond, x, y) + + return jax.tree_util.tree_map(_set_val, candidate, default) + + +def _check_failure_status(fail_code): + """Print failure reason according to fail value.""" + if fail_code == 1: + print("Provided descent direction is not a descent direction.") + elif fail_code == 2: + print("Maximal stepsize reached.") + elif fail_code == 3: + print("Maximal number of line search iterations reached.") + elif fail_code == 4: + print( + "Length of searched interval has been reduced below machine precision." + ) + elif fail_code == 5: + print("NaN or Inf values encountered in function values.") + + +class ZoomLineSearchState(NamedTuple): + """Named tuple containing state information for core loop.""" + + iter_num: int + params: Any # either initial or final + value: float # either initial or final + grad: Any # either initial or final + + # unchanged after initialization + value_init: float # (redundant with value, left for readability) + slope_init: float + descent_direction: Any + + num_fun_eval: int + num_grad_eval: int + + error: float + done: bool + fail_code: int # encode failure status, see _check_status + failed: bool # comply to semantic used by other line searches + + # Used only during the interval search + interval_found: bool + prev_stepsize: float + prev_value_step: float + prev_slope_step: float + + # Set up after interval search done, modified during zoom + low: float + value_low: float + slope_low: float + high: float + value_high: float + slope_high: float + cubic_ref: float + value_cubic_ref: float + + # Safeguard point: we may not be able to satisfy the curvature condition + # but we can still return a point that satisfies the decrease condition + safe_stepsize: float + + aux: Optional[Any] = None # either initial or final + + +@dataclasses.dataclass(eq=False) +class ZoomLineSearch(base.IterativeLineSearch): + """Inexact line search that satisfies sufficient decrease (Armijo) and small curvature (strong Wolfe) conditions. + + Algorithms 3.5, 3.6 from [1], pages 60-62. + + The sufficient decrease condition may be impossible to satisfy close to a + minimum, in that case, we switch to an approximate sufficient decrease + condition (approximate Wolfe) taken from [2]. + + [1] J. Nocedal and S. Wright, 'Numerical Optimization', 2nd edition, 2006. + [2] W. Hager, H. Zhang, Algorithm 851: CG_DESCENT, a Conjugate Gradient Method + with Guaranteed Descent. + + Attributes: + fun: a function of the form ``fun(params, *args, **kwargs)``, where + ``params`` are parameters of the model, ``*args`` and ``**kwargs`` are + additional arguments. + value_and_grad: if ``False``, ``fun`` should return the function value only. + If ``True``, ``fun`` should return both the function value and the + gradient. + has_aux: if ``False``, ``fun`` should return the function value only. If + ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux`` is a + pytree of auxiliary values. (default: False) + c1: constant used to check if a sufficient decrease has been found (Armijo) + (default: 1e-4) + c2: constant used to check if a small curvature has been found (strong + Wolfe) (default: 0.9) + c3: constant used to check if an approximate sufficient decrease + (approximate Wolfe) has been found (default: 1e-6) + rel_tol_cubic: point computed by cubic interpolation accepted if inside + rel_tol_cubic*interval_size (default: 0.2) + rel_tol_quad: point computed by quadratic interpolation accepted if inside + rel_tol_quad*interval_size (default: 0.1) + increase_factor: factor to mutliply stepsize at initialization until finding + interval satisfying curvature condition (default: 2.) + max_stepsize: maximal possible stepsize. (default: 2**30) + tol: tolerance of the stopping criterion. (default: 0.0) + maxiter: maximum number of line search iterations. (default: 30) + verbose: whether to print error on every iteration or not. verbose=True will + automatically disable jit. (default: False) + jit: whether to JIT-compile the optimization loop (default: "auto"). + unroll: whether to unroll the optimization loop (default: "auto"). """ - init_state = _ZoomState( - done=False, - failed=False, - j=0, - a_lo=a_lo, - phi_lo=phi_lo, - dphi_lo=dphi_lo, - a_hi=a_hi, - phi_hi=phi_hi, - dphi_hi=dphi_hi, - a_rec=(a_lo + a_hi) / 2., - phi_rec=(phi_lo + phi_hi) / 2., - a_star=1.0, - phi_star=phi_lo, - dphi_star=dphi_lo, - g_star=g_0, - nfev=0, - ngev=0, - # the auxiliary values are not used in the body of the loop - # but are just set at the end, so we need them to have matching shapes - # and dtypes - aux_lo=aux, - aux_hi=aux, - aux_star=aux, - ) - delta1 = 0.2 - delta2 = 0.1 - - def body(state): - # Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond - # so that it works on GPU/TPU. - dalpha = (state.a_hi - state.a_lo) - a = jnp.minimum(state.a_hi, state.a_lo) - b = jnp.maximum(state.a_hi, state.a_lo) - cchk = delta1 * dalpha - qchk = delta2 * dalpha - - # This will cause the line search to stop, and since the Wolfe conditions - # are not satisfied the minimization should stop too. - threshold = jnp.where((jnp.finfo(dalpha).bits < 64), 1e-5, 1e-10) - state = state._replace(failed=state.failed | (dalpha <= threshold)) - - # Cubmin is sometimes nan, though in this case the bounds check will fail. - a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi, - state.phi_hi, state.a_rec, state.phi_rec) - use_cubic = (state.j > 0) & (a_j_cubic > a + cchk) & (a_j_cubic < b - cchk) - a_j_quad = _quadmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi, state.phi_hi) - use_quad = (~use_cubic) & (a_j_quad > a + qchk) & (a_j_quad < b - qchk) - a_j_bisection = (state.a_lo + state.a_hi) / 2. + + fun: Callable + value_and_grad: bool = False + has_aux: bool = False + + c1: float = 1e-4 + c2: float = 0.9 + c3: float = 1e-6 + rel_tol_cubic: float = 0.2 + rel_tol_quad: float = 0.1 + increase_factor: float = 2.0 + + tol: float = 0.0 + maxiter: int = 30 + # max_stepsize needs to be large enough for the linesearch to be able + # to find a good stepsize + max_stepsize: float = 2**30 + + verbose: bool = False + jit: base.AutoOrBoolean = "auto" + unroll: base.AutoOrBoolean = "auto" + + def _value_and_slope_on_line( + self, params, stepsize, descent_direction, args, kwargs + ): + step = tree_add_scalar_mul(params, stepsize, descent_direction) + (value_step, aux_step), grad_step = self._value_and_grad_fun_with_aux( + step, *args, **kwargs + ) + slope_step = tree_vdot(grad_step, descent_direction) + return value_step, slope_step, step, grad_step, aux_step + + def _decrease_error( + self, stepsize, value_step, slope_step, value_init, slope_init + ): + # We consider either the usual sufficient decrease (Armijo condition), see + # equation (3.7a) of [1] + exact_decrease_error = ( + value_step - value_init - self.c1 * stepsize * slope_init + ) + # or an approximate decrease condition, see equation (23) of [2] + approx_decrease_error_ = slope_step - (2 * self.c1 - 1.0) * slope_init + + # The classical Armijo condition may fail to be satisfied if we are too + # close to a minimum, causing the optimizer to fail as explained in [2] + + # We switch to approximate Wolfe conditions only if we are close enough to + # the minimizer which is captured by the following criterion. + delta_values = value_step - value_init - self.c3 * jnp.abs(value_init) + approx_decrease_error = jnp.maximum(approx_decrease_error_, delta_values) + # We take then the *minimum* of both errors. + return jnp.minimum(approx_decrease_error, exact_decrease_error) + + def _curvature_error(self, slope_step, slope_init): + # See equation (3.7b) of [1]. + return jnp.abs(slope_step) - self.c2 * jnp.abs(slope_init) + + def _make_safe_step(self, _, state, args, kwargs): + safe_stepsize = state.safe_stepsize + step = tree_add_scalar_mul( + state.params, safe_stepsize, state.descent_direction + ) + (value_step, aux_step), grad_step = self._value_and_grad_fun_with_aux( + step, *args, **kwargs + ) + new_state = state._replace( + params=step, value=value_step, grad=grad_step, aux=aux_step + ) + return safe_stepsize, new_state + + def _keep_step(self, stepsize, state, _, __): + return stepsize, state + + def _search_interval(self, init_stepsize, state, args, kwargs): + """Line search procedure described in Algorithm 3.5 of [1].""" + # init_stepsize only used for iter_num = 0 + + iter_num = state.iter_num + + params_init = state.params + grad_init = state.grad + aux_init = state.aux + + fail_code = state.fail_code + + value_init = state.value_init + slope_init = state.slope_init + descent_direction = state.descent_direction + + prev_stepsize = state.prev_stepsize + prev_value_step = state.prev_value_step + prev_slope_step = state.prev_slope_step + + safe_stepsize = state.safe_stepsize + + # Choose new point, larger than previous one or set to initial guess + # for first iteration. + larger_stepsize = self.increase_factor * prev_stepsize + new_stepsize_ = jnp.where(iter_num == 0, init_stepsize, larger_stepsize) + new_stepsize = jnp.minimum(new_stepsize_, self.max_stepsize) + + max_stepsize_reached = new_stepsize == self.max_stepsize + fail_check1 = jnp.where( + (fail_code == 0) & max_stepsize_reached, 2, fail_code + ) + + new_value_step, new_slope_step, new_step, new_grad_step, new_aux_step = ( + self._value_and_slope_on_line( + params_init, new_stepsize, descent_direction, args, kwargs + ) + ) + is_value_nan = jnp.isnan(new_value_step) | jnp.isinf(new_value_step) + fail_check2 = jnp.where((fail_check1 == 0) & is_value_nan, 5, fail_check1) + + decrease_error_ = self._decrease_error( + new_stepsize, new_value_step, new_slope_step, value_init, slope_init + ) + decrease_error = jnp.maximum(decrease_error_, 0.0) + curvature_error_ = self._curvature_error(new_slope_step, slope_init) + curvature_error = jnp.maximum(curvature_error_, 0.0) + new_error = jnp.maximum(decrease_error, curvature_error) + + # If the new point satisfies at least the decrease error we keep it + # in case the curvature error cannot be satisfied. + safe_decrease = decrease_error <= self.tol + new_safe_stepsize = jnp.where(safe_decrease, new_stepsize, safe_stepsize) + + # If the new point not good, set high and low values according to + # conditions described in Algorithm 3.5 of [1] + set_high_to_new = (decrease_error > 0.0) | ( + (new_value_step >= prev_value_step) & (iter_num > 0) + ) + set_low_to_new = (new_slope_step >= 0.0) & (~set_high_to_new) + + # By default we set high to new and correct if we should have set + # low to new. If none should have set, the search for the interval + # continues anyway. + low_, value_low_, slope_low_, high_, value_high_, slope_high_ = ( + prev_stepsize, + prev_value_step, + prev_slope_step, + new_stepsize, + new_value_step, + new_slope_step, + ) + + default = [low_, value_low_, slope_low_, high_, value_high_, slope_high_] + candidate = [ + new_stepsize, + new_value_step, + new_slope_step, + prev_stepsize, + prev_value_step, + prev_slope_step, + ] + [low, value_low, slope_low, high, value_high, slope_high] = _set_values( + set_low_to_new, candidate, default + ) + + # If high or low have been set or the point is good, the interval has been + # found. Otherwise we'll keep on augmenting the stepsize. + interval_found = set_high_to_new | set_low_to_new | (new_error <= self.tol) + + # If new_error <= self.tol, the line search is done. In that case, we set + # directly the new parameters, gradient, value and aux to the ones found. + default = [0.0, params_init, value_init, grad_init, aux_init] + candidate = [ + new_stepsize, + new_step, + new_value_step, + new_grad_step, + new_aux_step, + ] + best_stepsize, next_params, next_value, next_grad, next_aux = _set_values( + new_error <= self.tol, candidate, default + ) + + done = new_error <= self.tol + + max_iter_reached = (iter_num + 1 >= self.maxiter) & (~done) + new_fail_code = jnp.where( + (fail_check2 == 0) & max_iter_reached, 3, fail_check2 + ) + + new_state = state._replace( + iter_num=iter_num + 1, + params=next_params, + value=next_value, + grad=next_grad, + aux=next_aux, + # + error=new_error, + done=done, + fail_code=new_fail_code, + failed=jnp.asarray(new_fail_code > 0), + interval_found=interval_found, + # + prev_stepsize=new_stepsize, + prev_value_step=new_value_step, + prev_slope_step=new_slope_step, + # + low=low, + value_low=value_low, + slope_low=slope_low, + high=high, + value_high=value_high, + slope_high=slope_high, + cubic_ref=low, + value_cubic_ref=value_low, + # + safe_stepsize=new_safe_stepsize, + # + num_fun_eval=state.num_fun_eval + 1, + num_grad_eval=state.num_grad_eval + 1, + ) + return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + + def _zoom_into_interval(self, stepsize, state, args, kwargs): + """Zoom procedure described in Algorithm 3.6 of [1].""" + + # The stepsize is not used, only low, high, etc... are used to + # find a good point + dtype = stepsize.dtype + del stepsize + + iter_num = state.iter_num + + params_init = state.params + grad_init = state.grad + aux_init = state.aux + + value_init = state.value_init + slope_init = state.slope_init + descent_direction = state.descent_direction + + fail_code = state.fail_code + + low = state.low + value_low = state.value_low + slope_low = state.slope_low + high = state.high + value_high = state.value_high + slope_high = state.slope_high + cubic_ref = state.cubic_ref + value_cubic_ref = state.value_cubic_ref + + safe_stepsize = state.safe_stepsize + + # Check if interval not too small otherwise fail + delta = jnp.abs(high - low) + left = jnp.minimum(high, low) + right = jnp.maximum(high, low) + cubic_chk = self.rel_tol_cubic * delta + quad_chk = self.rel_tol_quad * delta + threshold = jnp.where((jnp.finfo(delta).bits < 64), 1e-5, 1e-10) + too_small_int = delta <= threshold + fail_check1 = jnp.where((fail_code == 0) & too_small_int, 4, fail_code) + + # Find new point by interpolation + middle_cubic = _cubicmin( + low, value_low, slope_low, high, value_high, cubic_ref, value_cubic_ref + ) + middle_cubic_valid = (middle_cubic > left + cubic_chk) & ( + middle_cubic < right - cubic_chk + ) + use_cubic = middle_cubic_valid + middle_quad = _quadmin(low, value_low, slope_low, high, value_high) + middle_quad_valid = (middle_quad > left + quad_chk) & ( + middle_quad < right - quad_chk + ) + use_quad = (~use_cubic) & middle_quad_valid + middle_bisection = (low + high) / 2.0 use_bisection = (~use_cubic) & (~use_quad) - a_j = jnp.where(use_cubic, a_j_cubic, state.a_rec) - a_j = jnp.where(use_quad, a_j_quad, a_j) - a_j = jnp.where(use_bisection, a_j_bisection, a_j) - - # TODO(jakevdp): should we use some sort of fixed-point approach here instead? - if has_aux: - (phi_j, dphi_j, g_j), aux_j = restricted_func_and_grad(a_j) - else: - phi_j, dphi_j, g_j = restricted_func_and_grad(a_j) - aux_j = jnp.nan - phi_j = phi_j.astype(state.phi_lo.dtype) - dphi_j = dphi_j.astype(state.dphi_lo.dtype) - #g_j = g_j.astype(state.g_star.dtype) - state = state._replace(nfev=state.nfev + 1, - ngev=state.ngev + 1) - - hi_to_j = wolfe_one(a_j, phi_j) | (phi_j >= state.phi_lo) - star_to_j = wolfe_two(dphi_j) & (~hi_to_j) - hi_to_lo = (dphi_j * (state.a_hi - state.a_lo) >= 0.) & (~hi_to_j) & (~star_to_j) - lo_to_j = (~hi_to_j) & (~star_to_j) - - state = state._replace( - **_binary_replace( - hi_to_j, - state._asdict(), - dict( - a_hi=a_j, - phi_hi=phi_j, - dphi_hi=dphi_j, - aux_hi=aux_j, - a_rec=state.a_hi, - phi_rec=state.phi_hi, - ), - ), + middle = jnp.where(use_cubic, middle_cubic, cubic_ref) + middle = jnp.where(use_quad, middle_quad, middle) + middle = jnp.where(use_bisection, middle_bisection, middle).astype(dtype) + + # Check if new point is good + value_middle, slope_middle, step, grad_step, aux_step = ( + self._value_and_slope_on_line( + params_init, middle, descent_direction, args, kwargs + ) ) + is_value_nan = jnp.isnan(value_middle) | jnp.isinf(value_middle) + fail_check2 = jnp.where((fail_check1 == 0) & is_value_nan, 5, fail_check1) - # for termination - state = state._replace( - done=star_to_j | state.done, - **_binary_replace( - star_to_j, - state._asdict(), - dict( - a_star=a_j, - phi_star=phi_j, - dphi_star=dphi_j, - g_star=g_j, - aux_star=aux_j, - ) - ), + decrease_error_ = self._decrease_error( + middle, value_middle, slope_middle, value_init, slope_init ) - state = state._replace( - **_binary_replace( - hi_to_lo, - state._asdict(), - dict( - a_hi=state.a_lo, - phi_hi=state.phi_lo, - dphi_hi=state.dphi_lo, - aux_hi=state.aux_lo, - a_rec=state.a_hi, - phi_rec=state.phi_hi, - ), - ), + decrease_error = jnp.maximum(decrease_error_, 0.0) + curvature_error_ = self._curvature_error(slope_middle, slope_init) + curvature_error = jnp.maximum(curvature_error_, 0.0) + + new_error = jnp.maximum(decrease_error, curvature_error) + + # If the new point satisfies at least the decrease error we keep it in case + # the curvature error cannot be satisfied. We take the largest possible one + safe_decrease = decrease_error <= self.tol + new_safe_stepsize_ = jnp.where(safe_decrease, middle, safe_stepsize) + new_safe_stepsize = jnp.maximum(new_safe_stepsize_, safe_stepsize) + + # If both armijo and curvature conditions are satisfied, we are done. + done = new_error <= self.tol + default = [0.0, params_init, value_init, grad_init, aux_init] + candidate = [middle, step, value_middle, grad_step, aux_step] + best_stepsize, next_params, next_value, next_grad, next_aux = _set_values( + new_error <= self.tol, candidate, default ) - state = state._replace( - **_binary_replace( - lo_to_j, - state._asdict(), - dict( - a_lo=a_j, - phi_lo=phi_j, - dphi_lo=dphi_j, - aux_lo=aux_j, - a_rec=state.a_lo, - phi_rec=state.phi_lo, - ), - ), + + # Otherwise, we update high and low values + set_high_to_middle = (decrease_error > 0.0) | (value_middle >= value_low) + secant_interval = slope_middle * (high - low) + set_high_to_low = (secant_interval >= 0.0) & (~set_high_to_middle) + set_low_to_middle = ~set_high_to_middle + + # Set high to middle, or low, or keep as it is + default = [high, value_high, slope_high] + candidate = [middle, value_middle, slope_middle] + [new_high_, new_value_high_, new_slope_high_] = _set_values( + set_high_to_middle, candidate, default + ) + default = [new_high_, new_value_high_, new_slope_high_] + candidate = [low, value_low, slope_low] + [new_high, new_value_high, new_slope_high] = _set_values( + set_high_to_low, candidate, default ) - state = state._replace(j=state.j + 1) - # Choose higher cutoff for maxiter than Scipy as Jax takes longer to find - # the same value - possibly floating point issues? - state = state._replace(failed= state.failed | (state.j >= 30)) - - # For dtype consistency - state = state._replace(a_lo=state.a_lo.astype(init_state.a_lo.dtype), - a_hi=state.a_hi.astype(init_state.a_hi.dtype), - a_rec=state.a_rec.astype(init_state.a_rec.dtype)) - - return state - - state = lax.while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed), - body, - init_state) - - return state - - -class _LineSearchState(NamedTuple): - done: Union[bool, jnp.ndarray] - failed: Union[bool, jnp.ndarray] - i: Union[int, jnp.ndarray] - a_i1: Union[float, jnp.ndarray] - phi_i1: Union[float, jnp.ndarray] - dphi_i1: Union[float, jnp.ndarray] - nfev: Union[int, jnp.ndarray] - ngev: Union[int, jnp.ndarray] - a_star: Union[float, jnp.ndarray] - phi_star: Union[float, jnp.ndarray] - dphi_star: Union[float, jnp.ndarray] - g_star: jnp.ndarray - aux_star: Union[float, jnp.ndarray] - - -class _LineSearchResults(NamedTuple): - """Results of line search. - Parameters: - failed: True if the strong Wolfe criteria were satisfied - nit: integer number of iterations - nfev: integer number of functions evaluations - ngev: integer number of gradients evaluations - k: integer number of iterations - a_k: integer step size - f_k: final function value - g_k: final gradient value - status: integer end status - """ - failed: Union[bool, jnp.ndarray] - nit: Union[int, jnp.ndarray] - nfev: Union[int, jnp.ndarray] - ngev: Union[int, jnp.ndarray] - k: Union[int, jnp.ndarray] - a_k: Union[int, jnp.ndarray] - f_k: jnp.ndarray - g_k: jnp.ndarray - status: Union[bool, jnp.ndarray] - aux: Union[float, jnp.ndarray] - - -def zoom_linesearch(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, - c1=1e-4, c2=0.9, maxiter=20, value_and_grad=False, - has_aux=False, aux=None, args=[], kwargs={}): - """Inexact line search that satisfies strong Wolfe conditions. - Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, - pages 59-61. - Args: - f: function of the form f(x) where x is a flat ndarray and returns a real - scalar. The function should be composed of operations with vjp defined. - x0: initial guess. - pk: direction to search in. Assumes the direction is a descent direction. - old_fval, gfk: initial value of value_and_gradient as position. - old_old_fval: unused argument, only for scipy API compliance. - maxiter: maximum number of iterations to search - c1, c2: Wolfe criteria constant, see ref. - value_and_grad: whether f returns just the value (False) or the value and - grad (True). - has_aux: if ``False``, ``f`` should return the function value only. - If ``True``, ``f`` should return a pair ``(value, aux)`` where ``aux`` - is a pytree of auxiliary values. - aux: auxiliary pytree data example for ``f``. - args, kwargs: optional positional and keywords arguments to be passed to f. - Returns: LineSearchResults - """ - #xk, pk = _promote_dtypes_inexact(xk, pk) - #xk = jnp.asarray(xk) - #pk = jnp.asarray(pk) - - if value_and_grad: - f_value_and_grad = f - else: - f_value_and_grad = jax.value_and_grad(f, has_aux=has_aux) - - def restricted_func_and_grad(t): - dtype = tree_single_dtype(xk) - if dtype is not None: - t = jnp.asarray(t, dtype=dtype) - xkp1 = tree_add_scalar_mul(xk, t, pk) - if has_aux: - (phi, aux), g = f_value_and_grad(xkp1, *args, **kwargs) - else: - phi, g = f_value_and_grad(xkp1, *args, **kwargs) - dphi = jnp.real(tree_vdot(g, pk)) - if has_aux: - return (phi, dphi, g), aux - else: - return phi, dphi, g - - if old_fval is None or gfk is None or (aux is None and has_aux): - if has_aux: - (phi_0, dphi_0, gfk), aux = restricted_func_and_grad(0) - else: - phi_0, dphi_0, gfk = restricted_func_and_grad(0) - else: - phi_0 = old_fval - dphi_0 = jnp.real(tree_vdot(gfk, pk)) - if not has_aux: - aux = jnp.nan - if old_old_fval is not None: - candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0 - start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value) - else: - start_value = 1 - - def wolfe_one(a_i, phi_i): - # actually negation of W1 - return phi_i > phi_0 + c1 * a_i * dphi_0 - - def wolfe_two(dphi_i): - return jnp.abs(dphi_i) <= -c2 * dphi_0 - - state = _LineSearchState( - done=False, - failed=False, - # algorithm begins at 1 as per Wright and Nocedal, however Scipy has a - # bug and starts at 0. See https://github.com/scipy/scipy/issues/12157 - i=1, - a_i1=jnp.zeros([], dtype=phi_0.dtype), - phi_i1=phi_0, - dphi_i1=dphi_0, - nfev=1 if (old_fval is None or gfk is None) else 0, - ngev=1 if (old_fval is None or gfk is None) else 0, - a_star=0.0, - phi_star=phi_0, - dphi_star=dphi_0, - g_star=gfk, - aux_star=aux, - ) - - def body(state): - # no amax in this version, we just double as in scipy. - # unlike original algorithm we do our next choice at the start of this loop - a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.) - - if has_aux: - (phi_i, dphi_i, g_i), aux_i = restricted_func_and_grad(a_i) - else: - phi_i, dphi_i, g_i = restricted_func_and_grad(a_i) - aux_i = jnp.nan - state = state._replace(nfev=state.nfev + 1, - ngev=state.ngev + 1) - - star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) & (state.i > 1)) - star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1) - star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i) - - zoom1 = _zoom(restricted_func_and_grad, - wolfe_one, - wolfe_two, - state.a_i1, - state.phi_i1, - state.dphi_i1, - a_i, - phi_i, - dphi_i, - gfk, - ~star_to_zoom1, - has_aux, - aux_i) - - state = state._replace(nfev=state.nfev + zoom1.nfev, - ngev=state.ngev + zoom1.ngev) - - zoom2 = _zoom(restricted_func_and_grad, - wolfe_one, - wolfe_two, - a_i, - phi_i, - dphi_i, - state.a_i1, - state.phi_i1, - state.dphi_i1, - gfk, - ~star_to_zoom2, - has_aux, - aux_i) - - state = state._replace(nfev=state.nfev + zoom2.nfev, - ngev=state.ngev + zoom2.ngev) - - state = state._replace( - done=star_to_zoom1 | state.done, - failed=(star_to_zoom1 & zoom1.failed) | state.failed, - **_binary_replace( - star_to_zoom1, - state._asdict(), - zoom1._asdict(), - keys=['a_star', 'phi_star', 'dphi_star', 'g_star', 'aux_star'], - ), + # Set low to middle or keep as it is + default = [low, value_low, slope_low] + candidate = [middle, value_middle, slope_middle] + [new_low, new_value_low, new_slope_low] = _set_values( + set_low_to_middle, candidate, default ) - state = state._replace( - done=star_to_i | state.done, - **_binary_replace( - star_to_i, - state._asdict(), - dict( - a_star=a_i, - phi_star=phi_i, - dphi_star=dphi_i, - g_star=g_i, - aux_star=aux_i, - ), - ), + + # Update cubic reference point. + # If high changed then it can be used as the new ref point. + # Otherwise, low has been updated and not kept as high + # so it can be used as the new ref point. + [new_cubic_ref, new_value_cubic_ref] = _set_values( + set_high_to_middle | set_high_to_low, + [high, value_high], + [low, value_low], + ) + + max_iter_reached = (iter_num + 1 >= self.maxiter) & (~done) + new_fail_code = jnp.where( + (fail_check2 == 0) & max_iter_reached, 3, fail_check2 + ) + + new_state = state._replace( + iter_num=iter_num + 1, + params=next_params, + value=next_value, + grad=next_grad, + aux=next_aux, + # + error=new_error, + done=done, + fail_code=new_fail_code, + failed=jnp.asarray(new_fail_code > 0), + # + low=new_low, + value_low=new_value_low, + slope_low=new_slope_low, + high=new_high, + value_high=new_value_high, + slope_high=new_slope_high, + cubic_ref=new_cubic_ref, + value_cubic_ref=new_value_cubic_ref, + # + safe_stepsize=new_safe_stepsize, + # + num_fun_eval=state.num_fun_eval + 1, + num_grad_eval=state.num_grad_eval + 1, ) - state = state._replace( - done=star_to_zoom2 | state.done, - failed=(star_to_zoom2 & zoom2.failed) | state.failed, - **_binary_replace( - star_to_zoom2, - state._asdict(), - zoom2._asdict(), - keys=['a_star', 'phi_star', 'dphi_star', 'g_star', 'aux_star'], - ), + return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + + def init_state( + self, + init_stepsize: float, + params: Any, + value: Optional[float] = None, + grad: Optional[Any] = None, + descent_direction: Optional[Any] = None, + *args, + **kwargs, + ): + """Initialize the line search state by computing all relevant quantities and store it in the initial state. + + Args: + init_stepsize: initial step size value (used in update, not in + init_state). + params: current parameters. + value: current function value (recomputed if None). + grad: current gradient (recomputed if None). + descent_direction: descent direction (negative gradient if None). + *args: additional positional arguments to be passed to ``fun``. + **kwargs: additional keyword arguments to be passed to ``fun``. + + Returns: + state + """ + # FIXME: Signature issue in base.IterativeLineSearch: Keyword argument + # before variable positional arguments. + dtype = tree_single_dtype(params) + num_fun_eval = jnp.asarray(0) + num_grad_eval = jnp.asarray(0) + del init_stepsize + aux = None + if value is None or grad is None: + (value, aux), grad = self._value_and_grad_fun_with_aux( + params, *args, **kwargs + ) + num_fun_eval = num_fun_eval + 1 + num_grad_eval = num_grad_eval + 1 + + # TODO(vroulet): ideally, we shall also provide aux as arguments to avoid + # recomputing the function. It's especially problematic if the function + # provided has an artificial aux = None coming from its instanciation via + # base._make_funs_with_aux. This requires changing the signature of + # base.IterativeLineSearch. + if aux is None and self.has_aux: + _, aux = self._fun_with_aux(params, *args, **kwargs) + + if descent_direction is None: + descent_direction = tree_scalar_mul(-1.0, grad) + + slope = tree_vdot(grad, descent_direction) + + fail_code = jnp.where(slope > 0, 1, 0) + + return ZoomLineSearchState( + iter_num=jnp.asarray(0), + params=params, + value=value, + grad=grad, + aux=aux, + # + value_init=value, + slope_init=slope, + descent_direction=descent_direction, + # + error=jnp.asarray(jnp.inf), + done=jnp.asarray(False), + fail_code=fail_code, + failed=jnp.asarray(fail_code > 0), + interval_found=jnp.asarray(False), + # + prev_stepsize=jnp.asarray(0.0).astype(dtype), + prev_value_step=value, + prev_slope_step=slope, + # + low=jnp.asarray(0.0).astype(dtype), + value_low=value, + slope_low=slope, + high=jnp.asarray(0.0).astype(dtype), + value_high=value, + slope_high=slope, + cubic_ref=jnp.asarray(0.0).astype(dtype), + value_cubic_ref=value, + # + safe_stepsize=jnp.asarray(0.0).astype(dtype), + num_fun_eval=num_fun_eval, + num_grad_eval=num_grad_eval, + ) + + def update( + self, + stepsize: float, + state: NamedTuple, + params: Any, + value: Optional[float] = None, + grad: Optional[Any] = None, + descent_direction: Optional[Any] = None, + *args, + **kwargs, + ) -> base.LineSearchStep: + """Combines Algorithms 3.5 and 3.6 of [1]. + + Final state contains next_params, next_value, next_grad, next_aux if the + linesearch succeeded. + + Args: + stepsize: current estimate of the step size. + state: named tuple containing the line search state. + params: current parameters (not used, recorded during init in state). + value: current function value (not used, recorded during init in state). + grad: current gradient (not used, recorded during init in state). + descent_direction: descent direction (not used, recorded during init in + state). + *args: additional positional arguments to be passed to ``fun``. + **kwargs: additional keyword arguments to be passed to ``fun``. + + Returns: + (stepsize, state) + """ + # FIXME: Signature issue in base.IterativeLineSearch: Keyword argument + # before variable positional arguments. + # Params, value, grad, descent direction recorded in state at initialization + dtype = tree_single_dtype(params) + init_stepsize = jnp.asarray(stepsize).astype(dtype) + del params + del value + del grad + del descent_direction + + best_stepsize, new_state_ = lax.cond( + state.interval_found, + self._zoom_into_interval, + self._search_interval, + init_stepsize, + state, + args, + kwargs, + ) + + best_stepsize, new_state = lax.cond( + (new_state_.failed) & (new_state_.iter_num == self.maxiter), + self._make_safe_step, + self._keep_step, + best_stepsize, + new_state_, + args, + kwargs, + ) + + if self.verbose: + _check_failure_status(new_state.fail_code) + + return base.LineSearchStep(stepsize=best_stepsize, state=new_state) + + def __post_init__(self): + self._fun_with_aux, _, self._value_and_grad_fun_with_aux = ( + _make_funs_with_aux( + self.fun, value_and_grad=self.value_and_grad, has_aux=self.has_aux + ) ) - state = state._replace(i=state.i + 1, a_i1=a_i, phi_i1=phi_i, dphi_i1=dphi_i) - return state - - state = lax.while_loop(lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed), - body, - state) - - status = jnp.where( - state.failed, - jnp.array(1), # zoom failed - jnp.where( - state.i > maxiter, - jnp.array(3), # maxiter reached - jnp.array(0), # passed (should be) - ), - ) - # Step sizes which are too small causes the optimizer to get stuck with a - # direction of zero in <64 bit mode - avoid with a floor on minimum step size. - alpha_k = state.a_star - alpha_k = jnp.where((jnp.finfo(alpha_k).bits != 64) - & (jnp.abs(alpha_k) < 1e-8), - jnp.sign(alpha_k) * 1e-8, - alpha_k) - param_dtype = tree_single_dtype(xk) - results = _LineSearchResults( - failed=state.failed | (~state.done), - nit=state.i - 1, # because iterations started at 1 - nfev=state.nfev, - ngev=state.ngev, - k=state.i, - a_k=alpha_k.astype(param_dtype), - f_k=state.phi_star, - aux=state.aux_star, - g_k=state.g_star, - status=status, - ) - return results diff --git a/tests/lbfgsb_test.py b/tests/lbfgsb_test.py index f249059a..82e044aa 100644 --- a/tests/lbfgsb_test.py +++ b/tests/lbfgsb_test.py @@ -50,7 +50,7 @@ def fun(x): # Rosenbrock function. fun=fun, tol=1e-5, stepsize=-1.0, - maxiter=50, + maxiter=100, history_size=5, use_gamma=True, value_and_grad=value_and_grad, @@ -60,7 +60,7 @@ def fun(x): # Rosenbrock function. scipy_lbfgsb = ScipyBoundedMinimize( fun=fun, tol=1e-5, - maxiter=50, + maxiter=100, method="L-BFGS-B", options={"maxcor": 5}, value_and_grad=value_and_grad, diff --git a/tests/zoom_linesearch_test.py b/tests/zoom_linesearch_test.py index 5e7e8311..c3fdcaa9 100644 --- a/tests/zoom_linesearch_test.py +++ b/tests/zoom_linesearch_test.py @@ -1,192 +1,382 @@ -from absl.testing import absltest, parameterized +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ZoomLineSearch.""" + +from absl.testing import absltest +from absl.testing import parameterized import jax import jax.numpy as jnp +from jaxopt import objective +from jaxopt import ZoomLineSearch from jaxopt._src import test_util -from jaxopt._src.zoom_linesearch import zoom_linesearch +from jaxopt.tree_util import tree_add_scalar_mul +from jaxopt.tree_util import tree_negative import numpy as onp import scipy.optimize +from sklearn import datasets +# pylint: disable=invalid-name class ZoomLinesearchTest(test_util.JaxoptTestCase): - # -- scalar functions; must have dphi(0.) < 0 + """Tests for ZoomLineSearch.""" def setUp(self): - self.rng = lambda : onp.random.RandomState(0) - - def assert_wolfe(self, s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""): - """ - Check that strong Wolfe conditions apply - """ - phi1 = phi(s) - phi0 = phi(0) - derphi0 = derphi(0) - derphi1 = derphi(s) - msg = "s = {}; phi(0) = {}; phi(s) = {}; phi'(0) = {}; phi'(s) = {}; {}".format( - s, phi0, phi1, derphi0, derphi1, err_msg) - - self.assertTrue(phi1 <= phi0 + c1 * s * derphi0, "Wolfe 1 failed: " + msg) - self.assertTrue(abs(derphi1) <= abs(c2 * derphi0), "Wolfe 2 failed: " + msg) - - def assert_line_wolfe(self, x, p, s, f, fprime, **kw): - self.assert_wolfe(s, phi=lambda sp: f(x + p * sp), - derphi=lambda sp: jnp.dot(fprime(x + p * sp), p), **kw) - - def _scalar_func_1(self, s): - p = -s - s ** 3 + s ** 4 - dp = -1 - 3 * s ** 2 + 4 * s ** 3 + self.rng = lambda: onp.random.RandomState(0) + + def _check_step_in_state(self, x, p, s, fun, fun_der, state): + step = tree_add_scalar_mul(x, s, p) + self.assertAllClose(step, state.params, atol=1e-5, rtol=1e-5) + self.assertAllClose(fun(step), state.value, atol=1e-5, rtol=1e-5) + self.assertAllClose(fun_der(step), state.grad, atol=1e-5, rtol=1e-5) + self.assertTrue(~(state.failed & state.done)) + + def _assert_conds(self, s, value_fun, slope_fun, c1=1e-4, c2=0.9, err_msg=""): + value_init = value_fun(0) + value_step = value_fun(s) + slope_init = slope_fun(0) + slope_step = slope_fun(s) + msg = ( + "s = {}; value(0) = {}; value(s) = {}; slope(0) = {}; slope(s) = {}; {}" + .format(s, value_init, value_step, slope_init, slope_step, err_msg) + ) + + self.assertTrue( + value_step <= value_init + c1 * s * slope_init, + "Sufficient decrease (Armijo) failed: " + msg, + ) + self.assertTrue( + abs(slope_step) <= abs(c2 * slope_init), + "Small curvature (strong Wolfe) failed: " + msg, + ) + + def _assert_line_conds(self, x, p, s, fun, fun_der, **kw): + self._assert_conds( + s, + value_fun=lambda sp: fun(x + p * sp), + slope_fun=lambda sp: jnp.dot(fun_der(x + p * sp), p), + **kw, + ) + + # -- scalar functions + + def _scalar_fun_1(self, s): + p = -s - s**3 + s**4 + dp = -1 - 3 * s**2 + 4 * s**3 return p, dp - def _scalar_func_2(self, s): - p = jnp.exp(-4 * s) + s ** 2 + def _scalar_fun_2(self, s): + p = jnp.exp(-4 * s) + s**2 dp = -4 * jnp.exp(-4 * s) + 2 * s return p, dp - def _scalar_func_3(self, s): + def _scalar_fun_3(self, s): p = -jnp.sin(10 * s) dp = -10 * jnp.cos(10 * s) return p, dp # -- n-d functions - def _line_func_1(self, x): + def _line_fun_1(self, x): f = jnp.dot(x, x) df = 2 * x return f, df - def _line_func_2(self, x): + def _line_fun_2(self, x): f = jnp.dot(x, jnp.dot(self.A, x)) + 1 df = jnp.dot(self.A + self.A.T, x) return f, df - # -- Generic scalar searches + def _rosenbrock_fun(self, x): + return sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0) - @parameterized.product(name=["_scalar_func_1", - "_scalar_func_2", - "_scalar_func_3"]) - def test_scalar_search_wolfe2(self, name): + def _line_fun_3(self, x): + # Rosenbrock function + f = self._rosenbrock_fun(x) + df = jax.grad(self._rosenbrock_fun)(x) + return f, df - def bind_index(func, idx): + # -- Generic scalar searches + + @parameterized.product( + name=["_scalar_fun_1", "_scalar_fun_2", "_scalar_fun_3"] + ) + def test_scalar_search(self, name): + def bind_index(fun, idx): # Remember Python's closure semantics! - return lambda *a, **kw: func(*a, **kw)[idx] + return lambda *a, **kw: fun(*a, **kw)[idx] value = getattr(self, name) - phi = bind_index(value, 0) - derphi = bind_index(value, 1) - for old_phi0 in self.rng().randn(3): - res = zoom_linesearch(phi, 0., 1.) - s, phi1, derphi1 = res.a_k, res.f_k, res.g_k - self.assertAllClose(phi1, phi(s), check_dtypes=False, atol=1e-6) - if derphi1 is not None: - self.assertAllClose(derphi1, derphi(s), check_dtypes=False, atol=1e-6) - self.assert_wolfe(s, phi, derphi, err_msg=f"{name} {old_phi0:g}") + fun = bind_index(value, 0) + fun_der = bind_index(value, 1) + for old_value in self.rng().randn(3): + ls = ZoomLineSearch(fun) + x, p = 0.0, 1.0 + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self._check_step_in_state(x, p, s, fun, fun_der, state) + self._assert_conds(s, fun, fun_der, err_msg=f"{name} {old_value:g}") # -- Generic line searches - @parameterized.product(name=["_line_func_1", "_line_func_2"]) - def test_line_search_wolfe2(self, name): - def bind_index(func, idx): + @parameterized.product(name=["_line_fun_1", "_line_fun_2", "_line_fun_3"]) + def test_line_search(self, name): + def bind_index(fun, idx): # Remember Python's closure semantics! - return lambda *a, **kw: func(*a, **kw)[idx] + return lambda *a, **kw: fun(*a, **kw)[idx] value = getattr(self, name) - f = bind_index(value, 0) - fprime = bind_index(value, 1) + fun = bind_index(value, 0) + fun_der = bind_index(value, 1) k = 0 N = 20 rng = self.rng() - # sets A in one of the line funcs + # sets A in one of the line functions self.A = self.rng().randn(N, N) while k < 9: x = rng.randn(N) p = rng.randn(N) - if jnp.dot(p, fprime(x)) >= 0: + if jnp.dot(p, fun_der(x)) >= 0: # always pick a descent pk continue + if fun(x + 1e6 * p) < fun(x): + # If the function is unbounded below, the linesearch cannot finish + continue k += 1 - f0 = f(x) - g0 = fprime(x) - self.fcount = 0 - res = zoom_linesearch(f, x, p, old_fval=f0, gfk=g0) - s = res.a_k - fv = res.f_k - gv = res.g_k - self.assertAllClose(fv, f(x + s * p), check_dtypes=False, atol=1e-5) - if gv is not None: - self.assertAllClose(gv, fprime(x + s * p), check_dtypes=False, atol=1e-5) - - def test_line_search_wolfe2_bounds(self): + f0 = fun(x) + g0 = fun_der(x) + ls = ZoomLineSearch(fun) + s, state = ls.run( + init_stepsize=1.0, params=x, descent_direction=p, value=f0, grad=g0 + ) + self._check_step_in_state(x, p, s, fun, fun_der, state) + self._assert_line_conds(x, p, s, fun, fun_der, err_msg=f"{name}") + + def test_logreg(self): + x, y = datasets.make_classification( + n_samples=10, n_features=5, n_classes=2, n_informative=3, random_state=0 + ) + data = (x, y) + fun = objective.binary_logreg + + def fun_(w): + return fun(w, data) + + rng = onp.random.RandomState(0) + w_init = rng.randn(x.shape[1]) + initial_grad = jax.grad(fun)(w_init, data=data) + descent_dir = tree_negative(initial_grad) + + # Call to run. + ls = ZoomLineSearch(fun=fun, maxiter=20) + stepsize, state = ls.run(init_stepsize=1.0, params=w_init, data=data) + + self._assert_line_conds( + w_init, descent_dir, stepsize, fun_, jax.grad(fun_), c1=ls.c1, c2=ls.c2 + ) + self._check_step_in_state( + w_init, descent_dir, stepsize, fun_, jax.grad(fun_), state + ) + + # Call to run with value_and_grad=True + ls = ZoomLineSearch( + fun=jax.value_and_grad(fun), maxiter=20, value_and_grad=True + ) + stepsize, state = ls.run(init_stepsize=1.0, params=w_init, data=data) + + self._assert_line_conds( + w_init, descent_dir, stepsize, fun_, jax.grad(fun_), c1=ls.c1, c2=ls.c2 + ) + self._check_step_in_state( + w_init, descent_dir, stepsize, fun_, jax.grad(fun_), state + ) + + def test_failure_cases(self): # See gh-7475 # For this f and p, starting at a point on axis 0, the strong Wolfe # condition 2 is met if and only if the step length s satisfies # |x + s| <= c2 * |x| - f = lambda x: jnp.dot(x, x) - fp = lambda x: 2 * x - p = jnp.array([1.0, 0.0]) + def fun(x): + return jnp.dot(x, x) - # Smallest s satisfying strong Wolfe conditions for these arguments is 30 - x = -60 * p - c2 = 0.5 + def fun_der(x): + return 2.0 * x - res = zoom_linesearch(f, x, p, c2=c2) - s = res.a_k - # s, _, _, _, _, _ = ls.zoom_linesearch(f, fp, x, p, amax=30, c2=c2) - self.assert_line_wolfe(x, p, s, f, fp) - self.assertTrue(s >= 30.) + c2 = 0.5 + p = jnp.array([1.0, 0.0]) - res = zoom_linesearch(f, x, p, c2=c2, maxiter=5) - self.assertTrue(res.failed) - # s=30 will only be tried on the 6th iteration, so this won't converge + # 1. Test that the line search fails for p not a descent direction + x = 60 * p + ls = ZoomLineSearch(fun, c2=c2) + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self._check_step_in_state(x, p, s, fun, fun_der, state) + # Check that we were not able to make a step or an infinitesimal one + self.assertTrue(s < 1e-5) + self.assertTrue(state.fail_code == 1) + + # 2. Test that the line search fails if the maximum stepsize is too small + # Here, smallest s satisfying strong Wolfe conditions for c2=0.5 is 30 + x = -60 * p + ls = ZoomLineSearch(fun, c2=c2, max_stepsize=10) + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self._check_step_in_state(x, p, s, fun, fun_der, state) + # Check that we still made a step + self.assertTrue(s == 10.0) + self.assertTrue(state.fail_code == 2) + + # 3. s=30 will only be tried on the 6th iteration, so this fails because + # the maximum number of iterations is reached. + ls = ZoomLineSearch(fun, c2=c2, maxiter=5) + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self._check_step_in_state(x, p, s, fun, fun_der, state) + # Check that we still made a step + self.assertTrue(s == 16.0) + self.assertTrue(state.fail_code == 3) + + # Check if it works normally + ls = ZoomLineSearch(fun, c2=c2) + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self._assert_line_conds(x, p, s, fun, fun_der, c1=ls.c1, c2=c2) + self._check_step_in_state(x, p, s, fun, fun_der, state) + self.assertTrue(s >= 30.0) + + # Check failure for a very flat function + def fun_flat(x): + return jnp.exp(-1 / x**2) + + x = jnp.asarray(-0.2) + if x.dtype == "float64": + x = x / 2.0 + ls = ZoomLineSearch(fun_flat) + _, state = ls.run(init_stepsize=1.0, params=x) + self.assertTrue(state.fail_code == 4) + + # Check behavior for inf/nan values + def fun_inf(x): + return jnp.log(x) + + x = 1.0 + p = -2.0 + ls = ZoomLineSearch(fun_inf) + s, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + self.assertTrue(state.fail_code == 5) def test_aux_value(self): - def f(x): + def fun(x): return jnp.cos(jnp.sum(jnp.exp(-x)) ** 2), x - xk = jnp.ones(2) - pk = jnp.array([-0.5, -0.25]) - res = zoom_linesearch(f, xk, pk, maxiter=100, has_aux=True) - new_stepsize = res.a_k - new_xk = xk + new_stepsize * pk - self.assertArraysEqual(res.aux, new_xk) + x = jnp.ones(2) + p = jnp.array([-0.5, -0.25]) + ls = ZoomLineSearch(fun=fun, maxiter=100, has_aux=True) + new_stepsize, state = ls.run( + init_stepsize=1.0, params=x, descent_direction=p + ) + new_x = x + new_stepsize * p + self.assertArraysEqual(state.aux, new_x) + + def test_against_scipy(self): + def fun(x): + return jnp.cos(jnp.sum(jnp.exp(-x)) ** 2) - def test_line_search(self): + x = jnp.ones(2) + p = jnp.array([-0.5, -0.25]) - def f(x): - return jnp.cos(jnp.sum(jnp.exp(-x)) ** 2) + ls = ZoomLineSearch(fun=fun, maxiter=20) + stepsize, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) + + ls2 = ZoomLineSearch( + fun=jax.value_and_grad(fun), + maxiter=20, + value_and_grad=True, + ) + stepsize2, state2 = ls2.run( + init_stepsize=1.0, params=x, descent_direction=p + ) + + scipy_res = scipy.optimize.line_search(fun, jax.grad(fun), x, p) + + self.assertAllClose(scipy_res[0], stepsize, atol=1e-5, check_dtypes=False) + self.assertAllClose( + scipy_res[3], state.value, atol=1e-5, check_dtypes=False + ) + + self.assertAllClose(scipy_res[0], stepsize2, atol=1e-5, check_dtypes=False) + self.assertAllClose( + scipy_res[3], state2.value, atol=1e-5, check_dtypes=False + ) + + def test_high_smaller_than_low(self): + # See google/jax/issues/16236 + def fun(x): + return x**2 + + # Descent direction p chosen such that, with x+p + # the first trial of the algorithm, + # 1. p*f'(x) < 0 (valid descent direction) + # 2. x+p satisifies sufficient decrease + # 3. x+p does not satisfy small curvature + # 4. f'(x+p) > 0 + # As a result, the first trial starts with high < low + + x = -1.0 + p = -1.95 * x - # assert not zoom_linesearch(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed - xk = jnp.ones(2) - pk = jnp.array([-0.5, -0.25]) - res = zoom_linesearch(f, xk, pk, maxiter=100) - res2 = zoom_linesearch(jax.value_and_grad(f), xk, pk, maxiter=100, - value_and_grad=True) - scipy_res = scipy.optimize.line_search(f, jax.grad(f), xk, pk) + ls = ZoomLineSearch(fun) + _, state = ls.run(init_stepsize=1.0, params=x, descent_direction=p) - self.assertAllClose(scipy_res[0], res.a_k, atol=1e-5, check_dtypes=False) - self.assertAllClose(scipy_res[3], res.f_k, atol=1e-5, check_dtypes=False) - self.assertAllClose(scipy_res[0], res2.a_k, atol=1e-5, check_dtypes=False) - self.assertAllClose(scipy_res[3], res2.f_k, atol=1e-5, check_dtypes=False) + self.assertFalse(state.failed) @parameterized.product(out_dtype=[jnp.float32, jnp.float64]) def test_correct_dtypes(self, out_dtype): - def f(x): - return jnp.cos(jnp.sum(jnp.exp(-x)) ** 2).astype(out_dtype) + def fun(x): + return jnp.cos(jnp.sum(jnp.exp(-x)) ** 2).astype(out_dtype) with jax.experimental.enable_x64(): xk = jnp.ones(2, dtype=jnp.float32) pk = jnp.array([-0.5, -0.25], dtype=jnp.float32) - res = zoom_linesearch(f, xk, pk, maxiter=100) - for name in ("failed",): - self.assertEqual(getattr(res, name).dtype, jnp.bool_) - for name in ("k", "nit", "nfev", "ngev"): - self.assertEqual(getattr(res, name).dtype, jnp.int64) - for name in ("g_k",): - self.assertEqual(getattr(res, name).dtype, jnp.float32, name) - for name in ("f_k",): - self.assertEqual(getattr(res, name).dtype, out_dtype) + ls = ZoomLineSearch(fun, maxiter=100) + _, state = ls.run(init_stepsize=1.0, params=xk, descent_direction=pk) + for name in ("done", "interval_found", "failed"): + self.assertEqual(getattr(state, name).dtype, jnp.bool_) + for name in ("iter_num", "num_fun_eval", "num_grad_eval", "fail_code"): + self.assertEqual(getattr(state, name).dtype, jnp.int64) + for name in ( + "params", + "grad", + "descent_direction", + "slope_init", + "prev_slope_step", + "slope_low", + "slope_high", + ): + self.assertEqual(getattr(state, name).dtype, jnp.float32, name) + for name in ("prev_stepsize", "low", "high", "cubic_ref"): + self.assertEqual(getattr(state, name).dtype, jnp.float32, name) + for name in ( + "value", + "value_init", + "prev_value_step", + "value_low", + "value_high", + "value_cubic_ref", + "error", + ): + self.assertEqual(getattr(state, name).dtype, out_dtype) if __name__ == "__main__": + # Uncomment the line below in order to run in float64. + # jax.config.update("jax_enable_x64", True) absltest.main()