diff --git a/docs/constrained.rst b/docs/constrained.rst index a65307ed..f72b6be8 100644 --- a/docs/constrained.rst +++ b/docs/constrained.rst @@ -109,6 +109,14 @@ The following operators are available. Projections always have two arguments: the input to be projected and the parameters of the convex set. +Note that a retraction is also provided, that allows to retrieve +an arbitrary point lying in the intersection of convex sets. + +.. autosummary:: + :toctree: _autosummary + + jaxopt.projection.alternating_projections + Mirror descent -------------- diff --git a/jaxopt/_src/projection.py b/jaxopt/_src/projection.py index 9154b682..08b38e75 100644 --- a/jaxopt/_src/projection.py +++ b/jaxopt/_src/projection.py @@ -17,12 +17,14 @@ from functools import partial from typing import Any from typing import Callable +from typing import List from typing import Tuple import jax import jax.numpy as jnp from jax.scipy.special import logsumexp +from jaxopt._src.fixed_point_iteration import FixedPointIteration from jaxopt._src.bisection import Bisection from jaxopt._src.eq_qp import EqualityConstrainedQP from jaxopt._src.lbfgs import LBFGS @@ -30,6 +32,65 @@ from jaxopt._src import tree_util +def alternating_projections(initial_guess: Any, + projections: List, + hyperparams: List, + **fixed_point_params) -> Any: + """Alternating projections algorithm. + + This algorithm returns a point in the intersection of convex sets + by projecting onto each set in turn. + + If the sets are not convex, or if their intersection is empty, + this algorithm may not converge. + + If the sets are convex and their intersection is non empty, + the algorithm converges to a point `p*` in the intersection of the sets. + However this point `p*` is not necessarily the closest to the initial guess, + i.e alternating_projections is not a valid projection itself. + + If the inittial guess lies in the intersection of the sets, then + the algorithm converges to this point. Hence this algorithm is a retraction. + If the initial guess lies outside the intersection, and if the intersection + contains more than one point, then the algorithm converges to an arbitrary + point in the intersection. + + Implicit differentiation will measure the sensitivity of `p*` + to perturbations in the `hyperparams`, but not to perturbations + in the initial guess. + + Args: + projections: a sequence of projections, each of which is a function that + with signature ``x, hyperparams -> x``. + hyperparams: a list of hyperparameters for each projection, each being a + pytree. + **fixed_point_params: parameters for the fixed point solver. + Returns: + A Pytree lying in the intersection of the sets. + + References: + Escalante, R. and Raydan, M., 2011. Alternating projection methods. + Society for Industrial and Applied Mathematics. + """ + assert len(projections) == len(hyperparams) + + def composed_projections(x, hyperparams): + for proj, hparam in zip(projections, hyperparams): + x = proj(x, hparam) + return x + + if 'maxiter' not in fixed_point_params: + fixed_point_params["maxiter"] = 100 + if 'tol' not in fixed_point_params: + fixed_point_params["tol"] = 1e-5 + + # look for a fixed point of this operator + solver = FixedPointIteration(fixed_point_fun=composed_projections, + **fixed_point_params) + fixed_point = solver.run(initial_guess, hyperparams).params + return fixed_point + + def projection_non_negative(x: Any, hyperparams=None) -> Any: r"""Projection onto the non-negative orthant: diff --git a/jaxopt/projection.py b/jaxopt/projection.py index 9077be96..7dfe11bd 100644 --- a/jaxopt/projection.py +++ b/jaxopt/projection.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from jaxopt._src.projection import alternating_projections from jaxopt._src.projection import projection_non_negative from jaxopt._src.projection import projection_box from jaxopt._src.projection import projection_hypercube diff --git a/tests/projection_test.py b/tests/projection_test.py index a614d8ff..a5f641b0 100644 --- a/tests/projection_test.py +++ b/tests/projection_test.py @@ -439,6 +439,32 @@ def test_projection_birkhoff(self): solution1 = projection.projection_birkhoff(doubly_stochastic_matrix) self.assertArraysAllClose(doubly_stochastic_matrix, solution1) + def test_alternating_projections(self): + # x1 + x2 = 1 + x = jnp.array([-2.0, 1.0, 3.0]) + a = jnp.array([ 1.0, 1.0, 0.]) + b = jnp.array(1.0) + + # l2 ball of radius 1.5 + radius = jnp.array(1.5) + + def retract_on_disk_intercept(b): + # The intersection of a ball with an hyperplane is a disk. + retract_on_disk = [projection.projection_l2_ball, + projection.projection_hyperplane] + hyper_params = [radius, (a, b)] + in_disk = projection.alternating_projections(x, retract_on_disk, hyper_params) + + return in_disk + + in_disk = retract_on_disk_intercept(b) + atol = 1e-5 + self.assertLessEqual(jnp.linalg.norm(in_disk), radius + atol) + self.assertArraysAllClose(jnp.dot(a, in_disk), jnp.array(b), atol=atol) + + # test that there is no error. + unused_jac = jax.jacrev(retract_on_disk_intercept)(b) + def test_projection_sparse_simplex(self): def top_k(x, k): """Preserve the top-k entries of the vector x and put -inf values elsewhere.