Skip to content

Commit

Permalink
made sure to use njev only when available in scipy optimize results
Browse files Browse the repository at this point in the history
  • Loading branch information
zaccharieramzi committed Oct 3, 2023
1 parent b23c7a5 commit 08fd2f8
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions jaxopt/_src/scipy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,19 +359,21 @@ def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]:
else:
hess_inv = None

try:
num_hess_eval = jnp.asarray(res.nhev, base.NUM_EVAL_DTYPE)
except AttributeError:
num_hess_eval = jnp.array(0, base.NUM_EVAL_DTYPE)
nev_dict = {}
for attr in ['nfev', 'njev', 'nhev']:
try:
nev_dict[attr] = jnp.asarray(getattr(res, attr), base.NUM_EVAL_DTYPE)
except AttributeError:
nev_dict[attr] = jnp.array(0, base.NUM_EVAL_DTYPE)

info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun),
success=res.success,
status=res.status,
iter_num=res.nit,
hess_inv=hess_inv,
num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE),
num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE),
num_hess_eval=num_hess_eval)
num_fun_eval=nev_dict['nfev'],
num_jac_eval=nev_dict['njev'],
num_hess_eval=nev_dict['nhev'])
return base.OptStep(params, info)

def run(self,
Expand Down

0 comments on commit 08fd2f8

Please sign in to comment.