diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d7b0b21..a6703d2f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,14 +1,27 @@ Changelog ========= -Version 0.1.1 (development version) ------------------------------------ - -- :class:`jaxopt.ArmijoSGD` -- :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py` -- :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py` - - +Version 0.1.1 +------------- + +New features +~~~~~~~~~~~~ + +- Added solver :class:`jaxopt.ArmijoSGD` +- Added example :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py` +- Added example :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py` + +Bug fixes +~~~~~~~~~ + +- Allow non-jittable proximity operators in :class:`jaxopt.ProximalGradient` +- Raise an exception if a quadratic program is infeasible or unbounded + +Contributors +~~~~~~~~~~~~ + +Fabian Pedregosa, Louis Bethune, Mathieu Blondel. + Version 0.1 (initial release) ----------------------------- diff --git a/examples/deep_learning/plot_sgd_solvers.py b/examples/deep_learning/plot_sgd_solvers.py index 28e9369a..a618b23c 100644 --- a/examples/deep_learning/plot_sgd_solvers.py +++ b/examples/deep_learning/plot_sgd_solvers.py @@ -13,8 +13,8 @@ # limitations under the License. r""" -Comparison of different GD algorithms. -====================================== +Comparison of different SGD algorithms. +======================================= The purpose of this example is to illustrate the power of adaptive stepsize algorithms. @@ -29,9 +29,11 @@ * SGD with constant stepsize * RMSprop -The reported ``training loss`` is an estimation of the true training loss based on the current minibatch. - -This experiment was conducted without momentum, with popular default values for learning rate. +The reported ``training loss`` is an estimation of the true training loss based +on the current minibatch. + +This experiment was conducted without momentum, with popular default values for +learning rate. """ from absl import flags @@ -112,7 +114,7 @@ def main(argv): # manual flags parsing to avoid conflicts between absl.app.run and sphinx-gallery flags.FLAGS(argv) FLAGS = flags.FLAGS - + train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size) # Initialize parameters. diff --git a/examples/fixed_point/deep_equilibrium_model.py b/examples/fixed_point/deep_equilibrium_model.py index 81bb0fa9..f56ba1f4 100644 --- a/examples/fixed_point/deep_equilibrium_model.py +++ b/examples/fixed_point/deep_equilibrium_model.py @@ -15,12 +15,13 @@ """ Deep Equilibrium (DEQ) model in Flax with Anderson acceleration. ================================================================ - - -This implementation is strongly inspired by the Pytorch code snippets in [3]. - + +This implementation is strongly inspired by the Pytorch code snippets in [3]. + +A similar model called "implicit deep learning" is also proposed in [2]. + In practice BatchNormalization and initialization of weights in convolutions are -important to ensure convergence. +important to ensure convergence. [1] Bai, S., Kolter, J.Z. and Koltun, V., 2019. Deep Equilibrium Models. Advances in Neural Information Processing Systems, 32, pp.690-701. @@ -136,7 +137,7 @@ def block_apply(z, x, block_params): solver = self.fixed_point_solver(fixed_point_fun=block_apply) def batch_run(x, block_params): return solver.run(x, x, block_params)[0] - + return jax.vmap(batch_run, in_axes=(0,None), out_axes=0)(x, block_params) @@ -147,7 +148,7 @@ class FullDEQ(nn.Module): fixed_point_solver: Callable @nn.compact - def __call__(self, x, train): + def __call__(self, x, train): x = nn.Conv(features=self.channels, kernel_size=(3,3), use_bias=True, padding='SAME')(x) x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(x) block = ResNetBlock(self.channels, self.channels_bottleneck) @@ -261,7 +262,7 @@ def jitted_update(params, state, batch_stats, data): batch_stats = state.aux['batch_stats'] print_accuracy(params, state) params, state = jitted_update(params, state, batch_stats, next(train_ds)) - + if __name__ == "__main__": app.run(main) diff --git a/jaxopt/version.py b/jaxopt/version.py index d158538a..7be8372b 100644 --- a/jaxopt/version.py +++ b/jaxopt/version.py @@ -14,4 +14,4 @@ """JAXopt version.""" -__version__ = "0.1" +__version__ = "0.1.1"