Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Pseudo-inverse preconditioner for EqQP. #133

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jaxopt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jaxopt._src.polyak_sgd import PolyakSGD
from jaxopt._src.projected_gradient import ProjectedGradient
from jaxopt._src.proximal_gradient import ProximalGradient
from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP
from jaxopt._src.quadratic_prog import QuadraticProgramming
from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares
from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize
Expand Down
12 changes: 8 additions & 4 deletions jaxopt/_src/eq_qp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def eq_fun(primal_var, params_eq):

# It is required to post_process the output of `idf.make_kkt_optimality_fun`
# to make the signatures of optimality_fun() and run() agree.
def optimality_fun(params, params_obj, params_eq):
# The M argument is needed for using preconditioners.
def optimality_fun(params, params_obj, params_eq, M=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On first sight, I'm rather -1 on introducing a pre-conditioner M in optimality_fun since I don't think it plays any role in the optimality conditions (it would not make sense to differentiate with respect to M). Since run and optimality_fun need to have the same signature, this rules out adding M to run as well.

I think I would go for something like this instead:

preconditioner = PseudoInversePreconditioner(params_obj, params_eq)
qp = EqualityConstrainedQP(preconditioner=preconditioner)
qp.run(params_obj=params_obj, params_eq=params_eq)

Typically, stuff that doesn't need to be differentiated should go to the constructor.

If you want to differentiate wrt params_eq or params_obj, you may need to use

EqualityConstrainedQP(preconditioner=lax.stop_gradient(preconditioner))

instead. Not entirely sure if PseudoInversePreconditioner should live in JAXopt or in user land.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't play a role in the optimality conditions -- it could play a role in the backwards though, if we manage to pass the argument to the backward solver as well. This would hopefully speed things up, since the forward and backward linear systems share the linear operator.

With the current API, if we're solving the QP as the inner problem of a bi-level problem, then we need to build a new QP solver instance at each step of the outer loop, and pass solve=partial(linearsolver, M=preconditioner) to both solve and implicit_diff_solve.

Something which is then unclear to me: does building a new instance of the QP solver necessarily trigger recompilation of the run method at each iteration ?

return optimality_fun_with_ineq(params, params_obj, params_eq, None)

return optimality_fun
Expand Down Expand Up @@ -103,7 +104,7 @@ class EqualityConstrainedQP(base.Solver):
implicit_diff_solve: Optional[Callable] = None
jit: bool = True

def _refined_solve(self, matvec, b, init, maxiter, tol):
def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs):
# Instead of solving S x = b
# We solve \bar{S} x = b
#
Expand Down Expand Up @@ -152,13 +153,14 @@ def matvec_regularized_qp(_, x):
maxiter=self.refine_maxiter,
tol=tol,
)
return solver.run(init_params=init, A=None, b=b)[0]
return solver.run(init_params=init, A=None, b=b, **kwargs)[0]

def run(
self,
init_params: Optional[base.KKTSolution] = None,
params_obj: Optional[Any] = None,
params_eq: Optional[Any] = None,
**kwargs,
) -> base.OptStep:
"""Solves 0.5 * x^T Q x + c^T x subject to Ax = b.

Expand All @@ -168,6 +170,7 @@ def run(
init_params: ignored.
params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
params_eq: (A, b) or (params_A, b) if matvec_A is provided.
**kwargs: Keyword args provided to the solver.
Returns:
(params, state), where params = (primal_var, dual_var_eq, None)
"""
Expand Down Expand Up @@ -200,10 +203,11 @@ def matvec(u):
init=init_params,
tol=self.tol,
maxiter=self.maxiter,
**kwargs,
)
else:
primal, dual_eq = self._refined_solve(
matvec, target, init_params, tol=self.tol, maxiter=self.maxiter
matvec, target, init_params, tol=self.tol, maxiter=self.maxiter, **kwargs
)

