From 66446d6cf619f5a89ded1b53f8ad3868107c4055 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Wed, 13 Sep 2023 14:51:32 +0200 Subject: [PATCH] test documentation docstrings --- .github/workflows/tests.yml | 4 +- docs/Makefile | 11 +- docs/conf.py | 8 +- docs/non_smooth.rst | 51 +++-- docs/quadratic_programming.rst | 380 ++++++++++++++++++--------------- docs/requirements.txt | 3 +- docs/root_finding.rst | 73 ++++--- docs/stochastic.rst | 61 +++--- docs/unconstrained.rst | 35 +-- 9 files changed, 360 insertions(+), 266 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c56db1f6..37842587 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,8 +64,8 @@ jobs: set -xe pip install --upgrade pip setuptools wheel pip install -r docs/requirements.txt - - name: Build documentation + - name: Test examples and docstrings run: | set -xe python -VV - cd docs && make clean && make html + make doctest diff --git a/docs/Makefile b/docs/Makefile index 85a82ccb..f12c3874 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,16 +3,16 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build +SPHINXOPTS = -d $(BUILDDIR)/doctrees -T # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help Makefile doctest # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). @@ -25,6 +25,11 @@ clean: rm -rf _autosummary/ html-noplot: - $(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html + $(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +doctest: + $(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/docs/conf.py b/docs/conf.py index 81552969..58c9367b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,13 +50,14 @@ 'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'matplotlib.sphinxext.plot_directive', 'sphinx_autodoc_typehints', 'myst_nb', - "sphinx_remove_toctrees", + 'sphinx_remove_toctrees', 'sphinx_rtd_theme', 'sphinx_gallery.gen_gallery', 'sphinx_copybutton', @@ -70,7 +71,12 @@ "backreferences_dir": os.path.join("modules", "generated"), } +# Specify how to identify the prompt when copying code snippets +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True +copybutton_exclude = "style" +trim_doctests_flags = True source_suffix = ['.rst', '.ipynb', '.md'] autosummary_generate = True diff --git a/docs/non_smooth.rst b/docs/non_smooth.rst index 79c5a2a1..62cc8977 100644 --- a/docs/non_smooth.rst +++ b/docs/non_smooth.rst @@ -40,17 +40,23 @@ which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot | corresponding ``prox`` operator is :func:`prox_lasso `. We can therefore write:: - from jaxopt import ProximalGradient - from jaxopt.prox import prox_lasso +.. doctest:: + >>> import jax.numpy as jnp + >>> from jaxopt import ProximalGradient + >>> from jaxopt.prox import prox_lasso + >>> from sklearn import datasets + >>> X, y = datasets.make_regression() - def least_squares(w, data): - X, y = data - residuals = jnp.dot(X, w) - y - return jnp.mean(residuals ** 2) + >>> def least_squares(w, data): + ... inputs, targets = data + ... residuals = jnp.dot(inputs, w) - targets + ... return jnp.mean(residuals ** 2) + + >>> l1reg = 1.0 + >>> w_init = jnp.zeros(n_features) + >>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso) + >>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params - l1reg = 1.0 - pg = ProximalGradient(fun=least_squares, prox=prox_lasso) - pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params Note that :func:`prox_lasso ` has a hyperparameter ``l1reg``, which controls the :math:`L_1` regularization strength. As shown @@ -65,13 +71,15 @@ Differentiation In some applications, it is useful to differentiate the solution of the solver with respect to some hyperparameters. Continuing the previous example, we can -now differentiate the solution w.r.t. ``l1reg``:: +now differentiate the solution w.r.t. ``l1reg``: + - def solution(l1reg): - pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True) - return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params +.. doctest:: + >>> def solution(l1reg): + ... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True) + ... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params - print(jax.jacobian(solution)(l1reg)) + >>> print(jax.jacobian(solution)(l1reg)) Under the hood, we use the implicit function theorem if ``implicit_diff=True`` and autodiff of unrolled iterations if ``implicit_diff=False``. See the @@ -95,15 +103,16 @@ Block coordinate descent Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with :ref:`composite linear objective functions `. -Example:: +Example: - from jaxopt import objective - from jaxopt import prox +.. doctest:: + >>> from jaxopt import objective + >>> from jaxopt import prox - l1reg = 1.0 - w_init = jnp.zeros(n_features) - bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso) - lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params + >>> l1reg = 1.0 + >>> w_init = jnp.zeros(n_features) + >>> bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso) + >>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params .. topic:: Examples diff --git a/docs/quadratic_programming.rst b/docs/quadratic_programming.rst index 30009e84..458dc0ba 100644 --- a/docs/quadratic_programming.rst +++ b/docs/quadratic_programming.rst @@ -112,20 +112,25 @@ The problem takes the form: This class is optimized for QPs with equality constraints only: it supports jit, pytrees and matvec. It is based on the KKT conditions of the problem. -Example:: +Example: - from jaxopt import EqualityConstrainedQP +.. doctest:: - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) + >>> import jax.numpy as jnp + >>> import jaxopt - qp = EqualityConstrainedQP() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) - print(sol.primal) - print(sol.dual_eq) + >>> qp = jaxopt.EqualityConstrainedQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params + + >>> print(sol.primal) + [0.2499998 0.74999976] + >>> print(sol.dual_eq) + [-2.7499995] Ill-posed problems ~~~~~~~~~~~~~~~~~~ @@ -138,19 +143,22 @@ it is possible to enable `iterative refinement `_. This can be done by setting ``refine_regularization`` and ``refine_maxiter``:: - from jaxopt.eq_qp import EqualityConstrainedQP +.. doctest:: - Q = 2 * jnp.array([[3000., 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) + >>> Q = 2 * jnp.array([[3000., 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) - qp = EqualityConstrainedQP(tol=1e-5, refine_regularization=3., refine_maxiter=50) - sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params + >>> qp = jaxopt.EqualityConstrainedQP(tol=1e-5, refine_regularization=3., refine_maxiter=50) + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params - print(sol.primal) - print(sol.dual_eq) - print(qp.l2_optimality_error(sol, params_obj=(Q, c), params_eq=(A, b))) + >>> print(sol.primal) + [1.6666646e-04 9.9981850e-01] + >>> print(sol.dual_eq) + [-2.9998174] + >>> print(qp.l2_optimality_error(sol, params_obj=(Q, c), params_eq=(A, b))) + 2.0285292e-05 General QPs @@ -175,23 +183,29 @@ However, it is not jittable, and does not support matvec and pytrees. jaxopt.CvxpyQP -Example:: +Example: - from jaxopt import CvxpyQP +.. doctest:: - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) - G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) - h = jnp.array([0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) + >>> G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) + >>> h = jnp.array([0.0, 0.0]) - qp = CvxpyWrapper() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params + >>> qp = jaxopt.CvxpyQP() + >>> init_params = jnp.zeros(2) + >>> sol = qp.run( + ... init_params=init_params, params_obj=(Q, c), + ... params_eq=(A, b), params_ineq=(G, h)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + [0.25 0.75] + >>> print(sol.dual_eq) + [-2.75] + >>> print(sol.dual_ineq) + [0. 0.] It is also possible to specify only equality constraints or only inequality constraints by setting ``params_eq`` or ``params_ineq`` to ``None``. @@ -211,23 +225,27 @@ Hence we recommend to use :class:`BoxOSQP` to avoid a costly problem transformat jaxopt.OSQP -Example:: +Example: - from jaxopt import OSQP +.. doctest:: + >>> from jaxopt import OSQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) - G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) - h = jnp.array([0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) + >>> G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) + >>> h = jnp.array([0.0, 0.0]) - qp = OSQP() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params + >>> qp = OSQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + [0.24996418 0.7500219 ] + >>> print(sol.dual_eq) + [-2.750001] + >>> print(sol.dual_ineq) + [0. 0.] See :class:`jaxopt.BoxOSQP` for a full description of the parameters. @@ -254,22 +272,27 @@ but accepts problems in the above box-constrained format instead. The bounds ``u`` (resp. ``l``) can be set to ``inf`` (resp. ``-inf``) if required. Equality can be enforced with ``l = u``. -Example:: +Example: + +.. doctests:: - from jaxopt import BoxOSQP + >>> from jaxopt import BoxOSQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]) - l = jnp.array([1.0, -jnp.inf, -jnp.inf]) - u = jnp.array([1.0, 0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]) + >>> l = jnp.array([1.0, -jnp.inf, -jnp.inf]) + >>> u = jnp.array([1.0, 0.0, 0.0]) - qp = BoxOSQP() - sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params + >>> qp = BoxOSQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + (Array([0.25004143, 0.7500388 ], dtype=float32), Array([ 1. , -0.2500382 , -0.75000846], dtype=float32)) + >>> print(sol.dual_eq) + [-2.7502570e+00 1.5411481e-09 0.0000000e+00] + >>> print(sol.dual_ineq) + (Array([0.0000000e+00, 1.5411481e-09, 0.0000000e+00], dtype=float32), Array([ 2.750257, 0. , -0. ], dtype=float32)) If required the algorithm can be sped up by setting ``check_primal_dual_infeasability`` to ``False``, and by setting @@ -304,20 +327,23 @@ The problem takes the form: :class:`jaxopt.BoxCDQP` uses a coordinate descent solver. The solver returns only the primal solution. -Example:: +Example: - from jaxopt import BoxCDQP +.. doctest:: + >>> from jaxopt import BoxCDQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, -1.0]) - l = jnp.array([0.0, 0.0]) - u = jnp.array([1.0, 1.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, -1.0]) + >>> l = jnp.array([0.0, 0.0]) + >>> u = jnp.array([1.0, 1.0]) - qp = BoxCDQP() - init = jnp.zeros(2) - sol = qp.run(init, params_obj=(Q, c), params_ineq=(l, u)).params + >>> qp = BoxCDQP() + >>> init = jnp.zeros(2) + >>> sol = qp.run(init, params_obj=(Q, c), params_ineq=(l, u)).params + + >>> print(sol) + [0. 0.5] - print(sol) Unconstrained QPs ----------------- @@ -332,17 +358,18 @@ quadratics of the form: The optimality condition rewrites :math:`\nabla \frac{1}{2} x^\top Q x + c^\top x=Qx+c=0`. Therefore, this is equivalent to solving the linear system :math:`Qx=-c`. Since the matrix :math:`Q` is assumed PSD, one of the best -algorithms is *conjugate gradient*. In JAXopt, this can be done as follows:: - - from jaxopt.linear_solve import solve_cg +algorithms is *conjugate gradient*. In JAXopt, this can be done as follows: - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - matvec = lambda x: jnp.dot(Q, x) +.. doctest:: + >>> from jaxopt.linear_solve import solve_cg - sol = solve_cg(matvec, b=-c) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> matvec = lambda x: jnp.dot(Q, x) - print(sol) + >>> sol = solve_cg(matvec, b=-c) + >>> print(sol) + [-0.14285713 -0.42857143] Pytree of matrices API ---------------------- @@ -360,30 +387,34 @@ It offers several advantages: * The tolerance is globally defined and shared by all the problems, and the number of iterations is the same for all the problems. -We illustrate below the parallel solving of two problems with different shapes:: +We illustrate below the parallel solving of two problems with different shapes: - Q1 = jnp.array([[1.0, -0.5], - [-0.5, 1.0]]) - Q2 = jnp.array([[2.0]]) - Q = {'problem1': Q1, 'problem2': Q2} +.. doctest:: - c1 = jnp.array([-0.4, 0.3]) - c2 = jnp.array([0.1]) - c = {'problem1': c1, 'problem2': c2} + >>> Q1 = jnp.array([[1.0, -0.5], + ... [-0.5, 1.0]]) + >>> Q2 = jnp.array([[2.0]]) + >>> Q = {'problem1': Q1, 'problem2': Q2} - a1 = jnp.array([[-0.5, 1.5]]) - a2 = jnp.array([[10.0]]) - A = {'problem1': a1, 'problem2': a2} + >>> c1 = jnp.array([-0.4, 0.3]) + >>> c2 = jnp.array([0.1]) + >>> c = {'problem1': c1, 'problem2': c2} - b1 = jnp.array([0.3]) - b2 = jnp.array([5.0]) - b = {'problem1': b1, 'problem2': b2} + >>> a1 = jnp.array([[-0.5, 1.5]]) + >>> a2 = jnp.array([[10.0]]) + >>> A = {'problem1': a1, 'problem2': a2} + + >>> b1 = jnp.array([0.3]) + >>> b2 = jnp.array([5.0]) + >>> b = {'problem1': b1, 'problem2': b2} + + >>> qp = jaxopt.EqualityConstrainedQP(tol=1e-3) + >>> hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) + >>> # Solve the two problems in parallel with a single call. + >>> sol = qp.run(**hyperparams).params + >>> print(sol.primal['problem1'], sol.primal['problem2']) + [0.42857167 0.34285742] [0.5] - qp = EqualityConstrainedQP(tol=1e-3) - hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) - # Solve the two problems in parallel with a single call. - sol = qp.run(**hyperparams).params - print(sol.primal['problem1'], sol.primal['problem2']) Matvec API ---------- @@ -401,8 +432,9 @@ It offers several advantages: This is the recommended API to use when the matrices are not block diagonal operators, especially when there are other sparsity patterns involved, or in conjunction with -implicit differentiation:: +implicit differentiation: +.. doctest:: # Objective: # min ||data @ x - targets||_2^2 + 2 * n * lam ||x||_1 # @@ -412,27 +444,30 @@ implicit differentiation:: # under targets = data @ x - y # 0 <= x + t <= infinity # -infinity <= x - t <= 0 - data, targets = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - lam = 10.0 + >>> from sklearn import datasets + >>> n = 10 + >>> data, targets = datasets.make_regression(n_samples=n, n_features=3, random_state=0) + >>> lam = 10.0 + + >>> def matvec_Q(params_Q, xyt): + ... del params_Q # unused + ... x, y, t = xyt + ... return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t) - def matvec_Q(params_Q, xyt): - del params_Q # unused - x, y, t = xyt - return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t) + >>> c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1]) - c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1]) + >>> def matvec_A(params_A, xyt): + ... x, y, t = xyt + ... residuals = params_A @ x - y + ... return residuals, x + t, x - t - def matvec_A(params_A, xyt): - x, y, t = xyt - residuals = params_A @ x - y - return residuals, x + t, x - t + >>> l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf) + >>> u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0]) - l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf) - u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0]) + >>> hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u)) + >>> osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2) + >>> sol, state = osqp.run(None, **hyper_params) - hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u)) - osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2) - sol, state = osqp.run(None, **hyper_params) Quadratic function API ---------------------- @@ -452,33 +487,36 @@ Take care that this API also have drawbacks: * to extract `x -> Qx` and `c` from the function, we need to compute the Hessian-vector product and the gradient of ``fun``, which may be expensive. * for this API `init_params` must be provided to `run`, contrary to the other APIs. -We illustrate this API with Non Negative Least Squares (NNLS):: +We illustrate this API with Non Negative Least Squares (NNLS): +.. doctest:: # min_W \|Y-UW\|_F^2 # s.t. W>=0 - n, m, rank = 20, 10, 3 - onp.random.seed(654) - U = jax.nn.relu(onp.random.randn(n, rank)) - W_0 = jax.nn.relu(onp.random.randn(rank, m)) - Y = U @ W_0 + >>> import numpy as onp + >>> import jax + >>> n, m, rank = 20, 10, 3 + >>> onp.random.seed(654) + >>> U = jax.nn.relu(onp.random.randn(n, rank)) + >>> W_0 = jax.nn.relu(onp.random.randn(rank, m)) + >>> Y = U @ W_0 - def fun(W, params_obj): - Y, U = params_obj - # Write the objective as an implicit quadratic polynomial - return jnp.sum(jnp.square(Y - U @ W)) + >>> def fun(W, params_obj): + ... Y, U = params_obj + ... # Write the objective as an implicit quadratic polynomial + ... return jnp.sum(jnp.square(Y - U @ W)) - def matvec_G(params_G, W): - del params_G # unused - return -W + >>> def matvec_G(params_G, W): + ... del params_G # unused + ... return -W - zeros = jnp.zeros_like(W_0) - hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros)) + >>> zeros = jnp.zeros_like(W_0) + >>> hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros)) - solver = OSQP(fun=fun, matvec_G=matvec_G) + >>> solver = OSQP(fun=fun, matvec_G=matvec_G) - init_W = jnp.zeros_like(W_0) # mandatory with `fun` API. - init_params = solver.init_params(init_W, **hyper_params) - W_sol = solver.run(init_params=init_params, **hyper_params).params.primal + >>> init_W = jnp.zeros_like(W_0) # mandatory with `fun` API. + >>> init_params = solver.init_params(init_W, **hyper_params) + >>> W_sol = solver.run(init_params=init_params, **hyper_params).params.primal This API is not recommended for large-scale problems or nested differentiations, use matvec API instead. @@ -487,52 +525,56 @@ Implicit differentiation pitfalls When using implicit differentiation, the parameters w.r.t which we differentiate must be passed to `params_obj`, `params_eq` or `params_ineq`. They should not be captured -from the global scope by `fun` or `matvec`. We illustrate below this common mistake:: - - def wrong_solver(Q): # don't do this! - - def matvec_Q(params_Q, x): - del params_Q # unused - # error! Q is captured from the global scope. - # it does not fail now, but it will fail later. - return jnp.dot(Q, x) - - c = jnp.zeros(Q.shape[0]) - - A = jnp.array([[1.0, 2.0]]) - b = jnp.array([1.0]) - - eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) - sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params - loss = jnp.sum(sol.primal) - return loss - - Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) - _ = wrong_solver(Q) # no error... but it will fail later. - _ = jax.grad(wrong_solver)(Q) # raise CustomVJPException +from the global scope by `fun` or `matvec`. We illustrate below this common mistake: + +.. doctest:: + + >>> def _matvec_Q(params_Q, x): + ... del params_Q # unused + ... # error! Q is captured from the global scope. + ... # it does not fail now, but it will fail later. + ... return jnp.dot(Q, x) + + >>> def wrong_solver(Q): # don't do this! + ... c = jnp.zeros(Q.shape[0]) + ... A = jnp.array([[1.0, 2.0]]) + ... b = jnp.array([1.0]) + ... eq_qp = jaxopt.EqualityConstrainedQP(matvec_Q=_matvec_Q) + ... sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params + ... loss = jnp.sum(sol.primal) + ... return loss + + >>> Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) + >>> _ = wrong_solver(Q) # no error... but it will fail later. + >>> jax.grad(wrong_solver)(Q) # raise exception + Traceback (most recent call last): + ... + TypeError: Gradient only defined for scalar-output functions. Output was None. Also, notice that since the problems are convex, the optimum is independent of the starting point `init_params`. Hence, derivatives w.r.t `init_params` are always zero (mathematically). -The correct implementation is given below:: +The correct implementation is given below: - def correct_solver(Q): +.. doctest:: - def matvec_Q(params_Q, x): - return jnp.dot(params_Q, x) + >>> def correct_solver(Q): + ... def _matvec_Q(params_Q, x): + ... return jnp.dot(params_Q, x) - c = jnp.zeros(Q.shape[0]) + ... c = jnp.zeros(Q.shape[0]) - A = jnp.array([[1.0, 2.0]]) - b = jnp.array([1.0]) + ... A = jnp.array([[1.0, 2.0]]) + ... b = jnp.array([1.0]) - eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) - # Q is passed as a parameter, not captured from the global scope. - sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params - loss = jnp.sum(sol.primal) - return loss + ... eq_qp = jaxopt.EqualityConstrainedQP(matvec_Q=_matvec_Q) + ... # Q is passed as a parameter, not captured from the global scope. + ... sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params + ... loss = jnp.sum(sol.primal) + ... return loss - Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) - _ = correct_solver(Q) # no error - _ = jax.grad(correct_solver)(Q) # no error + >>> Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) + >>> print(correct_solver(Q)) # no error + 0.74999994 + >>> _ = jax.grad(correct_solver)(Q) # no error diff --git a/docs/requirements.txt b/docs/requirements.txt index 6cae9c58..6f201840 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,4 +15,5 @@ tensorflow dm-haiku flax jupytext -scikit-learn \ No newline at end of file +scikit-learn +cvxpy \ No newline at end of file diff --git a/docs/root_finding.rst b/docs/root_finding.rst index b3c76d49..8df760f4 100644 --- a/docs/root_finding.rst +++ b/docs/root_finding.rst @@ -27,15 +27,17 @@ First, let us consider the case :math:`F(x)`, i.e., without extra argument in this interval as long as :math:`F` is continuous. For instance, suppose that we want to find the root of :math:`F(x) = x^3 - x - 2`. We have :math:`F(1) = -2` and :math:`F(2) = 4`. Since the function is continuous, there -must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`:: +must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`: - from jaxopt import Bisection +.. doctest:: + >>> from jaxopt import Bisection + >>> def F(x): + ... return x ** 3 - x - 2 - def F(x): - return x ** 3 - x - 2 + >>> bisec = Bisection(optimality_fun=F, lower=1, upper=2) + >>> print(bisec.run().params) + 1.5213814 - bisec = Bisection(optimality_fun=F, lower=1, upper=2) - print(bisec.run().params) ``Bisection`` successfully finds the root ``x = 1.521``. Notice that ``Bisection`` does not require an initialization, @@ -46,17 +48,20 @@ Differentiation Now, let us consider the case :math:`F(x, \theta)`. For instance, suppose that ``F`` takes an additional argument ``factor``. We can easily differentiate -with respect to ``factor``:: +with respect to ``factor``: - def F(x, factor): - return factor * x ** 3 - x - 2 +.. doctest:: + >>> import jax + >>> def F(x, factor): + ... return factor * x ** 3 - x - 2 - def root(factor): - bisec = Bisection(optimality_fun=F, lower=1, upper=2) - return bisec.run(factor=factor).params + >>> def root(factor): + ... bisec = Bisection(optimality_fun=F, lower=1, upper=2) + ... return bisec.run(factor=factor).params - # Derivative of root with respect to factor at 2.0. - print(jax.grad(root)(2.0)) + >>> # Derivative of root with respect to factor at 2.0. + >>> print(jax.grad(root)(2.0)) + -0.22139914 Under the hood, we use the implicit function theorem in order to differentiate the root. See the :ref:`implicit differentiation ` section for more details. @@ -87,30 +92,34 @@ updates. One can control the number of updates with the ``history_size`` argument. Furthermore, Broyden's method uses a line search to ensure the rank-one updates are stable. -Example:: +Example: - import jax.numpy as jnp - from jaxopt import Broyden +.. doctest:: + >>> import jax.numpy as jnp + >>> from jaxopt import Broyden - def F(x): - return x ** 3 - x - 2 + >>> def F(x): + ... return x ** 3 - x - 2 - broyden = Broyden(fun=F) - print(broyden.run(jnp.array(1.0)).params) + >>> broyden = Broyden(fun=F) + >>> print(broyden.run(jnp.array(1.0)).params) + 1.5213826 -For implicit differentiation:: +For implicit differentiation: - import jax - import jax.numpy as jnp - from jaxopt import Broyden +.. doctest:: + >>> import jax + >>> import jax.numpy as jnp + >>> from jaxopt import Broyden - def F(x, factor): - return factor * x ** 3 - x - 2 + >>> def F(x, factor): + ... return factor * x ** 3 - x - 2 - def root(factor): - broyden = Broyden(fun=F) - return broyden.run(jnp.array(1.0), factor=factor).params + >>> def root(factor): + ... broyden = Broyden(fun=F) + ... return broyden.run(jnp.array(1.0), factor=factor).params - # Derivative of root with respect to factor at 2.0. - print(jax.grad(root)(2.0)) + >>> # Derivative of root with respect to factor at 2.0. + >>> print(jax.grad(root)(2.0)) + -0.22141123 \ No newline at end of file diff --git a/docs/stochastic.rst b/docs/stochastic.rst index cf6ad00f..e4bbf563 100644 --- a/docs/stochastic.rst +++ b/docs/stochastic.rst @@ -26,24 +26,30 @@ Defining an objective function Objective functions must contain a ``data`` argument corresponding to :math:`D` above. -Example:: +Example: + +.. doctest:: + >>> import jax.numpy as jnp + + >>> def ridge_reg_objective(params, l2reg, data): + ... X, y = data + ... residuals = jnp.dot(X, params) - y + ... return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2) - def ridge_reg_objective(params, l2reg, data): - X, y = data - residuals = jnp.dot(X, params) - y - return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2) Data iterator ------------- Sampling realizations of the random variable :math:`D` can be done using an iterator. -Example:: +Example: + +.. doctest:: + >>> def data_iterator(): + ... for _ in range(n_iter): + ... perm = rng.permutation(n_samples)[:batch_size] + ... yield (X[perm], y[perm]) - def data_iterator(): - for _ in range(n_iter): - perm = rng.permutation(n_samples)[:batch_size] - yield (X[perm], y[perm]) Solvers ------- @@ -59,12 +65,16 @@ Optax solvers ~~~~~~~~~~~~~ `Optax `_ solvers can be used in JAXopt using -:class:`OptaxSolver `. Here's an example with Adam:: +:class:`OptaxSolver `. +Here's an example with Adam: - from jaxopt import OptaxSolver +.. doctest:: + >>> from jaxopt import OptaxSolver + >>> import optax + >>> opt = optax.adam(0.1) # adam with a learning rate of 0.1 + >>> solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000) + ... - opt = optax.adam(learning_rate) - solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000) See `common optimizers `_ in the @@ -90,16 +100,19 @@ in classification tasks with separable classes, or on regression tasks without n Run iterator vs. manual loop ---------------------------- -The following:: +The following: - iterator = data_iterator() - solver.run_iterator(init_params, iterator, l2reg=l2reg) +.. doctest:: + >>> iterator = data_iterator() + >>> solver.run_iterator(init_params, iterator, l2reg=l2reg) -is equivalent to:: +is equivalent to: - iterator = data_iterator() - state = solver.init_state(init_params, l2reg=l2reg) - params = init_params - for _ in range(maxiter): - data = next(iterator) - params, state = solver.update(params, state, l2reg=l2reg, data=data) +.. doctest:: + >>> iterator = data_iterator() + >>> state = solver.init_state(init_params, l2reg=l2reg) + >>> params = init_params + >>> maxiter = 1000 + >>> for _ in range(maxiter): + ... data = next(iterator) + ... params, state = solver.update(params, state, l2reg=l2reg, data=data) diff --git a/docs/unconstrained.rst b/docs/unconstrained.rst index 5f304b85..5ac26f26 100644 --- a/docs/unconstrained.rst +++ b/docs/unconstrained.rst @@ -20,11 +20,15 @@ Objective functions must always include as first argument the variables with respect to which the function is minimized. The function can also contain extra arguments. -The following illustrates how to express the ridge regression objective:: +The following illustrates how to express the ridge regression objective: + + +.. doctest:: + >>> import jax.numpy as jnp + >>> def ridge_reg_objective(params, l2reg, X, y): + ... residuals = jnp.dot(X, params) - y + ... return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) - def ridge_reg_objective(params, l2reg, X, y): - residuals = jnp.dot(X, params) - y - return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) The model parameters ``params`` correspond to :math:`x` while ``l2reg``, ``X`` and ``y`` correspond to the extra arguments :math:`\theta` in the mathematical @@ -48,13 +52,16 @@ Instantiating and running the solver Continuing the ridge regression example above, gradient descent can be instantiated and run as follows:: - solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter) - res = solver.run(init_params, l2reg=l2reg, X=X, y=y) +.. doctest:: + >>> import jaxopt + >>> solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=500) + >>> res = solver.run(init_params, l2reg=l2reg, X=X, y=y) + + >>> # Alternatively, we could have used one of these solvers as well: + >>> # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) + >>> # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) + >>> # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) - # Alternatively, we could have used one of these solvers as well: - # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) - # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) - # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) Unpacking results ~~~~~~~~~~~~~~~~~ @@ -65,9 +72,11 @@ solver-specific information about convergence. Because ``res`` is a ``NamedTuple``, we can unpack it as:: - params, state = res - print(params, state) +.. doctest:: + >>> params, state = res + >>> print(params, state) Alternatively, we can also access attributes directly:: - print(res.params, res.state) +.. doctest:: + >>> print(res.params, res.state)