diff --git a/jaxopt/__init__.py b/jaxopt/__init__.py index 96aef470..fc7f3063 100644 --- a/jaxopt/__init__.py +++ b/jaxopt/__init__.py @@ -35,6 +35,7 @@ from jaxopt._src.polyak_sgd import PolyakSGD from jaxopt._src.projected_gradient import ProjectedGradient from jaxopt._src.proximal_gradient import ProximalGradient +from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP from jaxopt._src.quadratic_prog import QuadraticProgramming from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize diff --git a/jaxopt/_src/eq_qp.py b/jaxopt/_src/eq_qp.py index 65fc3936..27af4db3 100644 --- a/jaxopt/_src/eq_qp.py +++ b/jaxopt/_src/eq_qp.py @@ -103,7 +103,7 @@ class EqualityConstrainedQP(base.Solver): implicit_diff_solve: Optional[Callable] = None jit: bool = True - def _refined_solve(self, matvec, b, init, maxiter, tol): + def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs): # Instead of solving S x = b # We solve \bar{S} x = b # @@ -152,13 +152,14 @@ def matvec_regularized_qp(_, x): maxiter=self.refine_maxiter, tol=tol, ) - return solver.run(init_params=init, A=None, b=b)[0] + return solver.run(init_params=init, A=None, b=b, **kwargs)[0] def run( self, init_params: Optional[base.KKTSolution] = None, params_obj: Optional[Any] = None, params_eq: Optional[Any] = None, + **kwargs, ) -> base.OptStep: """Solves 0.5 * x^T Q x + c^T x subject to Ax = b. @@ -168,6 +169,7 @@ def run( init_params: ignored. params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided. params_eq: (A, b) or (params_A, b) if matvec_A is provided. + **kwargs: Keyword args provided to the solver. Returns: (params, state), where params = (primal_var, dual_var_eq, None) """ @@ -200,6 +202,7 @@ def matvec(u): init=init_params, tol=self.tol, maxiter=self.maxiter, + **kwargs, ) else: primal, dual_eq = self._refined_solve( diff --git a/jaxopt/_src/eq_qp_preconditioned.py b/jaxopt/_src/eq_qp_preconditioned.py new file mode 100644 index 00000000..3490f6f6 --- /dev/null +++ b/jaxopt/_src/eq_qp_preconditioned.py @@ -0,0 +1,58 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Preconditioned solvers for equality constrained quadratic programming.""" + +from typing import Optional, Any +from dataclasses import dataclass +import jax.numpy as jnp +import jaxopt +from jaxopt._src import base +from jaxopt._src import linear_operator + + +@dataclass +class PseudoInversePreconditionedEqQP(base.Solver): + qp_solver: jaxopt.EqualityConstrainedQP + + def init_params(self, params_obj, params_eq): + """Computes the matvec associated to the pseudo inverse of the KKT matrix.""" + Q, p = params_obj + A, b = params_eq + del p, b + + kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]]) + kkt_mat_pinv = jnp.linalg.pinv(kkt_mat) + + d = Q.shape[0] + + pinv_blocks = ( + (kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]), + (kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]), + ) + return linear_operator.BlockLinearOperator(pinv_blocks) + + def run( + self, + init_params: Optional[base.KKTSolution] = None, + params_obj: Optional[Any] = None, + params_eq: Optional[Any] = None, + params_precond=None, + **kwargs + ): + # TODO(gnegiar): the M parameter should be passed to both + # the QP solve and the implicit_diff_solve + return self.qp_solver.run( + init_params, params_obj, params_eq, M=params_precond, **kwargs + ) diff --git a/jaxopt/_src/linear_operator.py b/jaxopt/_src/linear_operator.py index 34e4941e..0bd0ce2f 100644 --- a/jaxopt/_src/linear_operator.py +++ b/jaxopt/_src/linear_operator.py @@ -14,15 +14,16 @@ """Interface for linear operators.""" import functools +from dataclasses import dataclass +from typing import Tuple + import jax import jax.numpy as jnp -import numpy as onp -from jaxopt.tree_util import tree_map, tree_sum, tree_mul +from jaxopt.tree_util import tree_map class DenseLinearOperator: - def __init__(self, pytree): self.pytree = pytree @@ -33,7 +34,7 @@ def matvec(self, x): return tree_map(jnp.dot, self.pytree, x) def rmatvec(self, _, y): - return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y) + return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y) def matvec_and_rmatvec(self, x, y): return self.matvec(x), self.rmatvec(x, y) @@ -52,11 +53,11 @@ def col_norm(w): if not squared: col_norms = jnp.sqrt(col_norms) return col_norms + return tree_map(col_norm, self.pytree) class FunctionalLinearOperator: - def __init__(self, fun, params): self.fun = functools.partial(fun, params) @@ -71,7 +72,7 @@ def rmatvec(self, x, y): def matvec_and_rmatvec(self, x, y): matvec_x, vjp = jax.vjp(self.matvec, x) - rmatvec_y, = vjp(y) + (rmatvec_y,) = vjp(y) return matvec_x, rmatvec_y def normal_matvec(self, x): @@ -85,3 +86,72 @@ def _make_linear_operator(matvec): return DenseLinearOperator else: return functools.partial(FunctionalLinearOperator, matvec) + + +def block_row_matvec(block, x): + """Performs a matvec for a row of block matrices. + + The following matvec is performed: + [U1, ..., UN] * [x1, ..., xN] + where U1, ..., UN are matrices and x1, ..., xN are vectors + of compatible shapes. + """ + if len(block) != len(x): + raise ValueError( + "We need as many blocks in the matrix as in the vector." + ) + return sum(jax.tree_util.tree_map(jnp.dot, block, x)) + + +# TODO(gnegiar): Extend to arbitrary block shapes. +@jax.tree_util.register_pytree_node_class +@dataclass +class BlockLinearOperator: + """Represents a linear operator defined by blocks over a block pytree. + + Attributes: + blocks: a 2x2 block matrix of the form + [[A, B] + [C, D]] + """ + + blocks: Tuple[Tuple[jnp.array]] + + def __call__(self, x): + return self.matvec(x) + + def matvec(self, x): + """Performs the block matvec with u defined by blocks. + + The matvec is of form: + [u1, u2] + [[A, B] * + [C, D]] + + """ + return jax.tree_util.tree_map( + lambda row_of_blocks: block_row_matvec(row_of_blocks, x), + self.blocks, + is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1], + ) + + def rmatvec(self, x, y): + return self.matvec_and_rmatvec(x, y)[1] + + def matvec_and_rmatvec(self, x, y): + matvec_x, vjp = jax.vjp(self.matvec, x) + (rmatvec_y,) = vjp(y) + return matvec_x, rmatvec_y + + def normal_matvec(self, x): + """Computes A^T A x from matvec(x) = A x.""" + matvec_x, vjp = jax.vjp(self.matvec, x) + return vjp(matvec_x)[0] + + def tree_flatten(self): + return self.blocks, None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + return cls(children) diff --git a/tests/eq_qp_preconditioned_test.py b/tests/eq_qp_preconditioned_test.py new file mode 100644 index 00000000..90f1b3fb --- /dev/null +++ b/tests/eq_qp_preconditioned_test.py @@ -0,0 +1,82 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax import test_util as jtu +import jax.numpy as jnp + +from jaxopt import PseudoInversePreconditionedEqQP +from jaxopt import EqualityConstrainedQP +import numpy as onp + + +class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase): + def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): + def fun(Q, c, A, b): + Q = 0.5 * (Q + Q.T) + + hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) + # reduce the primal variables to a scalar value for test purpose. + return jnp.sum(solver.run(**hyperparams).params[0]) + + # Derivative w.r.t. A. + rng = onp.random.RandomState(0) + V = rng.rand(*A.shape) + V /= onp.sqrt(onp.sum(V ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b)) + deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. b. + v = rng.rand(*b.shape) + v /= onp.sqrt(onp.sum(v ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b)) + deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. Q + W = rng.rand(*Q.shape) + W /= onp.sqrt(onp.sum(W ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b)) + deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + # Derivative w.r.t. c + w = rng.rand(*c.shape) + w /= onp.sqrt(onp.sum(w ** 2)) + eps = 1e-4 + deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b)) + deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps) + self.assertAllClose(deriv_jax, deriv_num, atol=1e-3) + + def test_pseudoinverse_preconditioner(self): + Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + c = jnp.array([1.0, 1.0]) + A = jnp.array([[1.0, 1.0]]) + b = jnp.array([1.0]) + qp = EqualityConstrainedQP(tol=1e-7) + preconditioned_qp = PseudoInversePreconditionedEqQP(qp) + params_obj = (Q, c) + params_eq = (A, b) + params_precond = preconditioned_qp.init_params(params_obj, params_eq) + hyperparams = dict( + params_obj=params_obj, + params_eq=params_eq, + ) + sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params + self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) + self._check_derivative_Q_c_A_b(qp, Q, c, A, b) diff --git a/tests/eq_qp_test.py b/tests/eq_qp_test.py index 071c3d64..8da3102e 100644 --- a/tests/eq_qp_test.py +++ b/tests/eq_qp_test.py @@ -26,7 +26,7 @@ class EqualityConstrainedQPTest(jtu.JaxTestCase): - def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b): + def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b): def fun(Q, c, A, b): Q = 0.5 * (Q + Q.T) @@ -77,7 +77,7 @@ def test_qp_eq_only(self): hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) sol = qp.run(**hyperparams).params self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) - self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b) + self._check_derivative_Q_c_A_b(qp, Q, c, A, b) def test_qp_eq_with_init(self): Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])