From b265542ce04e9421704a17322cb0b2bdf89fa717 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Fri, 10 Jun 2022 17:47:30 +0200 Subject: [PATCH] Fix issue with positional arguments in LBFGS and NonlinearCG. --- docs/changelog.rst | 14 ++++++++++++++ jaxopt/_src/lbfgs.py | 6 +++--- jaxopt/_src/nonlinear_cg.py | 9 +++++---- jaxopt/version.py | 2 +- tests/lbfgs_test.py | 4 +++- tests/nonlinear_cg_test.py | 3 ++- 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 516fec9e..f245e0fe 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,20 @@ Changelog ========= +Version 0.4.2 +------------- + +Bug fixes and enhancements +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Fix issue with positional arguments in :class:`jaxopt.LBFGS` and :class:`jaxopt.NonlinearCG`, + by Mathieu Blondel. + +Contributors +~~~~~~~~~~~~ + +Mathieu Blondel. + Version 0.4.1 ------------- diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 301d7cf9..b34339a7 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -280,9 +280,9 @@ def update(self, self.max_stepsize, # Otherwise, we increase a bit the previous one. state.stepsize * self.increase_factor) - new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize, - params=params, value=value, grad=grad, - descent_direction=descent_direction, + new_stepsize, ls_state = ls.run(init_stepsize, + params, value, grad, + descent_direction, *args, **kwargs) new_value = ls_state.value new_params = ls_state.params diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 15e63668..88603f41 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -139,10 +139,11 @@ def update(self, maxiter=self.maxls, decrease_factor=self.decrease_factor, condition=self.condition) - new_stepsize, ls_state = ls.run(init_stepsize=init_stepsize, - params=params, - value=value, - grad=grad, + new_stepsize, ls_state = ls.run(init_stepsize, + params, + value, + grad, + None, # descent_direction *args, **kwargs) new_params = tree_add_scalar_mul(params, new_stepsize, descent_direction) diff --git a/jaxopt/version.py b/jaxopt/version.py index c9cfab72..e0b5b397 100644 --- a/jaxopt/version.py +++ b/jaxopt/version.py @@ -14,4 +14,4 @@ """JAXopt version.""" -__version__ = "0.4.1" +__version__ = "0.4.2" diff --git a/tests/lbfgs_test.py b/tests/lbfgs_test.py index 0d64a2d2..55ee85a4 100644 --- a/tests/lbfgs_test.py +++ b/tests/lbfgs_test.py @@ -213,6 +213,7 @@ def test_binary_logreg(self, use_gamma): w_init = jnp.zeros(X.shape[1]) lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma) + # Test with keyword argument. w_fit, info = lbfgs.run(w_init, data=data) # Check optimality conditions. @@ -236,7 +237,8 @@ def test_multiclass_logreg(self, use_gamma): pytree_init = (W_init, b_init) lbfgs = LBFGS(fun=fun, tol=1e-3, maxiter=500, use_gamma=use_gamma) - pytree_fit, info = lbfgs.run(pytree_init, data=data) + # Test with positional argument. + pytree_fit, info = lbfgs.run(pytree_init, data) # Check optimality conditions. self.assertLessEqual(info.error, 1e-2) diff --git a/tests/nonlinear_cg_test.py b/tests/nonlinear_cg_test.py index adf02c0b..84bbe2e6 100644 --- a/tests/nonlinear_cg_test.py +++ b/tests/nonlinear_cg_test.py @@ -81,7 +81,8 @@ def test_binary_logreg(self, method): w_init = jnp.zeros(X.shape[1]) cg_model = NonlinearCG(fun=fun, tol=1e-3, maxiter=100, method=method) - w_fit, info = cg_model.run(w_init, data=data) + # Test with positional argument. + w_fit, info = cg_model.run(w_init, data) # Check optimality conditions. self.assertLessEqual(info.error, 5e-2)