Skip to content

Commit

Permalink
Merge pull request #542 from zaccharieramzi:fix-njev-scipy-minimize
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581902442
  • Loading branch information
JAXopt authors committed Nov 13, 2023
2 parents 16cad2d + b5c7dd3 commit 5b9b62c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
16 changes: 9 additions & 7 deletions jaxopt/_src/scipy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,21 @@ def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]:
else:
hess_inv = None

try:
num_hess_eval = jnp.asarray(res.nhev, base.NUM_EVAL_DTYPE)
except AttributeError:
num_hess_eval = jnp.array(0, base.NUM_EVAL_DTYPE)
nev_dict = {}
for attr in ['nfev', 'njev', 'nhev']:
try:
nev_dict[attr] = jnp.asarray(getattr(res, attr), base.NUM_EVAL_DTYPE)
except AttributeError:
nev_dict[attr] = jnp.array(0, base.NUM_EVAL_DTYPE)

info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
iter_num=res.nit,
hess_inv=hess_inv,
num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE),
num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
num_hess_eval=num_hess_eval)
num_fun_eval=nev_dict['nfev'],
num_jac_eval=nev_dict['njev'],
num_hess_eval=nev_dict['nhev'])
return base.OptStep(params, info)

def run(self,
Expand Down
18 changes: 15 additions & 3 deletions tests/scipy_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_maxiter_api(self):
with self.assertRaises(ValueError):
ScipyMinimize(fun=self.logreg_fun, maxiter=500, options={'maxiter': 100})

@parameterized.product(method=[
'Nelder-Mead',
'Powell',
'TNC',
])
def test_no_njev(self, method):
# test methods that do not return njev
# see https://github.com/google/jaxopt/pull/542
solver = ScipyMinimize(fun=self.logreg_fun, method=method)
pytree_init = jnp.zeros([self.n_features, self.n_classes])
solver.run(pytree_init, l2reg=self.default_l2reg, data=self.data)

def test_callback(self):
# test the callback API
trace = []
Expand Down Expand Up @@ -388,7 +400,7 @@ def wrapper(b):
jac_idf = [list(blk_row) for blk_row in jac_idf]
jac_idf = jnp.block(jac_idf)
self.assertArraysAllClose(jac_theo, jac_idf, atol=1e-3)

def test_broyden(self):
# non regression test based on issue https://github.com/google/jaxopt/issues/290
def simple_root(x):
Expand Down Expand Up @@ -618,7 +630,7 @@ def setUp(self):

def test_inverse_hessian_bfgs(self):
# perform the bfgs algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='BFGS', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='BFGS', fun=self.objective).run(self.x0)
Expand All @@ -627,7 +639,7 @@ def test_inverse_hessian_bfgs(self):

def test_inverse_hessian_lbfgs(self):
# perform the l-bfgs-b algorithm search
res_scipy = osp.optimize.minimize(self.objective, self.x0,
res_scipy = osp.optimize.minimize(self.objective, self.x0,
method='L-BFGS-B', jac=self.derivative)
# run the same algorithm via jaxopt
res_jaxopt = ScipyMinimize(method='L-BFGS-B', fun=self.objective).run(self.x0)
Expand Down

0 comments on commit 5b9b62c

Please sign in to comment.