Skip to content

Commit

Permalink
Merge pull request #85 from mblondel:version_0.1.1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 404245958
  • Loading branch information
JAXopt authors committed Oct 19, 2021
2 parents 9a318e6 + f629fe0 commit 9692f90
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
29 changes: 21 additions & 8 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
Changelog
=========

Version 0.1.1 (development version)
-----------------------------------

- :class:`jaxopt.ArmijoSGD`
- :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py`
- :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py`


Version 0.1.1
-------------

New features
~~~~~~~~~~~~

- Added solver :class:`jaxopt.ArmijoSGD`
- Added example :ref:`sphx_glr_auto_examples_fixed_point_deep_equilibrium_model.py`
- Added example :ref:`sphx_glr_auto_examples_deep_learning_plot_sgd_solvers.py`

Bug fixes
~~~~~~~~~

- Allow non-jittable proximity operators in :class:`jaxopt.ProximalGradient`
- Raise an exception if a quadratic program is infeasible or unbounded

Contributors
~~~~~~~~~~~~

Fabian Pedregosa, Louis Bethune, Mathieu Blondel.

Version 0.1 (initial release)
-----------------------------

Expand Down
14 changes: 8 additions & 6 deletions examples/deep_learning/plot_sgd_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

r"""
Comparison of different GD algorithms.
======================================
Comparison of different SGD algorithms.
=======================================
The purpose of this example is to illustrate the power
of adaptive stepsize algorithms.
Expand All @@ -29,9 +29,11 @@
* SGD with constant stepsize
* RMSprop
The reported ``training loss`` is an estimation of the true training loss based on the current minibatch.
This experiment was conducted without momentum, with popular default values for learning rate.
The reported ``training loss`` is an estimation of the true training loss based
on the current minibatch.
This experiment was conducted without momentum, with popular default values for
learning rate.
"""

from absl import flags
Expand Down Expand Up @@ -112,7 +114,7 @@ def main(argv):
# manual flags parsing to avoid conflicts between absl.app.run and sphinx-gallery
flags.FLAGS(argv)
FLAGS = flags.FLAGS

train_ds, ds_info = load_dataset(FLAGS.dataset, FLAGS.batch_size)

# Initialize parameters.
Expand Down
17 changes: 9 additions & 8 deletions examples/fixed_point/deep_equilibrium_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
"""
Deep Equilibrium (DEQ) model in Flax with Anderson acceleration.
================================================================
This implementation is strongly inspired by the Pytorch code snippets in [3].
This implementation is strongly inspired by the Pytorch code snippets in [3].
A similar model called "implicit deep learning" is also proposed in [2].
In practice BatchNormalization and initialization of weights in convolutions are
important to ensure convergence.
important to ensure convergence.
[1] Bai, S., Kolter, J.Z. and Koltun, V., 2019. Deep Equilibrium Models.
Advances in Neural Information Processing Systems, 32, pp.690-701.
Expand Down Expand Up @@ -136,7 +137,7 @@ def block_apply(z, x, block_params):
solver = self.fixed_point_solver(fixed_point_fun=block_apply)
def batch_run(x, block_params):
return solver.run(x, x, block_params)[0]

return jax.vmap(batch_run, in_axes=(0,None), out_axes=0)(x, block_params)


Expand All @@ -147,7 +148,7 @@ class FullDEQ(nn.Module):
fixed_point_solver: Callable

@nn.compact
def __call__(self, x, train):
def __call__(self, x, train):
x = nn.Conv(features=self.channels, kernel_size=(3,3), use_bias=True, padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5)(x)
block = ResNetBlock(self.channels, self.channels_bottleneck)
Expand Down Expand Up @@ -261,7 +262,7 @@ def jitted_update(params, state, batch_stats, data):
batch_stats = state.aux['batch_stats']
print_accuracy(params, state)
params, state = jitted_update(params, state, batch_stats, next(train_ds))


if __name__ == "__main__":
app.run(main)
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.1"
__version__ = "0.1.1"

0 comments on commit 9692f90

Please sign in to comment.