diff --git a/tests/scipy_wrappers_test.py b/tests/scipy_wrappers_test.py index 4a306af4..c5adf9a8 100644 --- a/tests/scipy_wrappers_test.py +++ b/tests/scipy_wrappers_test.py @@ -130,6 +130,18 @@ def test_maxiter_api(self): with self.assertRaises(ValueError): ScipyMinimize(fun=self.logreg_fun, maxiter=500, options={'maxiter': 100}) + @parameterized.product(method=[ + 'Nelder-Mead', + 'Powell', + 'TNC', + ]) + def test_no_njev(self, method): + # test methods that do not return njev + # see https://github.com/google/jaxopt/pull/542 + solver = ScipyMinimize(fun=self.logreg_fun, method=method) + pytree_init = jnp.zeros([self.n_features, self.n_classes]) + solver.run(pytree_init, l2reg=self.default_l2reg, data=self.data) + def test_callback(self): # test the callback API trace = [] @@ -388,7 +400,7 @@ def wrapper(b): jac_idf = [list(blk_row) for blk_row in jac_idf] jac_idf = jnp.block(jac_idf) self.assertArraysAllClose(jac_theo, jac_idf, atol=1e-3) - + def test_broyden(self): # non regression test based on issue https://github.com/google/jaxopt/issues/290 def simple_root(x): @@ -618,7 +630,7 @@ def setUp(self): def test_inverse_hessian_bfgs(self): # perform the bfgs algorithm search - res_scipy = osp.optimize.minimize(self.objective, self.x0, + res_scipy = osp.optimize.minimize(self.objective, self.x0, method='BFGS', jac=self.derivative) # run the same algorithm via jaxopt res_jaxopt = ScipyMinimize(method='BFGS', fun=self.objective).run(self.x0) @@ -627,7 +639,7 @@ def test_inverse_hessian_bfgs(self): def test_inverse_hessian_lbfgs(self): # perform the l-bfgs-b algorithm search - res_scipy = osp.optimize.minimize(self.objective, self.x0, + res_scipy = osp.optimize.minimize(self.objective, self.x0, method='L-BFGS-B', jac=self.derivative) # run the same algorithm via jaxopt res_jaxopt = ScipyMinimize(method='L-BFGS-B', fun=self.objective).run(self.x0)