diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c56db1f6..37842587 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,8 +64,8 @@ jobs: set -xe pip install --upgrade pip setuptools wheel pip install -r docs/requirements.txt - - name: Build documentation + - name: Test examples and docstrings run: | set -xe python -VV - cd docs && make clean && make html + make doctest diff --git a/docs/Makefile b/docs/Makefile index 85a82ccb..f12c3874 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,16 +3,16 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build +SPHINXOPTS = -d $(BUILDDIR)/doctrees -T # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +.PHONY: help Makefile doctest # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). @@ -25,6 +25,11 @@ clean: rm -rf _autosummary/ html-noplot: - $(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html + $(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +doctest: + $(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/docs/conf.py b/docs/conf.py index 81552969..58c9367b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,13 +50,14 @@ 'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', 'matplotlib.sphinxext.plot_directive', 'sphinx_autodoc_typehints', 'myst_nb', - "sphinx_remove_toctrees", + 'sphinx_remove_toctrees', 'sphinx_rtd_theme', 'sphinx_gallery.gen_gallery', 'sphinx_copybutton', @@ -70,7 +71,12 @@ "backreferences_dir": os.path.join("modules", "generated"), } +# Specify how to identify the prompt when copying code snippets +copybutton_prompt_text = r">>> |\.\.\. " +copybutton_prompt_is_regexp = True +copybutton_exclude = "style" +trim_doctests_flags = True source_suffix = ['.rst', '.ipynb', '.md'] autosummary_generate = True diff --git a/docs/quadratic_programming.rst b/docs/quadratic_programming.rst index 30009e84..6fb86489 100644 --- a/docs/quadratic_programming.rst +++ b/docs/quadratic_programming.rst @@ -112,20 +112,22 @@ The problem takes the form: This class is optimized for QPs with equality constraints only: it supports jit, pytrees and matvec. It is based on the KKT conditions of the problem. -Example:: +Example: - from jaxopt import EqualityConstrainedQP +.. doctest:: + >>> from jax import numpy as jnp + >>> from jaxopt import EqualityConstrainedQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) - qp = EqualityConstrainedQP() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params + >> qp = EqualityConstrainedQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params - print(sol.primal) - print(sol.dual_eq) + >>> print(sol.primal) + >>> print(sol.dual_eq) Ill-posed problems ~~~~~~~~~~~~~~~~~~ @@ -138,19 +140,20 @@ it is possible to enable `iterative refinement `_. This can be done by setting ``refine_regularization`` and ``refine_maxiter``:: - from jaxopt.eq_qp import EqualityConstrainedQP +.. doctest:: + >>> from jaxopt.eq_qp import EqualityConstrainedQP - Q = 2 * jnp.array([[3000., 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) + >>> Q = 2 * jnp.array([[3000., 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) - qp = EqualityConstrainedQP(tol=1e-5, refine_regularization=3., refine_maxiter=50) - sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params + >>> qp = EqualityConstrainedQP(tol=1e-5, refine_regularization=3., refine_maxiter=50) + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b)).params - print(sol.primal) - print(sol.dual_eq) - print(qp.l2_optimality_error(sol, params_obj=(Q, c), params_eq=(A, b))) + >>> print(sol.primal) + >>> print(sol.dual_eq) + >>> print(qp.l2_optimality_error(sol, params_obj=(Q, c), params_eq=(A, b))) General QPs @@ -177,21 +180,22 @@ However, it is not jittable, and does not support matvec and pytrees. Example:: - from jaxopt import CvxpyQP +.. doctest:: + >>> from jaxopt import CvxpyQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) - G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) - h = jnp.array([0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) + >>> G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) + >>> h = jnp.array([0.0, 0.0]) - qp = CvxpyWrapper() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params + >>> qp = CvxpyWrapper() + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + >>> print(sol.dual_eq) + >>> print(sol.dual_ineq) It is also possible to specify only equality constraints or only inequality constraints by setting ``params_eq`` or ``params_ineq`` to ``None``. @@ -213,21 +217,22 @@ Hence we recommend to use :class:`BoxOSQP` to avoid a costly problem transformat Example:: - from jaxopt import OSQP +.. doctest:: + >>> from jaxopt import OSQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0]]) - b = jnp.array([1.0]) - G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) - h = jnp.array([0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0]]) + >>> b = jnp.array([1.0]) + >>> G = jnp.array([[-1.0, 0.0], [0.0, -1.0]]) + >>> h = jnp.array([0.0, 0.0]) - qp = OSQP() - sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params + >>> qp = OSQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + >>> print(sol.dual_eq) + >>> print(sol.dual_ineq) See :class:`jaxopt.BoxOSQP` for a full description of the parameters. @@ -256,20 +261,22 @@ Equality can be enforced with ``l = u``. Example:: - from jaxopt import BoxOSQP +.. doctests:: + + >>> from jaxopt import BoxOSQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]) - l = jnp.array([1.0, -jnp.inf, -jnp.inf]) - u = jnp.array([1.0, 0.0, 0.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]) + >>> l = jnp.array([1.0, -jnp.inf, -jnp.inf]) + >>> u = jnp.array([1.0, 0.0, 0.0]) - qp = BoxOSQP() - sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params + >>> qp = BoxOSQP() + >>> sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params - print(sol.primal) - print(sol.dual_eq) - print(sol.dual_ineq) + >>> print(sol.primal) + >>> print(sol.dual_eq) + >>> print(sol.dual_ineq) If required the algorithm can be sped up by setting ``check_primal_dual_infeasability`` to ``False``, and by setting @@ -306,18 +313,19 @@ the primal solution. Example:: - from jaxopt import BoxCDQP +.. doctest:: + >>> from jaxopt import BoxCDQP - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, -1.0]) - l = jnp.array([0.0, 0.0]) - u = jnp.array([1.0, 1.0]) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, -1.0]) + >>> l = jnp.array([0.0, 0.0]) + >>> u = jnp.array([1.0, 1.0]) - qp = BoxCDQP() - init = jnp.zeros(2) - sol = qp.run(init, params_obj=(Q, c), params_ineq=(l, u)).params + >>> qp = BoxCDQP() + >>> init = jnp.zeros(2) + >>> sol = qp.run(init, params_obj=(Q, c), params_ineq=(l, u)).params - print(sol) + >>> print(sol) Unconstrained QPs ----------------- @@ -334,15 +342,16 @@ x=Qx+c=0`. Therefore, this is equivalent to solving the linear system :math:`Qx=-c`. Since the matrix :math:`Q` is assumed PSD, one of the best algorithms is *conjugate gradient*. In JAXopt, this can be done as follows:: - from jaxopt.linear_solve import solve_cg +.. doctest:: + >>> from jaxopt.linear_solve import solve_cg - Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) - c = jnp.array([1.0, 1.0]) - matvec = lambda x: jnp.dot(Q, x) + >>> Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]]) + >>> c = jnp.array([1.0, 1.0]) + >>> matvec = lambda x: jnp.dot(Q, x) - sol = solve_cg(matvec, b=-c) + >>> sol = solve_cg(matvec, b=-c) - print(sol) + >>> print(sol) Pytree of matrices API ---------------------- @@ -362,28 +371,29 @@ It offers several advantages: We illustrate below the parallel solving of two problems with different shapes:: - Q1 = jnp.array([[1.0, -0.5], +.. doctest:: + >>> Q1 = jnp.array([[1.0, -0.5], [-0.5, 1.0]]) - Q2 = jnp.array([[2.0]]) - Q = {'problem1': Q1, 'problem2': Q2} + >>> Q2 = jnp.array([[2.0]]) + >>> Q = {'problem1': Q1, 'problem2': Q2} - c1 = jnp.array([-0.4, 0.3]) - c2 = jnp.array([0.1]) - c = {'problem1': c1, 'problem2': c2} + >>> c1 = jnp.array([-0.4, 0.3]) + >>> c2 = jnp.array([0.1]) + >>> c = {'problem1': c1, 'problem2': c2} - a1 = jnp.array([[-0.5, 1.5]]) - a2 = jnp.array([[10.0]]) - A = {'problem1': a1, 'problem2': a2} + >>> a1 = jnp.array([[-0.5, 1.5]]) + >>> a2 = jnp.array([[10.0]]) + >>> A = {'problem1': a1, 'problem2': a2} - b1 = jnp.array([0.3]) - b2 = jnp.array([5.0]) - b = {'problem1': b1, 'problem2': b2} + >>> b1 = jnp.array([0.3]) + >>> b2 = jnp.array([5.0]) + >>> b = {'problem1': b1, 'problem2': b2} - qp = EqualityConstrainedQP(tol=1e-3) - hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) + >>> qp = EqualityConstrainedQP(tol=1e-3) + >>> hyperparams = dict(params_obj=(Q, c), params_eq=(A, b)) # Solve the two problems in parallel with a single call. - sol = qp.run(**hyperparams).params - print(sol.primal['problem1'], sol.primal['problem2']) + >>> sol = qp.run(**hyperparams).params + >>> print(sol.primal['problem1'], sol.primal['problem2']) Matvec API ---------- @@ -403,6 +413,7 @@ This is the recommended API to use when the matrices are not block diagonal oper especially when there are other sparsity patterns involved, or in conjunction with implicit differentiation:: +.. doctest:: # Objective: # min ||data @ x - targets||_2^2 + 2 * n * lam ||x||_1 # @@ -412,27 +423,27 @@ implicit differentiation:: # under targets = data @ x - y # 0 <= x + t <= infinity # -infinity <= x - t <= 0 - data, targets = datasets.make_regression(n_samples=10, n_features=3, random_state=0) - lam = 10.0 + >>> data, targets = datasets.make_regression(n_samples=10, n_features=3, random_state=0) + >>> lam = 10.0 - def matvec_Q(params_Q, xyt): - del params_Q # unused - x, y, t = xyt - return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t) + >>> def matvec_Q(params_Q, xyt): + ... del params_Q # unused + ... x, y, t = xyt + ... return jnp.zeros_like(x), 2 * y, jnp.zeros_like(t) - c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1]) + >>> c = jnp.zeros(data.shape[1]), jnp.zeros(data.shape[0]), 2*n*lam * jnp.ones(data.shape[1]) - def matvec_A(params_A, xyt): - x, y, t = xyt - residuals = params_A @ x - y - return residuals, x + t, x - t + >>> def matvec_A(params_A, xyt): + ... x, y, t = xyt + ... residuals = params_A @ x - y + ... return residuals, x + t, x - t - l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf) - u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0]) + >>> l = targets, jnp.zeros_like(c[0]), jnp.full(data.shape[1], -jnp.inf) + >>> u = targets, jnp.full(data.shape[1], jnp.inf), jnp.zeros_like(c[0]) - hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u)) - osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2) - sol, state = osqp.run(None, **hyper_params) + >>> hyper_params = dict(params_obj=(None, c), params_eq=data, params_ineq=(l, u)) + >>> osqp = BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A, tol=1e-2) + >>> sol, state = osqp.run(None, **hyper_params) Quadratic function API ---------------------- @@ -454,31 +465,32 @@ Take care that this API also have drawbacks: We illustrate this API with Non Negative Least Squares (NNLS):: +.. doctest:: # min_W \|Y-UW\|_F^2 # s.t. W>=0 - n, m, rank = 20, 10, 3 - onp.random.seed(654) - U = jax.nn.relu(onp.random.randn(n, rank)) - W_0 = jax.nn.relu(onp.random.randn(rank, m)) - Y = U @ W_0 + >>> n, m, rank = 20, 10, 3 + >>> onp.random.seed(654) + >>> U = jax.nn.relu(onp.random.randn(n, rank)) + >>> W_0 = jax.nn.relu(onp.random.randn(rank, m)) + >>> Y = U @ W_0 - def fun(W, params_obj): - Y, U = params_obj - # Write the objective as an implicit quadratic polynomial - return jnp.sum(jnp.square(Y - U @ W)) + >>> def fun(W, params_obj): + ... Y, U = params_obj + ... # Write the objective as an implicit quadratic polynomial + ... return jnp.sum(jnp.square(Y - U @ W)) - def matvec_G(params_G, W): - del params_G # unused - return -W + >>> def matvec_G(params_G, W): + ... del params_G # unused + ... return -W - zeros = jnp.zeros_like(W_0) - hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros)) + >>> zeros = jnp.zeros_like(W_0) + >>> hyper_params = dict(params_obj=(Y, U), params_eq=None, params_ineq=(None, zeros)) - solver = OSQP(fun=fun, matvec_G=matvec_G) + >>> solver = OSQP(fun=fun, matvec_G=matvec_G) - init_W = jnp.zeros_like(W_0) # mandatory with `fun` API. - init_params = solver.init_params(init_W, **hyper_params) - W_sol = solver.run(init_params=init_params, **hyper_params).params.primal + >>> init_W = jnp.zeros_like(W_0) # mandatory with `fun` API. + >>> init_params = solver.init_params(init_W, **hyper_params) + >>> W_sol = solver.run(init_params=init_params, **hyper_params).params.primal This API is not recommended for large-scale problems or nested differentiations, use matvec API instead. @@ -489,27 +501,27 @@ When using implicit differentiation, the parameters w.r.t which we differentiate must be passed to `params_obj`, `params_eq` or `params_ineq`. They should not be captured from the global scope by `fun` or `matvec`. We illustrate below this common mistake:: - def wrong_solver(Q): # don't do this! + >>> def wrong_solver(Q): # don't do this! - def matvec_Q(params_Q, x): - del params_Q # unused - # error! Q is captured from the global scope. - # it does not fail now, but it will fail later. - return jnp.dot(Q, x) + ... def matvec_Q(params_Q, x): + ... del params_Q # unused + ... # error! Q is captured from the global scope. + ... # it does not fail now, but it will fail later. + ... return jnp.dot(Q, x) - c = jnp.zeros(Q.shape[0]) + ... c = jnp.zeros(Q.shape[0]) - A = jnp.array([[1.0, 2.0]]) - b = jnp.array([1.0]) + ... A = jnp.array([[1.0, 2.0]]) + ... b = jnp.array([1.0]) - eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) - sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params - loss = jnp.sum(sol.primal) - return loss + ... eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) + ... sol = eq_qp.run(None, params_obj=(None, c), params_eq=(A, b)).params + ... loss = jnp.sum(sol.primal) + ... return loss - Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) - _ = wrong_solver(Q) # no error... but it will fail later. - _ = jax.grad(wrong_solver)(Q) # raise CustomVJPException + >>> Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) + >>> _ = wrong_solver(Q) # no error... but it will fail later. + >>> _ = jax.grad(wrong_solver)(Q) # raise CustomVJPException Also, notice that since the problems are convex, the optimum is independent of the starting point `init_params`. Hence, derivatives w.r.t `init_params` are always @@ -517,22 +529,23 @@ zero (mathematically). The correct implementation is given below:: - def correct_solver(Q): +.. doctest:: + >>> def correct_solver(Q): - def matvec_Q(params_Q, x): - return jnp.dot(params_Q, x) + ... def matvec_Q(params_Q, x): + ... return jnp.dot(params_Q, x) - c = jnp.zeros(Q.shape[0]) + ... c = jnp.zeros(Q.shape[0]) - A = jnp.array([[1.0, 2.0]]) - b = jnp.array([1.0]) + ... A = jnp.array([[1.0, 2.0]]) + ... b = jnp.array([1.0]) - eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) - # Q is passed as a parameter, not captured from the global scope. - sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params - loss = jnp.sum(sol.primal) - return loss + ... eq_qp = EqualityConstrainedQP(matvec_Q=matvec_Q) + ... # Q is passed as a parameter, not captured from the global scope. + ... sol = eq_qp.run(None, params_obj=(Q, c), params_eq=(A, b)).params + ... loss = jnp.sum(sol.primal) + ... return loss - Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) - _ = correct_solver(Q) # no error - _ = jax.grad(correct_solver)(Q) # no error + >>> Q = jnp.array([[1.0, 0.5], [0.5, 4.0]]) + >>> _ = correct_solver(Q) # no error + >>> _ = jax.grad(correct_solver)(Q) # no error diff --git a/docs/root_finding.rst b/docs/root_finding.rst index b3c76d49..6b703f6a 100644 --- a/docs/root_finding.rst +++ b/docs/root_finding.rst @@ -27,15 +27,18 @@ First, let us consider the case :math:`F(x)`, i.e., without extra argument in this interval as long as :math:`F` is continuous. For instance, suppose that we want to find the root of :math:`F(x) = x^3 - x - 2`. We have :math:`F(1) = -2` and :math:`F(2) = 4`. Since the function is continuous, there -must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`:: +must be a :math:`x` between 1 and 2 such that :math:`F(x) = 0`: - from jaxopt import Bisection +.. doctest:: + >>> from jaxopt import Bisection - def F(x): - return x ** 3 - x - 2 + >>> def F(x): + ... return x ** 3 - x - 2 + + >>> bisec = Bisection(optimality_fun=F, lower=1, upper=2) + >>> print(bisec.run().params) + 1.5213814 - bisec = Bisection(optimality_fun=F, lower=1, upper=2) - print(bisec.run().params) ``Bisection`` successfully finds the root ``x = 1.521``. Notice that ``Bisection`` does not require an initialization, @@ -46,17 +49,19 @@ Differentiation Now, let us consider the case :math:`F(x, \theta)`. For instance, suppose that ``F`` takes an additional argument ``factor``. We can easily differentiate -with respect to ``factor``:: +with respect to ``factor``: - def F(x, factor): - return factor * x ** 3 - x - 2 +.. doctest:: + >>> import jax + >>> def F(x, factor): + ... return factor * x ** 3 - x - 2 - def root(factor): - bisec = Bisection(optimality_fun=F, lower=1, upper=2) - return bisec.run(factor=factor).params + >>> def root(factor): + ... bisec = Bisection(optimality_fun=F, lower=1, upper=2) + ... return bisec.run(factor=factor).params - # Derivative of root with respect to factor at 2.0. - print(jax.grad(root)(2.0)) + >>> # Derivative of root with respect to factor at 2.0. + >>> print(jax.grad(root)(2.0)) Under the hood, we use the implicit function theorem in order to differentiate the root. See the :ref:`implicit differentiation ` section for more details. @@ -87,30 +92,35 @@ updates. One can control the number of updates with the ``history_size`` argument. Furthermore, Broyden's method uses a line search to ensure the rank-one updates are stable. -Example:: +Example: - import jax.numpy as jnp - from jaxopt import Broyden +.. doctest:: + >>> import jax.numpy as jnp + >>> from jaxopt import Broyden - def F(x): - return x ** 3 - x - 2 + >>> def F(x): + ... return x ** 3 - x - 2 - broyden = Broyden(fun=F) - print(broyden.run(jnp.array(1.0)).params) + >>> broyden = Broyden(fun=F) + >>> print(broyden.run(jnp.array(1.0)).params) + 1.5213826 -For implicit differentiation:: +For implicit differentiation: - import jax - import jax.numpy as jnp - from jaxopt import Broyden +.. doctest:: + >>> import jax + >>> import jax.numpy as jnp + >>> from jaxopt import Broyden - def F(x, factor): - return factor * x ** 3 - x - 2 + >>> def F(x, factor): + ... return factor * x ** 3 - x - 2 - def root(factor): - broyden = Broyden(fun=F) - return broyden.run(jnp.array(1.0), factor=factor).params + >>> def root(factor): + ... broyden = Broyden(fun=F) + ... return broyden.run(jnp.array(1.0), factor=factor).params + -0.22141123 - # Derivative of root with respect to factor at 2.0. - print(jax.grad(root)(2.0)) + >>> # Derivative of root with respect to factor at 2.0. + >>> print(jax.grad(root)(2.0)) + -0.22141123 \ No newline at end of file diff --git a/docs/stochastic.rst b/docs/stochastic.rst index cf6ad00f..e4d4e51a 100644 --- a/docs/stochastic.rst +++ b/docs/stochastic.rst @@ -26,24 +26,30 @@ Defining an objective function Objective functions must contain a ``data`` argument corresponding to :math:`D` above. -Example:: +Example: + +.. doctest:: + >>> from jax import numpy as jnp + + >>> def ridge_reg_objective(params, l2reg, data): + ... X, y = data + ... residuals = jnp.dot(X, params) - y + ... return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2) - def ridge_reg_objective(params, l2reg, data): - X, y = data - residuals = jnp.dot(X, params) - y - return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2) Data iterator ------------- Sampling realizations of the random variable :math:`D` can be done using an iterator. -Example:: +Example: + +.. doctest:: + >>> def data_iterator(): + ... for _ in range(n_iter): + ... perm = rng.permutation(n_samples)[:batch_size] + ... yield (X[perm], y[perm]) - def data_iterator(): - for _ in range(n_iter): - perm = rng.permutation(n_samples)[:batch_size] - yield (X[perm], y[perm]) Solvers ------- @@ -59,12 +65,16 @@ Optax solvers ~~~~~~~~~~~~~ `Optax `_ solvers can be used in JAXopt using -:class:`OptaxSolver `. Here's an example with Adam:: +:class:`OptaxSolver `. +Here's an example with Adam: - from jaxopt import OptaxSolver +.. doctest:: + >>> from jaxopt import OptaxSolver + >>> import optax + >>> opt = optax.adam(0.1) # adam with a learning rate of 0.1 + >>> solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000) + ... - opt = optax.adam(learning_rate) - solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000) See `common optimizers `_ in the @@ -90,16 +100,19 @@ in classification tasks with separable classes, or on regression tasks without n Run iterator vs. manual loop ---------------------------- -The following:: +The following: - iterator = data_iterator() - solver.run_iterator(init_params, iterator, l2reg=l2reg) +.. doctest:: + >>> iterator = data_iterator() + >>> solver.run_iterator(init_params, iterator, l2reg=l2reg) -is equivalent to:: +is equivalent to: - iterator = data_iterator() - state = solver.init_state(init_params, l2reg=l2reg) - params = init_params - for _ in range(maxiter): - data = next(iterator) - params, state = solver.update(params, state, l2reg=l2reg, data=data) +.. doctest:: + >>> iterator = data_iterator() + >>> state = solver.init_state(init_params, l2reg=l2reg) + >>> params = init_params + >>> maxiter = 1000 + >>> for _ in range(maxiter): + ... data = next(iterator) + ... params, state = solver.update(params, state, l2reg=l2reg, data=data) diff --git a/docs/unconstrained.rst b/docs/unconstrained.rst index 5f304b85..b01fc3f8 100644 --- a/docs/unconstrained.rst +++ b/docs/unconstrained.rst @@ -20,11 +20,16 @@ Objective functions must always include as first argument the variables with respect to which the function is minimized. The function can also contain extra arguments. -The following illustrates how to express the ridge regression objective:: +The following illustrates how to express the ridge regression objective: + + +.. doctest:: + >>> from jax import numpy as jnp + + >>> def ridge_reg_objective(params, l2reg, X, y): + ... residuals = jnp.dot(X, params) - y + ... return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) - def ridge_reg_objective(params, l2reg, X, y): - residuals = jnp.dot(X, params) - y - return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2) The model parameters ``params`` correspond to :math:`x` while ``l2reg``, ``X`` and ``y`` correspond to the extra arguments :math:`\theta` in the mathematical @@ -48,13 +53,16 @@ Instantiating and running the solver Continuing the ridge regression example above, gradient descent can be instantiated and run as follows:: - solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter) - res = solver.run(init_params, l2reg=l2reg, X=X, y=y) +.. doctest:: + >>> import jaxopt + >>> solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=500) + >>> res = solver.run(init_params, l2reg=l2reg, X=X, y=y) + + >>> # Alternatively, we could have used one of these solvers as well: + >>> # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) + >>> # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) + >>> # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) - # Alternatively, we could have used one of these solvers as well: - # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) - # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) - # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) Unpacking results ~~~~~~~~~~~~~~~~~ @@ -65,9 +73,11 @@ solver-specific information about convergence. Because ``res`` is a ``NamedTuple``, we can unpack it as:: - params, state = res - print(params, state) +.. doctest:: + >>> params, state = res + >>> print(params, state) Alternatively, we can also access attributes directly:: - print(res.params, res.state) +.. doctest:: + >>> print(res.params, res.state)