Skip to content

Commit

Permalink
test documentation docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianp committed Sep 15, 2023
1 parent c488cd9 commit 66446d6
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 266 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ jobs:
set -xe
pip install --upgrade pip setuptools wheel
pip install -r docs/requirements.txt
- name: Build documentation
- name: Test examples and docstrings
run: |
set -xe
python -VV
cd docs && make clean && make html
make doctest
11 changes: 8 additions & 3 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
SPHINXOPTS = -d $(BUILDDIR)/doctrees -T

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
.PHONY: help Makefile doctest

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand All @@ -25,6 +25,11 @@ clean:
rm -rf _autosummary/

html-noplot:
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

doctest:
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
8 changes: 7 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@
'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
'matplotlib.sphinxext.plot_directive',
'sphinx_autodoc_typehints',
'myst_nb',
"sphinx_remove_toctrees",
'sphinx_remove_toctrees',
'sphinx_rtd_theme',
'sphinx_gallery.gen_gallery',
'sphinx_copybutton',
Expand All @@ -70,7 +71,12 @@
"backreferences_dir": os.path.join("modules", "generated"),
}

# Specify how to identify the prompt when copying code snippets
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True
copybutton_exclude = "style"

trim_doctests_flags = True
source_suffix = ['.rst', '.ipynb', '.md']

autosummary_generate = True
Expand Down
51 changes: 30 additions & 21 deletions docs/non_smooth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,23 @@ which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot |
corresponding ``prox`` operator is :func:`prox_lasso <jaxopt.prox.prox_lasso>`.
We can therefore write::

from jaxopt import ProximalGradient
from jaxopt.prox import prox_lasso
.. doctest::
>>> import jax.numpy as jnp
>>> from jaxopt import ProximalGradient
>>> from jaxopt.prox import prox_lasso
>>> from sklearn import datasets
>>> X, y = datasets.make_regression()

def least_squares(w, data):
X, y = data
residuals = jnp.dot(X, w) - y
return jnp.mean(residuals ** 2)
>>> def least_squares(w, data):
... inputs, targets = data
... residuals = jnp.dot(inputs, w) - targets
... return jnp.mean(residuals ** 2)

>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
>>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

l1reg = 1.0
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

Note that :func:`prox_lasso <jaxopt.prox.prox_lasso>` has a hyperparameter
``l1reg``, which controls the :math:`L_1` regularization strength. As shown
Expand All @@ -65,13 +71,15 @@ Differentiation

In some applications, it is useful to differentiate the solution of the solver
with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``l1reg``::
now differentiate the solution w.r.t. ``l1reg``:


def solution(l1reg):
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
.. doctest::
>>> def solution(l1reg):
... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

print(jax.jacobian(solution)(l1reg))
>>> print(jax.jacobian(solution)(l1reg))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand All @@ -95,15 +103,16 @@ Block coordinate descent
Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with
:ref:`composite linear objective functions <composite_linear_functions>`.

Example::
Example:

from jaxopt import objective
from jaxopt import prox
.. doctest::
>>> from jaxopt import objective
>>> from jaxopt import prox

l1reg = 1.0
w_init = jnp.zeros(n_features)
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
>>> l1reg = 1.0
>>> w_init = jnp.zeros(n_features)
>>> bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
>>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params

.. topic:: Examples

Expand Down
Loading

0 comments on commit 66446d6

Please sign in to comment.