diff --git a/jaxopt/_src/hager_zhang_linesearch.py b/jaxopt/_src/hager_zhang_linesearch.py index d8dd6e32..820e7d2d 100644 --- a/jaxopt/_src/hager_zhang_linesearch.py +++ b/jaxopt/_src/hager_zhang_linesearch.py @@ -49,6 +49,7 @@ class HagerZhangLineSearchState(NamedTuple): error: float params: Any grad: Any + aux: Optional[Any] = None @dataclass(eq=False) @@ -62,6 +63,9 @@ class HagerZhangLineSearch(base.IterativeLineSearch): value_and_grad: if ``False``, ``fun`` should return the function value only. If ``True``, ``fun`` should return both the function value and the gradient. + has_aux: if ``False``, ``fun`` should return the function value only. + If ``True``, ``fun`` should return a pair ``(value, aux)`` where ``aux`` + is a pytree of auxiliary values. c1: constant used by the Wolfe and Approximate Wolfe condition. c2: constant strictly less than 1 used by the Wolfe and Approximate Wolfe @@ -79,6 +83,7 @@ class HagerZhangLineSearch(base.IterativeLineSearch): """ fun: Callable # pylint:disable=g-bare-generic value_and_grad: bool = False + has_aux: bool = False maxiter: int = 30 tol: float = 0. @@ -95,8 +100,11 @@ class HagerZhangLineSearch(base.IterativeLineSearch): unroll: base.AutoOrBoolean = "auto" def _value_and_grad_on_line(self, x, c, descent_direction, *args, **kwargs): - value, grad = self._value_and_grad_fun( - tree_add_scalar_mul(x, c, descent_direction), *args, **kwargs) + z = tree_add_scalar_mul(x, c, descent_direction) + if self.has_aux: + (value, _), grad = self._value_and_grad_fun(z, *args, **kwargs) + else: + value, grad = self._value_and_grad_fun(z, *args, **kwargs) return value, tree_vdot(grad, descent_direction) def _satisfies_wolfe_and_approx_wolfe( @@ -324,7 +332,11 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg del init_stepsize if value is None or grad is None: - value, grad = self._value_and_grad_fun(params, *args, **kwargs) + if self.has_aux: + (value, _), grad = self._value_and_grad_fun(params, *args, **kwargs) + else: + value, grad = self._value_and_grad_fun(params, *args, **kwargs) + if descent_direction is None: descent_direction = tree_scalar_mul(-1, grad) @@ -364,6 +376,7 @@ def init_state(self, # pylint:disable=keyword-arg-before-vararg error=error, done=done, value=value, + aux=None, # we do not need to have aux in the initial state params=params, grad=grad) @@ -392,7 +405,11 @@ def update(self, # pylint:disable=keyword-arg-before-vararg """ if value is None or grad is None: - value, grad = self._value_and_grad_fun(params, *args, **kwargs) + if self.has_aux: + (value, _), grad = self._value_and_grad_fun(params, *args, **kwargs) + else: + value, grad = self._value_and_grad_fun(params, *args, **kwargs) + if descent_direction is None: descent_direction = tree_scalar_mul(-1, grad) @@ -436,7 +453,13 @@ def _reupdate(): new_stepsize = jnp.where(state.done, stepsize, best_point) new_params = tree_add_scalar_mul(params, best_point, descent_direction) - new_value, new_grad = self._value_and_grad_fun(new_params, *args, **kwargs) + if self.has_aux: + (new_value, new_aux), new_grad = self._value_and_grad_fun( + new_params, *args, **kwargs) + else: + new_value, new_grad = self._value_and_grad_fun( + new_params, *args, **kwargs) + new_aux = None error = jnp.where(state.done, state.error, self._satisfies_wolfe_and_approx_wolfe( @@ -453,6 +476,7 @@ def _reupdate(): iter_num=state.iter_num + 1, value=new_value, grad=new_grad, + aux=new_aux, params=new_params, low=new_low, high=new_high, @@ -465,4 +489,4 @@ def __post_init__(self): if self.value_and_grad: self._value_and_grad_fun = self.fun else: - self._value_and_grad_fun = jax.value_and_grad(self.fun) + self._value_and_grad_fun = jax.value_and_grad(self.fun, has_aux=self.has_aux) diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index c9ed19ab..24d87e3f 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -385,11 +385,10 @@ def update(self, params, value, grad, descent_direction, *args, **kwargs) - new_params = ls_state.params - (new_value, new_aux), new_grad = self._value_and_grad_with_aux( - new_params, *args, **kwargs) - + new_value = ls_state.value + new_grad = ls_state.grad + new_aux = ls_state.aux else: raise ValueError("Invalid name in 'linesearch' option.")