Skip to content

Commit

Permalink
Release 0.4.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed May 24, 2022
1 parent 8531330 commit df8c0b6
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 56 deletions.
9 changes: 0 additions & 9 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@ Constrained
jaxopt.MirrorDescent
jaxopt.ScipyBoundedMinimize

Least-Squares
~~~~~~~~~~~~~

.. autosummary::
:toctree: _autosummary

jaxopt.GaussNewton
jaxopt.LevenbergMarquardt

Quadratic programming
~~~~~~~~~~~~~~~~~~~~~

Expand Down
18 changes: 14 additions & 4 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
Changelog
=========

Main branch
Version 0.4
-----------

New features
~~~~~~~~~~~~

- Added solver :class:`jaxopt.LevenbergMarquardt`, by Amir Saadat.
- Added solver :class:`jaxopt.BoxCDQP`, by Mathieu Blondel.

- Added :func:`projection_hypercube <jaxopt.projection.projection_hypercube>`, by Mathieu Blondel.

Bug fixes and enhancements
~~~~~~~~~~~~~~~~~~~~~~~~~~

- Fixed ``solve_normal_cg`` when the linear operator is "nonsquare" (does not map to a space of same dimension),
- Fixed :func:`solve_normal_cg <jaxopt.linear_solve.solve_normal_cg>`
when the linear operator is "nonsquare" (does not map to a space of same dimension),
by Mathieu Blondel.
- Added :func:`projection_hypercube <jaxopt.projection.projection_hypercube>`, by Mathieu Blondel.
- Fixed edge case in :class:`jaxopt.Bisection`, by Mathieu Blondel.
- Replaced deprecated tree_multimap with tree_map, by Fan Yang.
- Added support for leaf cond pytrees in :func:`tree_where <jaxopt.tree_util.tree_where>`, by Felipe Llinares.
- Added Python 3.10 support officially, by Jeppe Klitgaard.
- Replaced deprecated tree_multimap with tree_map, by Fan Yang.
- In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig.
- Converted the "Resnet" and "Adversarial Training" examples to notebooks, by Fabian Pedregosa.

Contributors
~~~~~~~~~~~~

Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig.

Version 0.3.1.
--------------
Expand Down
19 changes: 10 additions & 9 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
:github_url: https://github.com/google/jaxopt/tree/master/docs

JAXopt Documentation
====================
JAXopt
======

Hardware accelerated, batchable and differentiable optimizers in
`JAX <https://github.com/google/jax>`_.
Expand Down Expand Up @@ -36,15 +36,15 @@ Alternatively, it can be be installed from sources with the following command::
basics
unconstrained
constrained
nonlinear_least_squares
quadratic_programming
non_smooth
stochastic
objective_and_loss
linear_system_solvers
root_finding
fixed_point
nonlinear_least_squares
linear_system_solvers
implicit_diff
objective_and_loss
line_search
developer

Expand Down Expand Up @@ -92,12 +92,13 @@ its implicit differentiation framework:
.. code-block:: bibtex
@article{jaxopt_implicit_diff,
title={Efficient and Modular Implicit Differentiation},
author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy
title={Efficient and Modular Implicit Differentiation},
author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy
and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian
and Vert, Jean-Philippe},
journal={arXiv preprint arXiv:2105.15183},
year={2021}
journal={arXiv preprint arXiv:2105.15183},
year={2021}
}
Indices and tables
Expand Down
56 changes: 24 additions & 32 deletions docs/nonlinear_least_squares.rst
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@

.. _nonlinear_least_squares:

Least squares optimization
==========================
Nonlinear least squares
=======================

This section is concerned with problems of the form

.. math::
\min_{x} f(x) = \frac{1}{2} * ||\textbf{r}(x, \theta)||^2=\sum_{i=1}^m r_i(x_1,...,x_n)^2,
\min_{x} \frac{1}{2} ||\textbf{r}(x, \theta)||^2,
where :math:`r \colon \mathbb{R}^n \to \mathbb{R}^m` is :math:`r(x, \theta)` is
:math:`r(x, \theta)` is differentiable (almost everywhere), :math:`x` are the
where :math:`\textbf{r}` is is a residual function, :math:`x` are the
parameters with respect to which the function is minimized, and :math:`\theta`
are optional additional arguments.

