From df8c0b647e6b7351f5d1268e80c2a4ff2a898c58 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Tue, 24 May 2022 09:53:08 +0200 Subject: [PATCH] Release 0.4. --- docs/api.rst | 9 ----- docs/changelog.rst | 18 +++++++--- docs/index.rst | 19 ++++++----- docs/nonlinear_least_squares.rst | 56 ++++++++++++++------------------ docs/requirements.txt | 3 +- jaxopt/version.py | 2 +- 6 files changed, 51 insertions(+), 56 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index ed7acb3b..8683b156 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -25,15 +25,6 @@ Constrained jaxopt.MirrorDescent jaxopt.ScipyBoundedMinimize -Least-Squares -~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: _autosummary - - jaxopt.GaussNewton - jaxopt.LevenbergMarquardt - Quadratic programming ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/changelog.rst b/docs/changelog.rst index c03a825d..24e4bab7 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,7 +1,7 @@ Changelog ========= -Main branch +Version 0.4 ----------- New features @@ -9,16 +9,26 @@ New features - Added solver :class:`jaxopt.LevenbergMarquardt`, by Amir Saadat. - Added solver :class:`jaxopt.BoxCDQP`, by Mathieu Blondel. - +- Added :func:`projection_hypercube `, by Mathieu Blondel. Bug fixes and enhancements ~~~~~~~~~~~~~~~~~~~~~~~~~~ -- Fixed ``solve_normal_cg`` when the linear operator is "nonsquare" (does not map to a space of same dimension), +- Fixed :func:`solve_normal_cg ` + when the linear operator is "nonsquare" (does not map to a space of same dimension), by Mathieu Blondel. -- Added :func:`projection_hypercube `, by Mathieu Blondel. - Fixed edge case in :class:`jaxopt.Bisection`, by Mathieu Blondel. - Replaced deprecated tree_multimap with tree_map, by Fan Yang. +- Added support for leaf cond pytrees in :func:`tree_where `, by Felipe Llinares. +- Added Python 3.10 support officially, by Jeppe Klitgaard. +- Replaced deprecated tree_multimap with tree_map, by Fan Yang. +- In scipy wrappers, converted pytree leaves to jax arrays to determine their shape/dtype, by Roy Frostig. +- Converted the "Resnet" and "Adversarial Training" examples to notebooks, by Fabian Pedregosa. + +Contributors +~~~~~~~~~~~~ + +Amir Saadat, Fabian Pedregosa, Fan Yang, Felipe Llinares, Jeppe Klitgaard, Mathieu Blondel, Roy Frostig. Version 0.3.1. -------------- diff --git a/docs/index.rst b/docs/index.rst index 95146c7e..b0a3941e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,7 +1,7 @@ :github_url: https://github.com/google/jaxopt/tree/master/docs -JAXopt Documentation -==================== +JAXopt +====== Hardware accelerated, batchable and differentiable optimizers in `JAX `_. @@ -36,15 +36,15 @@ Alternatively, it can be be installed from sources with the following command:: basics unconstrained constrained - nonlinear_least_squares quadratic_programming non_smooth stochastic - objective_and_loss - linear_system_solvers root_finding fixed_point + nonlinear_least_squares + linear_system_solvers implicit_diff + objective_and_loss line_search developer @@ -92,12 +92,13 @@ its implicit differentiation framework: .. code-block:: bibtex @article{jaxopt_implicit_diff, - title={Efficient and Modular Implicit Differentiation}, - author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy + title={Efficient and Modular Implicit Differentiation}, + author={Blondel, Mathieu and Berthet, Quentin and Cuturi, Marco and Frostig, Roy and Hoyer, Stephan and Llinares-L{\'o}pez, Felipe and Pedregosa, Fabian and Vert, Jean-Philippe}, - journal={arXiv preprint arXiv:2105.15183}, - year={2021} + journal={arXiv preprint arXiv:2105.15183}, + year={2021} + } Indices and tables diff --git a/docs/nonlinear_least_squares.rst b/docs/nonlinear_least_squares.rst index 967a03dc..e416c205 100644 --- a/docs/nonlinear_least_squares.rst +++ b/docs/nonlinear_least_squares.rst @@ -1,21 +1,20 @@ .. _nonlinear_least_squares: -Least squares optimization -========================== +Nonlinear least squares +======================= This section is concerned with problems of the form .. math:: - \min_{x} f(x) = \frac{1}{2} * ||\textbf{r}(x, \theta)||^2=\sum_{i=1}^m r_i(x_1,...,x_n)^2, + \min_{x} \frac{1}{2} ||\textbf{r}(x, \theta)||^2, -where :math:`r \colon \mathbb{R}^n \to \mathbb{R}^m` is :math:`r(x, \theta)` is -:math:`r(x, \theta)` is differentiable (almost everywhere), :math:`x` are the +where :math:`\textbf{r}` is is a residual function, :math:`x` are the parameters with respect to which the function is minimized, and :math:`\theta` are optional additional arguments. -Gauss Newton +Gauss-Newton ------------ .. autosummary:: @@ -23,6 +22,8 @@ Gauss Newton jaxopt.GaussNewton +We can use the Gauss-Newton method, which is the standard approach for nonlinear least squares problems. + Update equation ~~~~~~~~~~~~~~~ @@ -38,25 +39,26 @@ parameters. Instantiating and running the solver ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To solve nonlinear least squares optimization problems, we can use Gauss Newton -method, which is the standard approach for nonlinear least squares problems :: +As an example, let us see how to minimize the Rosenbrock residual function:: from jaxopt import GaussNewton - gn = GaussNewton(residual_fun=fun) - gn_sol = gn.run(x_init, *args, **kwargs).params + def rosenbrock(x): + return np.array([10 * (x[1] - x[0]**2), (1 - x[0])]) -As an example, consider the Rosenbrock residual function :: + gn = GaussNewton(residual_fun=rosenbrock) + gn_sol = gn.run(x_init).params - def rosenbrock_res_fun(x): - return np.array([10 * (x[1] - x[0]**2), (1 - x[0])]). -The function can take arguments, for example for fitting a double exponential :: +The residual function may take additional arguments, for example for fitting a double exponential:: - def double_exponential_fit(x, x_data, y_data): + def double_exponential(x, x_data, y_data): return y_data - (x[0] * jnp.exp(-x[2] * x_data) + x[1] * jnp.exp( -x[3] * x_data)). + gn = GaussNewton(residual_fun=double_exponential) + gn_sol = gn.run(x_init, x_data=x_data, y_data=y_data).params + Differentiation ~~~~~~~~~~~~~~~ @@ -65,10 +67,10 @@ with respect to some hyperparameters. Continuing the previous example, we can now differentiate the solution w.r.t. ``y``:: def solution(y): - gn = GaussNewton(residual_fun=fun) - lm_sol = lm.run(x_init, X, y).params + gn = GaussNewton(residual_fun=double_exponential) + lm_sol = lm.run(x_init, x_data, y).params - print(jax.jacobian(solution)(y)) + print(jax.jacobian(solution)(y_data)) Under the hood, we use the implicit function theorem if ``implicit_diff=True`` and autodiff of unrolled iterations if ``implicit_diff=False``. See the @@ -82,6 +84,9 @@ Levenberg Marquardt jaxopt.LevenbergMarquardt +We can also use the Levenberg-Marquardt method, which is a more advanced method compared to Gauss-Newton, in +that it regularizes the update equation. It helps for cases where Gauss-Newton method fails to converge. + Update equation ~~~~~~~~~~~~~~~ @@ -92,17 +97,4 @@ parameters: (\mathbf{J} \mathbf{J^T} + \mu\mathbf{I}) h_{lm} = - \mathbf{J^T} \mathbf{r} where :math:`\mathbf{J}` is the Jacobian of the residual function w.r.t. -parameters and :math:`\mu` is the damping parameter.. - -Instantiating and running the solver -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To solve nonlinear least squares optimization problems, we can use Levenberg -Marquardt method, which is a more advanced method compared to Gauss Newton, in -that it regularizes the update equation which helps for cases where Gauss -Newton method fails to converge :: - - from jaxopt import LevenbergMarquardt - - lm = LevenbergMarquardt(residual_fun=fun) - lm_sol = lm.run(x_init, X, y).params +parameters and :math:`\mu` is the damping parameter. diff --git a/docs/requirements.txt b/docs/requirements.txt index fb92e0ef..66443fcd 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,7 @@ ipython>=7.20.0 ipykernel>=5.5.0 sphinx-gallery>=0.9.0 sphinx_copybutton>=0.4.0 +sphinx-remove-toctrees>=0.0.3 jupyter-sphinx>=0.3.2 myst-nb -tensorflow-datasets \ No newline at end of file +tensorflow-datasets diff --git a/jaxopt/version.py b/jaxopt/version.py index c5e2737e..538e71d0 100644 --- a/jaxopt/version.py +++ b/jaxopt/version.py @@ -14,4 +14,4 @@ """JAXopt version.""" -__version__ = "0.3.1" +__version__ = "0.4"