return base.OptStep(params=base.KKTSolution(primal, dual_eq, None), state=None)
Expand Down
58 changes: 58 additions & 0 deletions jaxopt/_src/eq_qp_preconditioned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Preconditioned solvers for equality constrained quadratic programming."""

from typing import Optional, Any
from dataclasses import dataclass
import jax.numpy as jnp
import jaxopt
from jaxopt._src import base
from jaxopt._src import linear_operator


@dataclass
class PseudoInversePreconditionedEqQP(base.Solver):
qp_solver: jaxopt.EqualityConstrainedQP

def init_params(self, params_obj, params_eq):
"""Computes the matvec associated to the pseudo inverse of the KKT matrix."""
Q, p = params_obj
A, b = params_eq
del p, b

kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]])
kkt_mat_pinv = jnp.linalg.pinv(kkt_mat)

d = Q.shape[0]

pinv_blocks = (
(kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]),
(kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]),
)
return linear_operator.BlockLinearOperator(pinv_blocks)

def run(
self,
init_params: Optional[base.KKTSolution] = None,
params_obj: Optional[Any] = None,
params_eq: Optional[Any] = None,
params_precond=None,
**kwargs
):
# TODO(gnegiar): the M parameter should be passed to both
# the QP solve and the implicit_diff_solve
return self.qp_solver.run(
init_params, params_obj, params_eq, M=params_precond, **kwargs
)
82 changes: 76 additions & 6 deletions jaxopt/_src/linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
"""Interface for linear operators."""

import functools
from dataclasses import dataclass
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as onp

from jaxopt.tree_util import tree_map, tree_sum, tree_mul
from jaxopt.tree_util import tree_map


class DenseLinearOperator:

def __init__(self, pytree):
self.pytree = pytree

Expand All @@ -33,7 +34,7 @@ def matvec(self, x):
return tree_map(jnp.dot, self.pytree, x)

def rmatvec(self, _, y):
return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y)
return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y)

def matvec_and_rmatvec(self, x, y):
return self.matvec(x), self.rmatvec(x, y)
Expand All @@ -52,11 +53,11 @@ def col_norm(w):
if not squared:
col_norms = jnp.sqrt(col_norms)
return col_norms

return tree_map(col_norm, self.pytree)


class FunctionalLinearOperator:

def __init__(self, fun, params):
self.fun = functools.partial(fun, params)

Expand All @@ -71,7 +72,7 @@ def rmatvec(self, x, y):

def matvec_and_rmatvec(self, x, y):
matvec_x, vjp = jax.vjp(self.matvec, x)
rmatvec_y, = vjp(y)
(rmatvec_y,) = vjp(y)
return matvec_x, rmatvec_y

def normal_matvec(self, x):
Expand All @@ -85,3 +86,72 @@ def _make_linear_operator(matvec):
return DenseLinearOperator
else:
return functools.partial(FunctionalLinearOperator, matvec)


def block_row_matvec(block, x):
"""Performs a matvec for a row of block matrices.

The following matvec is performed:
[U1, ..., UN] * [x1, ..., xN]
where U1, ..., UN are matrices and x1, ..., xN are vectors
of compatible shapes.
"""
if len(block) != len(x):
raise ValueError(
"We need as many blocks in the matrix as in the vector."
)
return sum(jax.tree_util.tree_map(jnp.dot, block, x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, here block is actually a tuple of blocks, and x a tuple of vectors with the same structure ? So technicaly it is not a row_vector product since the result is not a scalar as one would expect. Maybe add a docstring and consider renaming the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, let me know :)



# TODO(gnegiar): Extend to arbitrary block shapes.
@jax.tree_util.register_pytree_node_class
@dataclass
class BlockLinearOperator:
"""Represents a linear operator defined by blocks over a block pytree.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the pytree refers to ? The 2x2 tuple Tuple[Tuple[jnp.array]] ? Or is it something more complicated ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is -- without registering it as a pytree node, I get this error: type <class 'jaxopt._src.linear_operator.BlockLinearOperator'> is not a valid JAX type.