Gauss Newton
Gauss-Newton
------------

.. autosummary::
:toctree: _autosummary

jaxopt.GaussNewton

We can use the Gauss-Newton method, which is the standard approach for nonlinear least squares problems.

Update equation
~~~~~~~~~~~~~~~

Expand All @@ -38,25 +39,26 @@ parameters.
Instantiating and running the solver
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To solve nonlinear least squares optimization problems, we can use Gauss Newton
method, which is the standard approach for nonlinear least squares problems ::
As an example, let us see how to minimize the Rosenbrock residual function::

from jaxopt import GaussNewton

gn = GaussNewton(residual_fun=fun)
gn_sol = gn.run(x_init, *args, **kwargs).params
def rosenbrock(x):
return np.array([10 * (x[1] - x[0]**2), (1 - x[0])])

As an example, consider the Rosenbrock residual function ::
gn = GaussNewton(residual_fun=rosenbrock)
gn_sol = gn.run(x_init).params

def rosenbrock_res_fun(x):
return np.array([10 * (x[1] - x[0]**2), (1 - x[0])]).

The function can take arguments, for example for fitting a double exponential ::
The residual function may take additional arguments, for example for fitting a double exponential::

def double_exponential_fit(x, x_data, y_data):
def double_exponential(x, x_data, y_data):
return y_data - (x[0] * jnp.exp(-x[2] * x_data) + x[1] * jnp.exp(
-x[3] * x_data)).

gn = GaussNewton(residual_fun=double_exponential)
gn_sol = gn.run(x_init, x_data=x_data, y_data=y_data).params

Differentiation
~~~~~~~~~~~~~~~

Expand All @@ -65,10 +67,10 @@ with respect to some hyperparameters. Continuing the previous example, we can
now differentiate the solution w.r.t. ``y``::

def solution(y):
gn = GaussNewton(residual_fun=fun)
lm_sol = lm.run(x_init, X, y).params
gn = GaussNewton(residual_fun=double_exponential)
lm_sol = lm.run(x_init, x_data, y).params

print(jax.jacobian(solution)(y))
print(jax.jacobian(solution)(y_data))

Under the hood, we use the implicit function theorem if ``implicit_diff=True``
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
Expand All @@ -82,6 +84,9 @@ Levenberg Marquardt

jaxopt.LevenbergMarquardt

We can also use the Levenberg-Marquardt method, which is a more advanced method compared to Gauss-Newton, in
that it regularizes the update equation. It helps for cases where Gauss-Newton method fails to converge.

Update equation
~~~~~~~~~~~~~~~

Expand All @@ -92,17 +97,4 @@ parameters:
(\mathbf{J} \mathbf{J^T} + \mu\mathbf{I}) h_{lm} = - \mathbf{J^T} \mathbf{r}
where :math:`\mathbf{J}` is the Jacobian of the residual function w.r.t.
parameters and :math:`\mu` is the damping parameter..

Instantiating and running the solver
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To solve nonlinear least squares optimization problems, we can use Levenberg
Marquardt method, which is a more advanced method compared to Gauss Newton, in
that it regularizes the update equation which helps for cases where Gauss
Newton method fails to converge ::

from jaxopt import LevenbergMarquardt

lm = LevenbergMarquardt(residual_fun=fun)
lm_sol = lm.run(x_init, X, y).params
parameters and :math:`\mu` is the damping parameter.
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ipython>=7.20.0
ipykernel>=5.5.0
sphinx-gallery>=0.9.0
sphinx_copybutton>=0.4.0
sphinx-remove-toctrees>=0.0.3
jupyter-sphinx>=0.3.2
myst-nb
tensorflow-datasets
tensorflow-datasets
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.3.1"
__version__ = "0.4"

0 comments on commit df8c0b6

Please sign in to comment.