From 08fd2f852d06394db77673d116dc1f0fad5bd4aa Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Tue, 3 Oct 2023 23:34:09 +0200 Subject: [PATCH] made sure to use njev only when available in scipy optimize results --- jaxopt/_src/scipy_wrappers.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/jaxopt/_src/scipy_wrappers.py b/jaxopt/_src/scipy_wrappers.py index bba65419..43faddc4 100644 --- a/jaxopt/_src/scipy_wrappers.py +++ b/jaxopt/_src/scipy_wrappers.py @@ -359,19 +359,21 @@ def scipy_fun(x_onp: onp.ndarray) -> Tuple[onp.ndarray, onp.ndarray]: else: hess_inv = None - try: - num_hess_eval = jnp.asarray(res.nhev, base.NUM_EVAL_DTYPE) - except AttributeError: - num_hess_eval = jnp.array(0, base.NUM_EVAL_DTYPE) + nev_dict = {} + for attr in ['nfev', 'njev', 'nhev']: + try: + nev_dict[attr] = jnp.asarray(getattr(res, attr), base.NUM_EVAL_DTYPE) + except AttributeError: + nev_dict[attr] = jnp.array(0, base.NUM_EVAL_DTYPE) info = ScipyMinimizeInfo(fun_val=jnp.asarray(res.fun), success=res.success, status=res.status, iter_num=res.nit, hess_inv=hess_inv, - num_fun_eval=jnp.asarray(res.nfev, base.NUM_EVAL_DTYPE), - num_jac_eval=jnp.asarray(res.njev, base.NUM_EVAL_DTYPE), - num_hess_eval=num_hess_eval) + num_fun_eval=nev_dict['nfev'], + num_jac_eval=nev_dict['njev'], + num_hess_eval=nev_dict['nhev']) return base.OptStep(params, info) def run(self,