diff --git a/jaxopt/_src/scipy_wrappers.py b/jaxopt/_src/scipy_wrappers.py index 55c3320d..d7c57825 100644 --- a/jaxopt/_src/scipy_wrappers.py +++ b/jaxopt/_src/scipy_wrappers.py @@ -359,19 +359,21 @@ def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]: else: hess_inv = None - try: - num_hess_eval = jnp.asarray(res.nhev, base.NUM_EVAL_DTYPE) - except AttributeError: - num_hess_eval = jnp.array(0, base.NUM_EVAL_DTYPE) + nev_dict = {} + for attr in ['nfev', 'njev', 'nhev']: + try: + nev_dict[attr] = jnp.asarray(getattr(res, attr), base.NUM_EVAL_DTYPE) + except AttributeError: + nev_dict[attr] = jnp.array(0, base.NUM_EVAL_DTYPE) info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun), success=res.success, status=res.status, iter_num=res.nit, hess_inv=hess_inv, - num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE), - num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE), - num_hess_eval=num_hess_eval) + num_fun_eval=nev_dict['nfev'], + num_jac_eval=nev_dict['njev'], + num_hess_eval=nev_dict['nhev']) return base.OptStep(params, info) def run(self, diff --git a/tests/scipy_wrappers_test.py b/tests/scipy_wrappers_test.py index 4a306af4..c5adf9a8 100644 --- a/tests/scipy_wrappers_test.py +++ b/tests/scipy_wrappers_test.py @@ -130,6 +130,18 @@ def test_maxiter_api(self): with self.assertRaises(ValueError): ScipyMinimize(fun=self.logreg_fun, maxiter=500, options={'maxiter': 100}) + @parameterized.product(method=[ + 'Nelder-Mead', + 'Powell', + 'TNC', + ]) + def test_no_njev(self, method): + # test methods that do not return njev + # see https://github.com/google/jaxopt/pull/542 + solver = ScipyMinimize(fun=self.logreg_fun, method=method) + pytree_init = jnp.zeros([self.n_features, self.n_classes]) + solver.run(pytree_init, l2reg=self.default_l2reg, data=self.data) + def test_callback(self): # test the callback API trace = [] @@ -388,7 +400,7 @@ def wrapper(b): jac_idf = [list(blk_row) for blk_row in jac_idf] jac_idf = jnp.block(jac_idf) self.assertArraysAllClose(jac_theo, jac_idf, atol=1e-3) - + def test_broyden(self): # non regression test based on issue https://github.com/google/jaxopt/issues/290 def simple_root(x): @@ -618,7 +630,7 @@ def setUp(self): def test_inverse_hessian_bfgs(self): # perform the bfgs algorithm search - res_scipy = osp.optimize.minimize(self.objective, self.x0, + res_scipy = osp.optimize.minimize(self.objective, self.x0, method='BFGS', jac=self.derivative) # run the same algorithm via jaxopt res_jaxopt = ScipyMinimize(method='BFGS', fun=self.objective).run(self.x0) @@ -627,7 +639,7 @@ def test_inverse_hessian_bfgs(self): def test_inverse_hessian_lbfgs(self): # perform the l-bfgs-b algorithm search - res_scipy = osp.optimize.minimize(self.objective, self.x0, + res_scipy = osp.optimize.minimize(self.objective, self.x0, method='L-BFGS-B', jac=self.derivative) # run the same algorithm via jaxopt res_jaxopt = ScipyMinimize(method='L-BFGS-B', fun=self.objective).run(self.x0)