Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test documentation docstrings with doctest #534

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ jobs:
set -xe
pip install --upgrade pip setuptools wheel
pip install -r docs/requirements.txt
- name: Build documentation
- name: Test examples and docstrings
run: |
set -xe
python -VV
cd docs && make clean && make html
make doctest
11 changes: 8 additions & 3 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
SPHINXOPTS = -d $(BUILDDIR)/doctrees -T

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
.PHONY: help Makefile doctest

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand All @@ -25,6 +25,11 @@ clean:
rm -rf _autosummary/

html-noplot:
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

doctest:
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
8 changes: 7 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@
'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'matplotlib.sphinxext.plot_directive',
'sphinx_autodoc_typehints',
'myst_nb',
"sphinx_remove_toctrees",
'sphinx_remove_toctrees',
'sphinx_rtd_theme',
'sphinx_gallery.gen_gallery',
'sphinx_copybutton',
Expand All @@ -70,7 +71,12 @@
"backreferences_dir": os.path.join("modules", "generated"),
}

# Specify how to identify the prompt when copying code snippets
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True
copybutton_exclude = "style"

trim_doctests_flags = True
source_suffix = ['.rst', '.ipynb', '.md']

autosummary_generate = True
Expand Down
38 changes: 23 additions & 15 deletions docs/constrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ To solve constrained optimization problems, we can use projected gradient
descent, which is gradient descent with an additional projection onto the
constraint set. Constraints are specified by setting the ``projection``
argument. For instance, non-negativity constraints can be specified using
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`::
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`:

from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
.. doctest::

>>> from jaxopt import ProjectedGradient
>>> from jaxopt.projection import projection_non_negative

>>> pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
>>> pg_sol = pg.run(w_init, data=(X, y)).params

pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
pg_sol = pg.run(w_init, data=(X, y)).params

Numerous projections are available, see below.

Expand All @@ -45,13 +48,15 @@ Specifying projection parameters
Some projections have a hyperparameter that can be specified. For
instance, the hyperparameter of :func:`projection_l2_ball
<jaxopt.projection.projection_l2_ball>` is the radius of the :math:`L_2` ball.
This can be passed using the ``hyperparams_proj`` argument of ``run``::
This can be passed using the ``hyperparams_proj`` argument of ``run``:

.. doctest::

from jaxopt.projection import projection_l2_ball
>>> from jaxopt.projection import projection_l2_ball

radius = 1.0
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
>>> radius = 1.0
>>> pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
>>> pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

.. topic:: Examples

Expand All @@ -62,13 +67,16 @@ Differentiation

In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``radius``::
now differentiate the solution w.r.t. ``radius``:

.. doctest::

>>> def solution(radius):
... pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
... return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

def solution(radius):
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
>>> print(jax.jacobian(solution)(radius))

print(jax.jacobian(solution)(radius))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand Down
57 changes: 35 additions & 22 deletions docs/non_smooth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,26 @@ For instance, suppose we want to solve the following optimization problem

which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1`. The
corresponding ``prox`` operator is :func:`prox_lasso <jaxopt.prox.prox_lasso>`.
We can therefore write::
We can therefore write:

from jaxopt import ProximalGradient
from jaxopt.prox import prox_lasso
.. doctest::

def least_squares(w, data):
X, y = data
residuals = jnp.dot(X, w) - y
return jnp.mean(residuals ** 2)
>>> import jax.numpy as jnp
>>> from jaxopt import ProximalGradient
>>> from jaxopt.prox import prox_lasso
>>> from sklearn import datasets
>>> X, y = datasets.make_regression()

>>> def least_squares(w, data):
... inputs, targets = data
... residuals = jnp.dot(inputs, w) - targets
... return jnp.mean(residuals ** 2)

>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
>>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

l1reg = 1.0
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

Note that :func:`prox_lasso <jaxopt.prox.prox_lasso>` has a hyperparameter
``l1reg``, which controls the :math:`L_1` regularization strength. As shown
Expand All @@ -65,13 +72,16 @@ Differentiation

In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``l1reg``::
now differentiate the solution w.r.t. ``l1reg``:


def solution(l1reg):
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
.. doctest::

print(jax.jacobian(solution)(l1reg))
>>> def solution(l1reg):
... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

>>> print(jax.jacobian(solution)(l1reg))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand All @@ -95,15 +105,18 @@ Block coordinate descent
Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with
:ref:`composite linear objective functions <composite_linear_functions>`.

Example::
Example:

.. doctest::

from jaxopt import objective
from jaxopt import prox
>>> from jaxopt import objective
>>> from jaxopt import prox
>>> import jax.numpy as jnp

l1reg = 1.0
w_init = jnp.zeros(n_features)
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> bcd = jaxopt.BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
>>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

.. topic:: Examples

Expand Down
Loading