diff --git a/jaxopt/_src/osqp.py b/jaxopt/_src/osqp.py index 7eae7f86..6d47e6bb 100644 --- a/jaxopt/_src/osqp.py +++ b/jaxopt/_src/osqp.py @@ -31,7 +31,7 @@ from jaxopt._src import base from jaxopt._src import implicit_diff as idf from jaxopt._src.cond import cond -from jaxopt.tree_util import tree_add, tree_sub, tree_mul +from jaxopt.tree_util import tree_add, tree_sub, tree_mul, tree_sum from jaxopt.tree_util import tree_scalar_mul, tree_add_scalar_mul from jaxopt.tree_util import tree_map, tree_vdot from jaxopt.tree_util import tree_ones_like, tree_zeros_like, tree_where @@ -792,7 +792,7 @@ def matvec_Q(params_obj, x): # nabla_x f(x) = Q x + c # Qx = nabla_x f(x) - c def fun_minus_cx(xx): - return self.fun(xx, params_Q) - jnp.sum(c*xx) + return self.fun(xx, params_Q) - tree_sum(tree_mul(c, xx)) Qx = jax.grad(fun_minus_cx)(x) return Qx self.matvec_Q = matvec_Q