From 4bb4f110db510657e810c12f9f4e9c4a4ec5b548 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:29:49 +0200 Subject: [PATCH 1/8] added the computed initialization of the jac approx in Broyden --- jaxopt/_src/broyden.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index dc7fe79b..de911960 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -80,14 +80,13 @@ def inv_jacobian_product(pytree: Any, Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`. c_history: pytree with the same structure as `pytree`. Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`. - gamma: scalar to use for the initial inverse jacobian approximation, + gamma: pytree with scalars to use for the initial inverse jacobian approximation, i.e., `gamma * I`. start: starting index in the circular buffer. """ fun = partial(inv_jacobian_product_leaf, - gamma=gamma, start=start) - return tree_map(fun, pytree, d_history, c_history) + return tree_map(fun, pytree, d_history, c_history, gamma) def inv_jacobian_rproduct(pytree: Any, d_history: Any, @@ -115,7 +114,7 @@ class BroydenState(NamedTuple): error: float d_history: Any c_history: Any - gamma: jnp.ndarray + gamma: Any aux: Optional[Any] = None failed_linesearch: bool = False @@ -205,6 +204,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 + compute_gamma: bool = False implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -244,24 +244,44 @@ def init_state(self, iter_num=init_params.state.iter_num, stepsize=init_params.state.stepsize, ) + # XXX: not computing the jacobian init approx + # when starting from an OptStep object init_params = init_params.params dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) else: dtype = tree_single_dtype(init_params) + value, aux = self._value_with_aux(init_params, *args, **kwargs) + if self.compute_gamma: + # we use scipy's formula: + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569 + # self.alpha = 0.5*max(norm(x0), 1) / normf0 + normf0 = tree_map(jnp.linalg.norm, value) + normx0 = tree_map(jnp.linalg.norm, init_params) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.max(x, 1), normx0) + def safe_divide_by_zero(x, y): + # a classical division of x by x + # when y == 0 then return 1 + return jnp.where(y == 0, 1, x / y) + gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) + return gamma + else: + gamma = self.gamma + # repeat gamma as a pytre of the shape of init_params + gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), c_history=init_history(init_params, self.history_size), - gamma=jnp.asarray(self.gamma, dtype=dtype), + gamma=gamma, iter_num=jnp.asarray(0), stepsize=jnp.asarray(self.max_stepsize, dtype=dtype), ) - value, aux = self._value_with_aux(init_params, *args, **kwargs) return BroydenState(value=value, error=jnp.asarray(jnp.inf), **state_kwargs, aux=aux, failed_linesearch=jnp.asarray(False), - num_fun_eval=jnp.array(1, base.NUM_EVAL_DTYPE), + num_fun_eval=jnp.array(1, base.NUM_EVAL_DTYPE), num_linesearch_iter=jnp.array(0, base.NUM_EVAL_DTYPE) ) From c6b20e698dd10cfa2c0ba2b78cd6beae84e09058 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:33:41 +0200 Subject: [PATCH 2/8] corrected right product gamma --- jaxopt/_src/broyden.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index de911960..c9877451 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -93,7 +93,8 @@ def inv_jacobian_rproduct(pytree: Any, c_history: Any, gamma: float = 1.0, start: int = 0): - return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start) + gamma_conj = tree_map(jnp.conjugate, gamma) + return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start) def init_history(pytree, history_size): From e2d4ec687f8d32b9f7e8f7af39a17584243d5dd0 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:45:45 +0200 Subject: [PATCH 3/8] few corrections to gamma computation --- jaxopt/_src/broyden.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index c9877451..dbaaa597 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -205,7 +205,7 @@ class Broyden(base.IterativeSolver): history_size: int = None gamma: float = 1.0 - compute_gamma: bool = False + compute_gamma: bool = True implicit_diff: bool = True implicit_diff_solve: Optional[Callable] = None @@ -259,13 +259,12 @@ def init_state(self, # self.alpha = 0.5*max(norm(x0), 1) / normf0 normf0 = tree_map(jnp.linalg.norm, value) normx0 = tree_map(jnp.linalg.norm, init_params) - clipped_normx0 = tree_map(lambda x: 0.5 * jnp.max(x, 1), normx0) + clipped_normx0 = tree_map(lambda x: - 0.5 * jnp.maximum(x, 1), normx0) def safe_divide_by_zero(x, y): # a classical division of x by x # when y == 0 then return 1 return jnp.where(y == 0, 1, x / y) gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) - return gamma else: gamma = self.gamma # repeat gamma as a pytre of the shape of init_params From 44646a307c2a8b18bf538146821a2b99eb8ca7e7 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 10 Sep 2023 18:46:06 +0200 Subject: [PATCH 4/8] increased tolerance for a broyden test --- tests/broyden_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 60503a55..5e24b6bd 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -133,7 +133,7 @@ def test_affine_contractive_mapping(self): b = jax.random.uniform(subkey, shape=(n,)) def g(x, M, b): return M @ x + b - x - tol = 1e-6 + tol = 5e-6 fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) From 23529d5dfd11e7d25ea60e8c85e5b5a7d89a9743 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 16:40:32 +0100 Subject: [PATCH 5/8] typos correction + test correction => needed more iterations --- jaxopt/_src/broyden.py | 6 +++--- tests/broyden_test.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index c6d42258..53267b21 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -259,15 +259,15 @@ def init_state(self, # self.alpha = 0.5*max(norm(x0), 1) / normf0 normf0 = tree_map(jnp.linalg.norm, value) normx0 = tree_map(jnp.linalg.norm, init_params) - clipped_normx0 = tree_map(lambda x: - 0.5 * jnp.maximum(x, 1), normx0) + clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0) def safe_divide_by_zero(x, y): - # a classical division of x by x + # a classical division of x by y # when y == 0 then return 1 return jnp.where(y == 0, 1, x / y) gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0) else: gamma = self.gamma - # repeat gamma as a pytre of the shape of init_params + # repeat gamma as a pytree of the shape of init_params gamma = tree_map(lambda x: jnp.array(gamma), init_params) state_kwargs = dict( d_history=init_history(init_params, self.history_size), diff --git a/tests/broyden_test.py b/tests/broyden_test.py index 5e24b6bd..ac42bd5e 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -133,8 +133,8 @@ def test_affine_contractive_mapping(self): b = jax.random.uniform(subkey, shape=(n,)) def g(x, M, b): return M @ x + b - x - tol = 5e-6 - fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1) + tol = 1e-6 + fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True) x0 = jnp.zeros_like(b) sol, state = fp.run(x0, M, b) self.assertLess(state.error, tol) From 6b3672239e441b451941c1f71835d957c4e6a699 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 16:54:50 +0100 Subject: [PATCH 6/8] removed gamma setting in broyden test, since smart init is available --- tests/broyden_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index ac42bd5e..e14196a3 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0) + sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol) From 29171f2b582850addc1db89617b6dcae3df4fdc0 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 17:27:39 +0100 Subject: [PATCH 7/8] switched to new API for linesearch in broyden --- jaxopt/_src/backtracking_linesearch.py | 8 +- jaxopt/_src/broyden.py | 104 +++++++++++++++---------- tests/broyden_test.py | 2 +- 3 files changed, 67 insertions(+), 47 deletions(-) diff --git a/jaxopt/_src/backtracking_linesearch.py b/jaxopt/_src/backtracking_linesearch.py index 189e3df5..60f032f4 100644 --- a/jaxopt/_src/backtracking_linesearch.py +++ b/jaxopt/_src/backtracking_linesearch.py @@ -39,7 +39,7 @@ class BacktrackingLineSearchState(NamedTuple): params: Any value: float grad: Any # either initial or final for armijo or glodstein - value_init: float + value_init: float grad_init: Any error: float done: bool @@ -260,11 +260,11 @@ def update( if self.condition in ["armijo", "goldstein"]: # If we are done for the armijo or the goldstein conditions, - # we compute the final gradient (we had not computed it before since + # we compute the final gradient (we had not computed it before since # these conditions did not require it) new_grad = cond(done | failed, self._compute_final_grad, - lambda *_: grad, + lambda *_: grad, new_params, fun_args, fun_kwargs, jit=self.jit) maybe_additional_eval = jnp.asarray(done | failed, dtype=base.NUM_EVAL_DTYPE) @@ -284,7 +284,7 @@ def update( num_grad_eval=num_grad_eval) return base.LineSearchStep(stepsize=new_stepsize, state=new_state) - + def _compute_final_grad(self, params, fun_args, fun_kwargs): return self._grad_with_aux(params, *fun_args, **fun_kwargs)[0] diff --git a/jaxopt/_src/broyden.py b/jaxopt/_src/broyden.py index 53267b21..83e34988 100644 --- a/jaxopt/_src/broyden.py +++ b/jaxopt/_src/broyden.py @@ -14,6 +14,8 @@ """Limited-memory Broyden method""" +import warnings + from functools import partial from typing import Any @@ -28,7 +30,7 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.backtracking_linesearch import BacktrackingLineSearch +from jaxopt._src.linesearch_util import _setup_linesearch, _init_stepsize from jaxopt.tree_util import tree_map from jaxopt.tree_util import tree_vdot from jaxopt.tree_util import tree_add_scalar_mul @@ -194,10 +196,11 @@ class Broyden(base.IterativeSolver): stepsize: Union[float, Callable] = 0.0 linesearch: str = "backtracking" + linesearch_init: str = "increase" stop_if_linesearch_fails: bool = False - condition: str = "wolfe" + condition: Any = None # deprecated in v0.8 maxls: int = 15 - decrease_factor: float = 0.8 + decrease_factor: Any = None # deprecated in v0.8 increase_factor: float = 1.5 max_stepsize: float = 1.0 # FIXME: should depend on whether float32 or float64 is used. @@ -328,45 +331,29 @@ def update(self, use_linesearch = not isinstance(self.stepsize, Callable) and self.stepsize <= 0 if use_linesearch: - if self.linesearch == "backtracking": - # we need to build the function used for the line search - # which is going to be the squared norm of the original function - # as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278 - # we then need to check if the gradient can be obtained with jax - # and if not we can build it in the same fashion as scipy - # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285 - def ls_fun_with_aux(params, *args, **kwargs): - f, aux = self._value_with_aux(params, *args, **kwargs) - norm_squared = tree_l2_norm(f, squared=True) - return norm_squared, (f, aux) - # here we need a check if the function is not smooth - ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True) - ls = BacktrackingLineSearch(fun=ls_fun_with_aux_and_grad, - value_and_grad=True, - maxiter=self.maxls, - decrease_factor=self.decrease_factor, - max_stepsize=self.max_stepsize, - condition=self.condition, - jit=self.jit, - unroll=self.unroll, - has_aux=True, - tol=1e-2) - init_stepsize = jnp.where(state.stepsize <= self.min_stepsize, - # If stepsize became too small, we restart it. - self.max_stepsize, - # Else, we increase a bit the previous one. - state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize, - params, value, None, - descent_direction, - fun_args=args, fun_kwargs=kwargs) - new_value, new_aux = ls_state.aux - new_params = ls_state.params - new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num - new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval - failed_linesearch = ls_state.failed - else: - raise ValueError("Invalid name in 'linesearch' option.") + init_stepsize = _init_stepsize( + self.linesearch_init, + self.max_stepsize, + self.min_stepsize, + self.increase_factor, + state.stepsize, + ) + new_stepsize, ls_state = self.run_ls( + init_stepsize, + params, + value=tree_l2_norm(value), + # in the case of Broyden, it's the value that's actually the equivalent + # of the gradient in the optimization case. + grad=value, + descent_direction=descent_direction, + fun_args=args, + fun_kwargs=kwargs, + ) + new_value, new_aux = ls_state.aux + new_params = ls_state.params + new_num_linesearch_iter = state.num_linesearch_iter + ls_state.iter_num + new_num_fun_eval = state.num_fun_eval + ls_state.num_fun_eval + failed_linesearch = ls_state.failed else: # without line search if isinstance(self.stepsize, Callable): @@ -429,3 +416,36 @@ def __post_init__(self): if self.history_size is None: self.history_size = self.maxiter + + # we need to build the function used for the line search + # which is going to be the squared norm of the original function + # as in scipy https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L278 + # we then need to check if thtree_l2_norme gradient can be obtained with jax + # and if not we can build it in the same fashion as scipy + # https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L285 + def ls_fun_with_aux(params, *args, **kwargs): + f, aux = self._value_with_aux(params, *args, **kwargs) + norm_squared = tree_l2_norm(f, squared=True) + return norm_squared, (f, aux) + # here we need a check if the function is not smooth + ls_fun_with_aux_and_grad = jax.value_and_grad(ls_fun_with_aux, has_aux=True) + self.linesearch_solver = _setup_linesearch( + linesearch=self.linesearch, + fun=ls_fun_with_aux_and_grad, + value_and_grad=True, + has_aux=True, + maxlsiter=self.maxls, + max_stepsize=self.max_stepsize, + jit=self.jit, + unroll=self.unroll, + verbose=self.verbose, + ) + self.run_ls = self.linesearch_solver.run + + # FIXME: to remove in future releases + if self.condition is not None: + warnings.warn("Argument condition is deprecated", DeprecationWarning) + if self.decrease_factor is not None: + warnings.warn( + "Argument decrease_factor is deprecated", DeprecationWarning + ) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index e14196a3..a7822030 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0) + sol, state = Broyden(g, maxiter=1000, tol=tol, jit=jit, stop_if_linesearch_fails=True, linesearch="backtracking", maxls=30).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol) From 0fc0beb186056b25aef2d4fd0d378b85d8acebf5 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Sun, 3 Dec 2023 17:28:40 +0100 Subject: [PATCH 8/8] corrected test --- tests/broyden_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/broyden_test.py b/tests/broyden_test.py index a7822030..1c9427eb 100644 --- a/tests/broyden_test.py +++ b/tests/broyden_test.py @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11 return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1] x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]]) tol = 1e-6 - sol, state = Broyden(g, maxiter=1000, tol=tol, jit=jit, stop_if_linesearch_fails=True, linesearch="backtracking", maxls=30).run(x0) + sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, stop_if_linesearch_fails=True).run(x0) self.assertLess(state.error, tol) g_sol_norm = tree_l2_norm(g(sol)) self.assertLess(g_sol_norm, tol)