Attributes:
blocks: a 2x2 block matrix of the form
[[A, B]
[C, D]]
"""

blocks: Tuple[Tuple[jnp.array]]

def __call__(self, x):
return self.matvec(x)

def matvec(self, x):
"""Performs the block matvec with u defined by blocks.

The matvec is of form:
[u1, u2]
[[A, B] *
[C, D]]

"""
return jax.tree_util.tree_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are stopping the recursion at depth 1 (i.e you only retrieve [A,B] and C,D) I believe it would be more clear to hardcode it instead of using tree_map. It is a bit overkill here. For example consider writing upper_block = self.blocks[0].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True ; my hope was to leave a stub which could be extended for any sized block matrix, and not just 2x2. I can go either way on this, let me know what you think.

lambda row_of_blocks: block_row_matvec(row_of_blocks, x),
self.blocks,
is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1],
)

def rmatvec(self, x, y):
return self.matvec_and_rmatvec(x, y)[1]

def matvec_and_rmatvec(self, x, y):
matvec_x, vjp = jax.vjp(self.matvec, x)
(rmatvec_y,) = vjp(y)
return matvec_x, rmatvec_y

def normal_matvec(self, x):
"""Computes A^T A x from matvec(x) = A x."""
matvec_x, vjp = jax.vjp(self.matvec, x)
return vjp(matvec_x)[0]

def tree_flatten(self):
return self.blocks, None

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(children)
82 changes: 82 additions & 0 deletions tests/eq_qp_preconditioned_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt import PseudoInversePreconditionedEqQP
from jaxopt import EqualityConstrainedQP
import numpy as onp


class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase):
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
def fun(Q, c, A, b):
Q = 0.5 * (Q + Q.T)

hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
# reduce the primal variables to a scalar value for test purpose.
return jnp.sum(solver.run(**hyperparams).params[0])

# Derivative w.r.t. A.
rng = onp.random.RandomState(0)
V = rng.rand(*A.shape)
V /= onp.sqrt(onp.sum(V ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b))
deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. b.
v = rng.rand(*b.shape)
v /= onp.sqrt(onp.sum(v ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b))
deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. Q
W = rng.rand(*Q.shape)
W /= onp.sqrt(onp.sum(W ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b))
deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

# Derivative w.r.t. c
w = rng.rand(*c.shape)
w /= onp.sqrt(onp.sum(w ** 2))
eps = 1e-4
deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b))
deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps)
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

def test_pseudoinverse_preconditioner(self):
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
c = jnp.array([1.0, 1.0])
A = jnp.array([[1.0, 1.0]])
b = jnp.array([1.0])
qp = EqualityConstrainedQP(tol=1e-7)
preconditioned_qp = PseudoInversePreconditionedEqQP(qp)
params_obj = (Q, c)
params_eq = (A, b)
params_precond = preconditioned_qp.init_params(params_obj, params_eq)
hyperparams = dict(
params_obj=params_obj,
params_eq=params_eq,
)
sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(preconditioned_qp, Q, c, A, b)
6 changes: 3 additions & 3 deletions tests/eq_qp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class EqualityConstrainedQPTest(jtu.JaxTestCase):
def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b):
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
def fun(Q, c, A, b):
Q = 0.5 * (Q + Q.T)

Expand Down Expand Up @@ -77,7 +77,7 @@ def test_qp_eq_only(self):
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
sol = qp.run(**hyperparams).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch ! Maybe you need self._check_derivative_Q_c_A_b(qp, None, Q, c, A, b) here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the second argument from the self._check_derivative_Q_c_A_b method since it doesn't use it actually.


def test_qp_eq_with_init(self):
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
Expand All @@ -89,7 +89,7 @@ def test_qp_eq_with_init(self):
init_params = KKTSolution(jnp.array([1.0, 1.0]), jnp.array([1.0]))
sol = qp.run(init_params, **hyperparams).params
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)

def test_projection_hyperplane(self):
x = jnp.array([1.0, 2.0])
Expand Down