-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Pseudo-inverse preconditioner for EqQP. #133
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Preconditioned solvers for equality constrained quadratic programming.""" | ||
|
||
from typing import Optional, Any | ||
from dataclasses import dataclass | ||
import jax.numpy as jnp | ||
import jaxopt | ||
from jaxopt._src import base | ||
from jaxopt._src import linear_operator | ||
|
||
|
||
@dataclass | ||
class PseudoInversePreconditionedEqQP(base.Solver): | ||
qp_solver: jaxopt.EqualityConstrainedQP | ||
|
||
def init_params(self, params_obj, params_eq): | ||
"""Computes the matvec associated to the pseudo inverse of the KKT matrix.""" | ||
Q, p = params_obj | ||
A, b = params_eq | ||
del p, b | ||
|
||
kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]]) | ||
kkt_mat_pinv = jnp.linalg.pinv(kkt_mat) | ||
|
||
d = Q.shape[0] | ||
|
||
pinv_blocks = ( | ||
(kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]), | ||
(kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]), | ||
) | ||
return linear_operator.BlockLinearOperator(pinv_blocks) | ||
|
||
def run( | ||
self, | ||
init_params: Optional[base.KKTSolution] = None, | ||
params_obj: Optional[Any] = None, | ||
params_eq: Optional[Any] = None, | ||
params_precond=None, | ||
**kwargs | ||
): | ||
# TODO(gnegiar): the M parameter should be passed to both | ||
# the QP solve and the implicit_diff_solve | ||
return self.qp_solver.run( | ||
init_params, params_obj, params_eq, M=params_precond, **kwargs | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,15 +14,16 @@ | |
"""Interface for linear operators.""" | ||
|
||
import functools | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as onp | ||
|
||
from jaxopt.tree_util import tree_map, tree_sum, tree_mul | ||
from jaxopt.tree_util import tree_map | ||
|
||
|
||
class DenseLinearOperator: | ||
|
||
def __init__(self, pytree): | ||
self.pytree = pytree | ||
|
||
|
@@ -33,7 +34,7 @@ def matvec(self, x): | |
return tree_map(jnp.dot, self.pytree, x) | ||
|
||
def rmatvec(self, _, y): | ||
return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y) | ||
return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y) | ||
|
||
def matvec_and_rmatvec(self, x, y): | ||
return self.matvec(x), self.rmatvec(x, y) | ||
|
@@ -52,11 +53,11 @@ def col_norm(w): | |
if not squared: | ||
col_norms = jnp.sqrt(col_norms) | ||
return col_norms | ||
|
||
return tree_map(col_norm, self.pytree) | ||
|
||
|
||
class FunctionalLinearOperator: | ||
|
||
def __init__(self, fun, params): | ||
self.fun = functools.partial(fun, params) | ||
|
||
|
@@ -71,7 +72,7 @@ def rmatvec(self, x, y): | |
|
||
def matvec_and_rmatvec(self, x, y): | ||
matvec_x, vjp = jax.vjp(self.matvec, x) | ||
rmatvec_y, = vjp(y) | ||
(rmatvec_y,) = vjp(y) | ||
return matvec_x, rmatvec_y | ||
|
||
def normal_matvec(self, x): | ||
|
@@ -85,3 +86,72 @@ def _make_linear_operator(matvec): | |
return DenseLinearOperator | ||
else: | ||
return functools.partial(FunctionalLinearOperator, matvec) | ||
|
||
|
||
def block_row_matvec(block, x): | ||
"""Performs a matvec for a row of block matrices. | ||
|
||
The following matvec is performed: | ||
[U1, ..., UN] * [x1, ..., xN] | ||
where U1, ..., UN are matrices and x1, ..., xN are vectors | ||
of compatible shapes. | ||
""" | ||
if len(block) != len(x): | ||
raise ValueError( | ||
"We need as many blocks in the matrix as in the vector." | ||
) | ||
return sum(jax.tree_util.tree_map(jnp.dot, block, x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If my understanding is correct, here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, let me know :) |
||
|
||
|
||
# TODO(gnegiar): Extend to arbitrary block shapes. | ||
@jax.tree_util.register_pytree_node_class | ||
@dataclass | ||
class BlockLinearOperator: | ||
"""Represents a linear operator defined by blocks over a block pytree. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is -- without registering it as a pytree node, I get this error: |
||
|
||
Attributes: | ||
blocks: a 2x2 block matrix of the form | ||
[[A, B] | ||
[C, D]] | ||
""" | ||
|
||
blocks: Tuple[Tuple[jnp.array]] | ||
|
||
def __call__(self, x): | ||
return self.matvec(x) | ||
|
||
def matvec(self, x): | ||
"""Performs the block matvec with u defined by blocks. | ||
|
||
The matvec is of form: | ||
[u1, u2] | ||
[[A, B] * | ||
[C, D]] | ||
|
||
""" | ||
return jax.tree_util.tree_map( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are stopping the recursion at depth 1 (i.e you only retrieve There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True ; my hope was to leave a stub which could be extended for any sized block matrix, and not just 2x2. I can go either way on this, let me know what you think. |
||
lambda row_of_blocks: block_row_matvec(row_of_blocks, x), | ||
self.blocks, | ||
is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1], | ||
) | ||
|
||
def rmatvec(self, x, y): | ||
return self.matvec_and_rmatvec(x, y)[1] | ||
|
||
def matvec_and_rmatvec(self, x, y): | ||
matvec_x, vjp = jax.vjp(self.matvec, x) | ||
(rmatvec_y,) = vjp(y) | ||
return matvec_x, rmatvec_y | ||
|
||
def normal_matvec(self, x): | ||
"""Computes A^T A x from matvec(x) = A x.""" | ||
matvec_x, vjp = jax.vjp(self.matvec, x) | ||
return vjp(matvec_x)[0] | ||
|
||
def tree_flatten(self): | ||
return self.blocks, None | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
del aux_data | ||
return cls(children) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2021 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import jax | ||
from jax import test_util as jtu | ||
import jax.numpy as jnp | ||
|
||
from jaxopt import PseudoInversePreconditionedEqQP | ||
from jaxopt import EqualityConstrainedQP | ||
import numpy as onp | ||
|
||
|
||
class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase): | ||
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): | ||
def fun(Q, c, A, b): | ||
Q = 0.5 * (Q + Q.T) | ||
|
||
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) | ||
# reduce the primal variables to a scalar value for test purpose. | ||
return jnp.sum(solver.run(**hyperparams).params[0]) | ||
|
||
# Derivative w.r.t. A. | ||
rng = onp.random.RandomState(0) | ||
V = rng.rand(*A.shape) | ||
V /= onp.sqrt(onp.sum(V ** 2)) | ||
eps = 1e-4 | ||
deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b)) | ||
deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps) | ||
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) | ||
|
||
# Derivative w.r.t. b. | ||
v = rng.rand(*b.shape) | ||
v /= onp.sqrt(onp.sum(v ** 2)) | ||
eps = 1e-4 | ||
deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b)) | ||
deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps) | ||
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) | ||
|
||
# Derivative w.r.t. Q | ||
W = rng.rand(*Q.shape) | ||
W /= onp.sqrt(onp.sum(W ** 2)) | ||
eps = 1e-4 | ||
deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b)) | ||
deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps) | ||
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) | ||
|
||
# Derivative w.r.t. c | ||
w = rng.rand(*c.shape) | ||
w /= onp.sqrt(onp.sum(w ** 2)) | ||
eps = 1e-4 | ||
deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b)) | ||
deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps) | ||
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) | ||
|
||
def test_pseudoinverse_preconditioner(self): | ||
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) | ||
c = jnp.array([1.0, 1.0]) | ||
A = jnp.array([[1.0, 1.0]]) | ||
b = jnp.array([1.0]) | ||
qp = EqualityConstrainedQP(tol=1e-7) | ||
preconditioned_qp = PseudoInversePreconditionedEqQP(qp) | ||
params_obj = (Q, c) | ||
params_eq = (A, b) | ||
params_precond = preconditioned_qp.init_params(params_obj, params_eq) | ||
hyperparams = dict( | ||
params_obj=params_obj, | ||
params_eq=params_eq, | ||
) | ||
sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params | ||
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) | ||
self._check_derivative_Q_c_A_b(preconditioned_qp, Q, c, A, b) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
|
||
|
||
class EqualityConstrainedQPTest(jtu.JaxTestCase): | ||
def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b): | ||
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): | ||
def fun(Q, c, A, b): | ||
Q = 0.5 * (Q + Q.T) | ||
|
||
|
@@ -77,7 +77,7 @@ def test_qp_eq_only(self): | |
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) | ||
sol = qp.run(**hyperparams).params | ||
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) | ||
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) | ||
self._check_derivative_Q_c_A_b(qp, Q, c, A, b) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch ! Maybe you need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the second argument from the |
||
|
||
def test_qp_eq_with_init(self): | ||
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) | ||
|
@@ -89,7 +89,7 @@ def test_qp_eq_with_init(self): | |
init_params = KKTSolution(jnp.array([1.0, 1.0]), jnp.array([1.0])) | ||
sol = qp.run(init_params, **hyperparams).params | ||
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) | ||
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) | ||
self._check_derivative_Q_c_A_b(qp, Q, c, A, b) | ||
|
||
def test_projection_hyperplane(self): | ||
x = jnp.array([1.0, 2.0]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On first sight, I'm rather -1 on introducing a pre-conditioner
M
inoptimality_fun
since I don't think it plays any role in the optimality conditions (it would not make sense to differentiate with respect toM
). Sincerun
andoptimality_fun
need to have the same signature, this rules out addingM
torun
as well.I think I would go for something like this instead:
Typically, stuff that doesn't need to be differentiated should go to the constructor.
If you want to differentiate wrt
params_eq
orparams_obj
, you may need to useinstead. Not entirely sure if
PseudoInversePreconditioner
should live in JAXopt or in user land.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't play a role in the optimality conditions -- it could play a role in the backwards though, if we manage to pass the argument to the backward solver as well. This would hopefully speed things up, since the forward and backward linear systems share the linear operator.
With the current API, if we're solving the QP as the inner problem of a bi-level problem, then we need to build a new QP solver instance at each step of the outer loop, and pass solve=partial(linearsolver, M=preconditioner) to both
solve
andimplicit_diff_solve
.Something which is then unclear to me: does building a new instance of the QP solver necessarily trigger recompilation of the
run
method at each iteration ?