Skip to content

Commit

Permalink
test documentation docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianp committed Sep 15, 2023
1 parent c488cd9 commit 74f89d2
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 283 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ jobs:
set -xe
pip install --upgrade pip setuptools wheel
pip install -r docs/requirements.txt
- name: Build documentation
- name: Test examples and docstrings
run: |
set -xe
python -VV
cd docs && make clean && make html
make doctest
11 changes: 8 additions & 3 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
SPHINXOPTS = -d $(BUILDDIR)/doctrees -T

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
.PHONY: help Makefile doctest

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand All @@ -25,6 +25,11 @@ clean:
rm -rf _autosummary/

html-noplot:
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

doctest:
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
8 changes: 7 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@
'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'matplotlib.sphinxext.plot_directive',
'sphinx_autodoc_typehints',
'myst_nb',
"sphinx_remove_toctrees",
'sphinx_remove_toctrees',
'sphinx_rtd_theme',
'sphinx_gallery.gen_gallery',
'sphinx_copybutton',
Expand All @@ -70,7 +71,12 @@
"backreferences_dir": os.path.join("modules", "generated"),
}

# Specify how to identify the prompt when copying code snippets
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True
copybutton_exclude = "style"

trim_doctests_flags = True
source_suffix = ['.rst', '.ipynb', '.md']

autosummary_generate = True
Expand Down
38 changes: 23 additions & 15 deletions docs/constrained.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ To solve constrained optimization problems, we can use projected gradient
descent, which is gradient descent with an additional projection onto the
constraint set. Constraints are specified by setting the ``projection``
argument. For instance, non-negativity constraints can be specified using
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`::
:func:`projection_non_negative <jaxopt.projection.projection_non_negative>`:

from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
.. doctest::

>>> from jaxopt import ProjectedGradient
>>> from jaxopt.projection import projection_non_negative

>>> pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
>>> pg_sol = pg.run(w_init, data=(X, y)).params

pg = ProjectedGradient(fun=fun, projection=projection_non_negative)
pg_sol = pg.run(w_init, data=(X, y)).params

Numerous projections are available, see below.

Expand All @@ -45,13 +48,15 @@ Specifying projection parameters
Some projections have a hyperparameter that can be specified. For
instance, the hyperparameter of :func:`projection_l2_ball
<jaxopt.projection.projection_l2_ball>` is the radius of the :math:`L_2` ball.
This can be passed using the ``hyperparams_proj`` argument of ``run``::
This can be passed using the ``hyperparams_proj`` argument of ``run``:

.. doctest::

from jaxopt.projection import projection_l2_ball
>>> from jaxopt.projection import projection_l2_ball

radius = 1.0
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
>>> radius = 1.0
>>> pg = ProjectedGradient(fun=fun, projection=projection_l2_ball)
>>> pg_sol = pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

.. topic:: Examples

Expand All @@ -62,13 +67,16 @@ Differentiation

In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``radius``::
now differentiate the solution w.r.t. ``radius``:

.. doctest::

>>> def solution(radius):
... pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
... return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params

def solution(radius):
pg = ProjectedGradient(fun=fun, projection=projection_l2_ball, implicit_diff=True)
return pg.run(w_init, hyperparams_proj=radius, data=(X, y)).params
>>> print(jax.jacobian(solution)(radius))

print(jax.jacobian(solution)(radius))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand Down
57 changes: 35 additions & 22 deletions docs/non_smooth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,26 @@ For instance, suppose we want to solve the following optimization problem
which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot ||w||_1`. The
corresponding ``prox`` operator is :func:`prox_lasso <jaxopt.prox.prox_lasso>`.
We can therefore write::
We can therefore write:

from jaxopt import ProximalGradient
from jaxopt.prox import prox_lasso
.. doctest::

def least_squares(w, data):
X, y = data
residuals = jnp.dot(X, w) - y
return jnp.mean(residuals ** 2)
>>> import jax.numpy as jnp
>>> from jaxopt import ProximalGradient
>>> from jaxopt.prox import prox_lasso
>>> from sklearn import datasets
>>> X, y = datasets.make_regression()

>>> def least_squares(w, data):
... inputs, targets = data
... residuals = jnp.dot(inputs, w) - targets
... return jnp.mean(residuals ** 2)

>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
>>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

l1reg = 1.0
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

Note that :func:`prox_lasso <jaxopt.prox.prox_lasso>` has a hyperparameter
``l1reg``, which controls the :math:`L_1` regularization strength. As shown
Expand All @@ -65,13 +72,16 @@ Differentiation

In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``l1reg``::
now differentiate the solution w.r.t. ``l1reg``:


def solution(l1reg):
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
.. doctest::

print(jax.jacobian(solution)(l1reg))
>>> def solution(l1reg):
... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

>>> print(jax.jacobian(solution)(l1reg))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand All @@ -95,15 +105,18 @@ Block coordinate descent
Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with
:ref:`composite linear objective functions <composite_linear_functions>`.

Example::
Example:

.. doctest::

from jaxopt import objective
from jaxopt import prox
>>> from jaxopt import objective
>>> from jaxopt import prox
>>> import jax.numpy as jnp

l1reg = 1.0
w_init = jnp.zeros(n_features)
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> bcd = jaxopt.BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
>>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

.. topic:: Examples

Expand Down
Loading

0 comments on commit 74f89d2

Please sign in to comment.