Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the computed initialization of the jac approx in Broyden #529

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
34 changes: 27 additions & 7 deletions jaxopt/_src/broyden.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,21 @@ def inv_jacobian_product(pytree: Any,
Leaves contain v variables, i.e., `(x[k] - x[k-1])^T B / ((g[k] - g[k-1])^T (x[k] - x[k-1])^T B)`.
c_history: pytree with the same structure as `pytree`.
Leaves contain u variables, i.e., `(x[k] - x[k-1]) - B(g[k] - g[k-1])`.
gamma: scalar to use for the initial inverse jacobian approximation,
gamma: pytree with scalars to use for the initial inverse jacobian approximation,
i.e., `gamma * I`.
start: starting index in the circular buffer.
"""
fun = partial(inv_jacobian_product_leaf,
gamma=gamma,
start=start)
return tree_map(fun, pytree, d_history, c_history)
return tree_map(fun, pytree, d_history, c_history, gamma)

def inv_jacobian_rproduct(pytree: Any,
d_history: Any,
c_history: Any,
gamma: float = 1.0,
start: int = 0):
return inv_jacobian_product(pytree, c_history, d_history, jnp.conjugate(gamma), start)
gamma_conj = tree_map(jnp.conjugate, gamma)
return inv_jacobian_product(pytree, c_history, d_history, gamma_conj, start)


def init_history(pytree, history_size):
Expand All @@ -115,7 +115,7 @@ class BroydenState(NamedTuple):
error: float
d_history: Any
c_history: Any
gamma: jnp.ndarray
gamma: Any
aux: Optional[Any] = None
failed_linesearch: bool = False

Expand Down Expand Up @@ -205,6 +205,7 @@ class Broyden(base.IterativeSolver):

history_size: int = None
gamma: float = 1.0
compute_gamma: bool = True

implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None
Expand Down Expand Up @@ -244,18 +245,37 @@ def init_state(self,
iter_num=init_params.state.iter_num,
stepsize=init_params.state.stepsize,
)
# XXX: not computing the jacobian init approx
# when starting from an OptStep object
init_params = init_params.params
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
else:
dtype = tree_single_dtype(init_params)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
if self.compute_gamma:
# we use scipy's formula:
# https://github.com/scipy/scipy/blob/main/scipy/optimize/_nonlin.py#L569
# self.alpha = 0.5*max(norm(x0), 1) / normf0
normf0 = tree_map(jnp.linalg.norm, value)
normx0 = tree_map(jnp.linalg.norm, init_params)
clipped_normx0 = tree_map(lambda x: 0.5 * jnp.maximum(x, 1), normx0)
def safe_divide_by_zero(x, y):
# a classical division of x by y
# when y == 0 then return 1
return jnp.where(y == 0, 1, x / y)
gamma = tree_map(safe_divide_by_zero, clipped_normx0, normf0)
else:
gamma = self.gamma
# repeat gamma as a pytree of the shape of init_params
gamma = tree_map(lambda x: jnp.array(gamma), init_params)
state_kwargs = dict(
d_history=init_history(init_params, self.history_size),
c_history=init_history(init_params, self.history_size),
gamma=jnp.asarray(self.gamma, dtype=dtype),
gamma=gamma,
iter_num=jnp.asarray(0),
stepsize=jnp.asarray(self.max_stepsize, dtype=dtype),
)
value, aux = self._value_with_aux(init_params, *args, **kwargs)
return BroydenState(value=value,
error=jnp.asarray(jnp.inf),
**state_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions tests/broyden_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def g(x): # Another fixed point exists for x[0] : ~1.11
return jnp.sin(x[0]) * (x[0] ** 2) - x[0], x[1] ** 3 - x[1]
x0 = jnp.array([0.6, 0., -0.1]), jnp.array([[0.7], [0.5]])
tol = 1e-6
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit, gamma=-1).run(x0)
sol, state = Broyden(g, maxiter=100, tol=tol, jit=jit).run(x0)
self.assertLess(state.error, tol)
g_sol_norm = tree_l2_norm(g(sol))
self.assertLess(g_sol_norm, tol)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_affine_contractive_mapping(self):
def g(x, M, b):
return M @ x + b - x
tol = 1e-6
fp = Broyden(g, maxiter=100, tol=tol, implicit_diff=True, gamma=-1)
fp = Broyden(g, maxiter=5000, tol=tol, implicit_diff=True)
x0 = jnp.zeros_like(b)
sol, state = fp.run(x0, M, b)
self.assertLess(state.error, tol)
Expand Down