Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BoxOSQP when configured with fun and used with pytree-based APIs #597

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jaxopt/_src/osqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src.cond import cond
from jaxopt.tree_util import tree_add, tree_sub, tree_mul
from jaxopt.tree_util import tree_add, tree_sub, tree_mul, tree_sum
from jaxopt.tree_util import tree_scalar_mul, tree_add_scalar_mul
from jaxopt.tree_util import tree_map, tree_vdot
from jaxopt.tree_util import tree_ones_like, tree_zeros_like, tree_where
Expand Down Expand Up @@ -792,7 +792,7 @@ def matvec_Q(params_obj, x):
# nabla_x f(x) = Q x + c
# Qx = nabla_x f(x) - c
def fun_minus_cx(xx):
return self.fun(xx, params_Q) - jnp.sum(c*xx)
return self.fun(xx, params_Q) - tree_sum(tree_mul(c, xx))
Qx = jax.grad(fun_minus_cx)(x)
return Qx
self.matvec_Q = matvec_Q
Expand Down
Loading