Skip to content

Commit

Permalink
Merge pull request #239 from mblondel:lbfgs_bug2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 454228985
  • Loading branch information
JAXopt authors committed Jun 10, 2022
2 parents 7d23cc6 + a666044 commit 9c9c266
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 14 deletions.
14 changes: 14 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
Changelog
=========

Version 0.4.2
-------------

Bug fixes and enhancements
~~~~~~~~~~~~~~~~~~~~~~~~~~

- Fix issue with positional arguments in :class:`jaxopt.LBFGS` and :class:`jaxopt.NonlinearCG`,
by Mathieu Blondel.

Contributors
~~~~~~~~~~~~

Mathieu Blondel.

Version 0.4.1
-------------

Expand Down
8 changes: 4 additions & 4 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def init_state(self,
"""
return LbfgsState(iter_num=jnp.asarray(0),
value=jnp.asarray(jnp.inf),
stepsize=jnp.asarray(1.0),
stepsize=jnp.asarray(self.max_stepsize),
error=jnp.asarray(jnp.inf),
s_history=init_history(init_params, self.history_size),
y_history=init_history(init_params, self.history_size),
Expand Down Expand Up @@ -280,9 +280,9 @@ def update(self,
self.max_stepsize,
# Otherwise, we increase a bit the previous one.
state.stepsize * self.increase_factor)
new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,
params=params, value=value, grad=grad,
descent_direction=descent_direction,
new_stepsize, ls_state = ls.run(init_stepsize,
params, value, grad,
descent_direction,
*args, **kwargs)
new_value = ls_state.value
new_params = ls_state.params
Expand Down
28 changes: 21 additions & 7 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class NonlinearCG(base.IterativeSolver):
(default: 0.8).
increase_factor: factor by which to increase the stepsize during line search
(default: 1.2).
max_stepsize: upper bound on stepsize.
min_stepsize: lower bound on stepsize.
implicit_diff: whether to enable implicit diff or autodiff of unrolled
iterations.
implicit_diff_solve: the linear system solver to use.
Expand All @@ -87,6 +89,10 @@ class NonlinearCG(base.IterativeSolver):
maxls: int = 15
decrease_factor: float = 0.8
increase_factor: float = 1.2
max_stepsize: float = 1.0
# FIXME: should depend on whether float32 or float64 is used.
min_stepsize: float = 1e-6

implicit_diff: bool = True
implicit_diff_solve: Optional[Callable] = None

Expand All @@ -110,7 +116,7 @@ def init_state(self,
value, grad = self._value_and_grad_fun(init_params, *args, **kwargs)

return NonlinearCGState(iter_num=jnp.asarray(0),
stepsize=jnp.asarray(1.0),
stepsize=jnp.asarray(self.max_stepsize),
error=jnp.asarray(jnp.inf),
value=value,
grad=grad,
Expand All @@ -133,16 +139,24 @@ def update(self,

eps = 1e-6
value, grad, descent_direction = state.value, state.grad, state.descent_direction
init_stepsize = state.stepsize * self.increase_factor
ls = BacktrackingLineSearch(fun=self._value_and_grad_fun,
value_and_grad=True,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
condition=self.condition)
new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,
params=params,
value=value,
grad=grad,
condition=self.condition,
max_stepsize=self.max_stepsize)

init_stepsize = jnp.where(state.stepsize <= self.min_stepsize,
# If stepsize became too small, we restart it.
self.max_stepsize,
# Otherwise, we increase a bit the previous one.
state.stepsize * self.increase_factor)

new_stepsize, ls_state = ls.run(init_stepsize,
params,
value,
grad,
None, # descent_direction
*args, **kwargs)

new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction)
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.4.1"
__version__ = "0.4.2"
4 changes: 3 additions & 1 deletion tests/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_binary_logreg(self, use_gamma):

w_init = jnp.zeros(X.shape[1])
lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma)
# Test with keyword argument.
w_fit, info = lbfgs.run(w_init, data=data)

# Check optimality conditions.
Expand All @@ -236,7 +237,8 @@ def test_multiclass_logreg(self, use_gamma):
pytree_init = (W_init, b_init)

lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma)
pytree_fit, info = lbfgs.run(pytree_init, data=data)
# Test with positional argument.
pytree_fit, info = lbfgs.run(pytree_init, data)

# Check optimality conditions.
self.assertLessEqual(info.error, 1e-2)
Expand Down
3 changes: 2 additions & 1 deletion tests/nonlinear_cg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_binary_logreg(self, method):

w_init = jnp.zeros(X.shape[1])
cg_model = NonlinearCG(fun=fun, tol=1e-3, maxiter=100, method=method)
w_fit, info = cg_model.run(w_init, data=data)
# Test with positional argument.
w_fit, info = cg_model.run(w_init, data)

# Check optimality conditions.
self.assertLessEqual(info.error, 5e-2)
Expand Down

0 comments on commit 9c9c266

Please sign in to comment.