From c78943e1d8b78523e7700f76fc7669b741e77c84 Mon Sep 17 00:00:00 2001 From: Geoffrey Negiar Date: Thu, 16 Dec 2021 12:48:27 -0800 Subject: [PATCH] Added Pseudo-inverse preconditioner for EqQP. This allows to precompute a preconditioner, and share it across multiple outer loops, where the inner loop is solving an Equality Constrained QP. This should provide speedups when the parameters of the inner loop QP don't change too much. TODO: modify the implicit diff decorator so that the jvp also uses the preconditioner. --- jaxopt/__init__.py | 1 + jaxopt/_src/eq_qp.py | 7 ++- jaxopt/_src/eq_qp_preconditioned.py | 58 ++++++++++++++++++++ jaxopt/_src/linear_operator.py | 82 ++++++++++++++++++++++++++--- tests/eq_qp_preconditioned_test.py | 82 +++++++++++++++++++++++++++++ tests/eq_qp_test.py | 6 +-- 6 files changed, 225 insertions(+), 11 deletions(-) create mode 100644 jaxopt/_src/eq_qp_preconditioned.py create mode 100644 tests/eq_qp_preconditioned_test.py diff --git a/jaxopt/__init__.py b/jaxopt/__init__.py index 96aef470..fc7f3063 100644 --- a/jaxopt/__init__.py +++ b/jaxopt/__init__.py @@ -35,6 +35,7 @@ from jaxopt._src.polyak_sgd import PolyakSGD from jaxopt._src.projected_gradient import ProjectedGradient from jaxopt._src.proximal_gradient import ProximalGradient +from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP from jaxopt._src.quadratic_prog import QuadraticProgramming from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize diff --git a/jaxopt/_src/eq_qp.py b/jaxopt/_src/eq_qp.py index 65fc3936..27af4db3 100644 --- a/jaxopt/_src/eq_qp.py +++ b/jaxopt/_src/eq_qp.py @@ -103,7 +103,7 @@ class EqualityConstrainedQP(base.Solver): implicit_diff_solve: Optional[Callable] = None jit: bool = True - def _refined_solve(self, matvec, b, init, maxiter, tol): + def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs): # Instead of solving S x = b # We solve \bar{S} x = b # @@ -152,13 +152,14 @@ def matvec_regularized_qp(_, x): maxiter=self.refine_maxiter, tol=tol, ) - return solver.run(init_params=init, A=None, b=b)[0] + return solver.run(init_params=init, A=None, b=b, **kwargs)[0] def run( self, init_params: Optional[base.KKTSolution] = None, params_obj: Optional[Any] = None, params_eq: Optional[Any] = None, + **kwargs, ) -> base.OptStep: """Solves 0.5 * x^T Q x + c^T x subject to Ax = b. @@ -168,6 +169,7 @@ def run( init_params: ignored. params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided. params_eq: (A, b) or (params_A, b) if matvec_A is provided. + **kwargs: Keyword args provided to the solver. Returns: (params, state), where params = (primal_var, dual_var_eq, None) """ @@ -200,6 +202,7 @@ def matvec(u): init=init_params, tol=self.tol, maxiter=self.maxiter, + **kwargs, ) else: primal, dual_eq = self._refined_solve( diff --git a/jaxopt/_src/eq_qp_preconditioned.py b/jaxopt/_src/eq_qp_preconditioned.py new file mode 100644 index 00000000..3490f6f6 --- /dev/null +++ b/jaxopt/_src/eq_qp_preconditioned.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preconditioned solvers for equality constrained quadratic programming.""" + +from typing import Optional, Any +from dataclasses import dataclass +import jax.numpy as jnp +import jaxopt +from jaxopt._src import base +from jaxopt._src import linear_operator + + +@dataclass +class PseudoInversePreconditionedEqQP(base.Solver): + qp_solver: jaxopt.EqualityConstrainedQP + + def init_params(self, params_obj, params_eq): + """Computes the matvec associated to the pseudo inverse of the KKT matrix.""" + Q, p = params_obj + A, b = params_eq + del p, b + + kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]]) + kkt_mat_pinv = jnp.linalg.pinv(kkt_mat) + + d = Q.shape[0] + + pinv_blocks = ( + (kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]), + (kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]), + ) + return linear_operator.BlockLinearOperator(pinv_blocks) + + def run( + self, + init_params: Optional[base.KKTSolution] = None, + params_obj: Optional[Any] = None, + params_eq: Optional[Any] = None, + params_precond=None, + **kwargs + ): + # TODO(gnegiar): the M parameter should be passed to both + # the QP solve and the implicit_diff_solve + return self.qp_solver.run( + init_params, params_obj, params_eq, M=params_precond, **kwargs + ) diff --git a/jaxopt/_src/linear_operator.py b/jaxopt/_src/linear_operator.py index 34e4941e..0bd0ce2f 100644 --- a/jaxopt/_src/linear_operator.py +++ b/jaxopt/_src/linear_operator.py @@ -14,15 +14,16 @@ """Interface for linear operators.""" import functools +from dataclasses import dataclass +from typing import Tuple + import jax import jax.numpy as jnp -import numpy as onp -from jaxopt.tree_util import tree_map, tree_sum, tree_mul +from jaxopt.tree_util import tree_map class DenseLinearOperator: - def __init__(self, pytree): self.pytree = pytree @@ -33,7 +34,7 @@ def matvec(self, x): return tree_map(jnp.dot, self.pytree, x) def rmatvec(self, _, y): - return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y) + return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y) def matvec_and_rmatvec(self, x, y): return self.matvec(x), self.rmatvec(x, y) @@ -52,11 +53,11 @@ def col_norm(w): if not squared: col_norms = jnp.sqrt(col_norms) return col_norms + return tree_map(col_norm, self.pytree) class FunctionalLinearOperator: - def __init__(self, fun, params): self.fun = functools.partial(fun, params) @@ -71,7 +72,7 @@ def rmatvec(self, x, y): def matvec_and_rmatvec(self, x, y): matvec_x, vjp = jax.vjp(self.matvec, x) - rmatvec_y, = vjp(y) + (rmatvec_y,) = vjp(y) return matvec_x, rmatvec_y def normal_matvec(self, x): @@ -85,3 +86,72 @@ def _make_linear_operator(matvec): return DenseLinearOperator else: return functools.partial(FunctionalLinearOperator, matvec) + + +def block_row_matvec(block, x): + """Performs a matvec for a row of block matrices. + + The following matvec is performed: + [U1, ..., UN] * [x1, ..., xN] + where U1, ..., UN are matrices and x1, ..., xN are vectors + of compatible shapes. + """ + if len(block) != len(x): + raise ValueError( + "We need as many blocks in the matrix as in the vector." + ) + return sum(jax.tree_util.tree_map(jnp.dot, block, x)) + + +# TODO(gnegiar): Extend to arbitrary block shapes. +@jax.tree_util.register_pytree_node_class +@dataclass +class BlockLinearOperator: + """Represents a linear operator defined by blocks over a block pytree. + + Attributes: + blocks: a 2x2 block matrix of the form + [[A, B] + [C, D]] + """ + + blocks: Tuple[Tuple[jnp.array]] + + def __call__(self, x): + return self.matvec(x) + + def matvec(self, x): + """Performs the block matvec with u defined by blocks. + + The matvec is of form: + [u1, u2] + [[A, B] * + [C, D]] + + """ + return jax.tree_util.tree_map( + lambda row_of_blocks: block_row_matvec(row_of_blocks, x), + self.blocks, + is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1], + ) + + def rmatvec(self, x, y): + return self.matvec_and_rmatvec(x, y)[1] + + def matvec_and_rmatvec(self, x, y): + matvec_x, vjp = jax.vjp(self.matvec, x) + (rmatvec_y,) = vjp(y) + return matvec_x, rmatvec_y + + def normal_matvec(self, x): + """Computes A^T A x from matvec(x) = A x.""" + matvec_x, vjp = jax.vjp(self.matvec, x) + return vjp(matvec_x)[0] + + def tree_flatten(self): + return self.blocks, None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(children) diff --git a/tests/eq_qp_preconditioned_test.py b/tests/eq_qp_preconditioned_test.py new file mode 100644 index 00000000..f05f7750 --- /dev/null +++ b/tests/eq_qp_preconditioned_test.py @@ -0,0 +1,82 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax import test_util as jtu +import jax.numpy as jnp + +from jaxopt import PseudoInversePreconditionedEqQP +from jaxopt import EqualityConstrainedQP +import numpy as onp + + +class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase): + def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): + def fun(Q, c, A, b): + Q = 0.5 * (Q + Q.T) + + hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) + # reduce the primal variables to a scalar value for test purpose. + return jnp.sum(solver.run(**hyperparams).params[0]) + + # Derivative w.r.t. A. + rng = onp.random.RandomState(0) + V = rng.rand(*A.shape) + V /= onp.sqrt(onp.sum(V ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b)) + deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. b. + v = rng.rand(*b.shape) + v /= onp.sqrt(onp.sum(v ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b)) + deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. Q + W = rng.rand(*Q.shape) + W /= onp.sqrt(onp.sum(W ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b)) + deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. c + w = rng.rand(*c.shape) + w /= onp.sqrt(onp.sum(w ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b)) + deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + def test_pseudoinverse_preconditioner(self): + Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + c = jnp.array([1.0, 1.0]) + A = jnp.array([[1.0, 1.0]]) + b = jnp.array([1.0]) + qp = EqualityConstrainedQP(tol=1e-7) + preconditioned_qp = PseudoInversePreconditionedEqQP(qp) + params_obj = (Q, c) + params_eq = (A, b) + params_precond = preconditioned_qp.init_params(params_obj, params_eq) + hyperparams = dict( + params_obj=params_obj, + params_eq=params_eq, + ) + sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params + self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) + self._check_derivative_Q_c_A_b(preconditioned_qp, Q, c, A, b) diff --git a/tests/eq_qp_test.py b/tests/eq_qp_test.py index 071c3d64..11aeae32 100644 --- a/tests/eq_qp_test.py +++ b/tests/eq_qp_test.py @@ -26,7 +26,7 @@ class EqualityConstrainedQPTest(jtu.JaxTestCase): - def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b): + def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): def fun(Q, c, A, b): Q = 0.5 * (Q + Q.T) @@ -77,7 +77,7 @@ def test_qp_eq_only(self): hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) sol = qp.run(**hyperparams).params self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) - self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) + self._check_derivative_Q_c_A_b(qp, Q, c, A, b) def test_qp_eq_with_init(self): Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) @@ -89,7 +89,7 @@ def test_qp_eq_with_init(self): init_params = KKTSolution(jnp.array([1.0, 1.0]), jnp.array([1.0])) sol = qp.run(init_params, **hyperparams).params self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) - self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) + self._check_derivative_Q_c_A_b(qp, Q, c, A, b) def test_projection_hyperplane(self): x = jnp.array([1.0, 2.0])