Skip to content

Commit

Permalink
Fix issue with positional arguments in LBFGS and NonlinearCG.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Jun 10, 2022
1 parent 7d23cc6 commit b265542
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 10 deletions.
14 changes: 14 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
Changelog
=========

Version 0.4.2
-------------

Bug fixes and enhancements
~~~~~~~~~~~~~~~~~~~~~~~~~~

- Fix issue with positional arguments in :class:`jaxopt.LBFGS` and :class:`jaxopt.NonlinearCG`,
by Mathieu Blondel.

Contributors
~~~~~~~~~~~~

Mathieu Blondel.

Version 0.4.1
-------------

Expand Down
6 changes: 3 additions & 3 deletions jaxopt/_src/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def update(self,
self.max_stepsize,
# Otherwise, we increase a bit the previous one.
state.stepsize * self.increase_factor)
new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,
params=params, value=value, grad=grad,
descent_direction=descent_direction,
new_stepsize, ls_state = ls.run(init_stepsize,
params, value, grad,
descent_direction,
*args, **kwargs)
new_value = ls_state.value
new_params = ls_state.params
Expand Down
9 changes: 5 additions & 4 deletions jaxopt/_src/nonlinear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ def update(self,
maxiter=self.maxls,
decrease_factor=self.decrease_factor,
condition=self.condition)
new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize,
params=params,
value=value,
grad=grad,
new_stepsize, ls_state = ls.run(init_stepsize,
params,
value,
grad,
None, # descent_direction
*args, **kwargs)

new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction)
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.4.1"
__version__ = "0.4.2"
4 changes: 3 additions & 1 deletion tests/lbfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_binary_logreg(self, use_gamma):

w_init = jnp.zeros(X.shape[1])
lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma)
# Test with keyword argument.
w_fit, info = lbfgs.run(w_init, data=data)

# Check optimality conditions.
Expand All @@ -236,7 +237,8 @@ def test_multiclass_logreg(self, use_gamma):
pytree_init = (W_init, b_init)

lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma)
pytree_fit, info = lbfgs.run(pytree_init, data=data)
# Test with positional argument.
pytree_fit, info = lbfgs.run(pytree_init, data)

# Check optimality conditions.
self.assertLessEqual(info.error, 1e-2)
Expand Down
3 changes: 2 additions & 1 deletion tests/nonlinear_cg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_binary_logreg(self, method):

w_init = jnp.zeros(X.shape[1])
cg_model = NonlinearCG(fun=fun, tol=1e-3, maxiter=100, method=method)
w_fit, info = cg_model.run(w_init, data=data)
# Test with positional argument.
w_fit, info = cg_model.run(w_init, data)

# Check optimality conditions.
self.assertLessEqual(info.error, 5e-2)
Expand Down

0 comments on commit b265542

Please sign in to comment.