Skip to content

Commit

Permalink
Merge pull request #394 from mblondel:release_0.6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 508371158
  • Loading branch information
JAXopt authors committed Feb 9, 2023
2 parents 0c8b25b + 730b5a6 commit e1d8355
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Line search
:toctree: _autosummary

jaxopt.BacktrackingLineSearch

jaxopt.HagerZhangLineSearch


Perturbed optimizers
Expand Down
38 changes: 38 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
Changelog
=========

Version 0.6
-----------

New features
~~~~~~~~~~~~

- Added new Hager-Zhang linesearch in LBFGS, by Srinivas Vasudevan (code review by Emily Fertig).
- Added perceptron and hinge losses, by Quentin Berthet.
- Added binary sparsemax loss, sparse_plus and sparse_sigmoid, by Vincent Roulet.
- Added isotonic regression, by Michael Sander.

Bug fixes and enhancements
~~~~~~~~~~~~~~~~~~~~~~~~~~

- Added TPU support to notebooks, by Ayush Shridhar.
- Allowed users to restart from a previous optimizer state in LBFGS, by Zaccharie Ramzi.
- Added faster error computation in gradient descent algorithm, by Zaccharie Ramzi.
- Got rid of extra function call in BFGS and LBFGS, by Zaccharie Ramzi.
- Improved dtype consistency between input and output of update method, by Mathieu Blondel.
- Added perturbed optimizers notebook and narrative documentation, by Quentin Berthet and Fabian Pedregosa.
- Enabled auxiliary value returned by linesearch methods, by Zaccharie Ramzi.
- Added distributed examples to the website, by Fabian Pedregosa.
- Added Custom loop pjit example, by Felipe Llinares.
- Fixed wrong latex in maml.ipynb, by Fabian Pedregosa.
- Fixed bug in backtracking line search, by Srinivas Vasudevan (code review by Emily Fertig).
- Added pylintrc to top level directory, by Fabian Pedregosa.
- Corrected the condition function in LBFGS, by Zaccharie Ramzi.
- Added custom loop pmap example, by Felipe Llinares.
- Fixed pytree support in IterativeRefinement, by Louis Béthune.
- Fixed has_aux support in ArmijoSGD, by Louis Béthune.
- Documentation improvements, by Fabian Pedregosa and Mathieu Blondel.

Contributors
~~~~~~~~~~~~

Ayush Shridhar, Fabian Pedregosa, Felipe Llinares, Louis Bethune,
Mathieu Blondel, Michael Sander, Quentin Berthet, Srinivas Vasudevan, Vincent Roulet, Zaccharie Ramzi.

Version 0.5.5
-------------

Expand Down
1 change: 1 addition & 0 deletions docs/line_search.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Algorithms
:toctree: _autosummary

jaxopt.BacktrackingLineSearch
jaxopt.HagerZhangLineSearch

The :class:`BacktrackingLineSearch <jaxopt.BacktrackingLineSearch>` algorithm
iteratively reduces the step size by some decrease factor until the conditions
Expand Down
10 changes: 8 additions & 2 deletions docs/objective_and_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ Binary classification
Binary classification losses are of the form ``loss(int: label, float: score) -> float``,
where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output.

The following utility functions are useful for the binary sparsemax loss.

.. autosummary::
:toctree: _autosummary

jaxopt.loss.sparse_plus
jaxopt.loss.sparse_sigmoid

Multiclass classification
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -79,5 +87,3 @@ Other functions
jaxopt.objective.multiclass_logreg_with_intercept
jaxopt.objective.l2_multiclass_logreg
jaxopt.objective.l2_multiclass_logreg_with_intercept
jaxopt.loss.sparse_plus
jaxopt.loss.sparse_sigmoid
34 changes: 18 additions & 16 deletions jaxopt/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,58 +74,60 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
loss value
References:
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Vlad Niculae. JMLR 2020. (Sec. 4.4)
"""
return sparse_plus(jnp.where(label, -logit, logit))


def sparse_plus(x: float) -> float:
"""Sparse plus function.
r"""Sparse plus function.
Computes the function:
.. math:
\mathrm{sparseplus}(x) = \begin{cases}
.. math::
\mathrm{sparse\_plus}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
x, & 1 \leq x
\end{cases}
This is the twin function of the softplus activation ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, monotonic by an adequate definition between
This is the twin function of the softplus activation ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, monotonic by an adequate definition between
-1 and 1.
Args:
x: input (float)
Returns:
sparseplus(x) as defined above
sparse_plus(x) as defined above
"""
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))


def sparse_sigmoid(x: float) -> float:
"""Sparse sigmoid function.
r"""Sparse sigmoid function.
Computes the function:
Computes the function:
.. math::
.. math:
\mathrm{sparsesigmoid}(x) = \begin{cases}
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{2}(x+1), & -1 < x < 1 \\
\frac{1}{2}(x+1), & -1 < x < 1 \\
1, & 1 \leq x
\end{cases}
This is the twin function of the sigmoid activation ensuring a zero output
for inputs less than -1, a 1 ouput for inputs greater than 1, and a linear
output for inputs between -1 and 1. This is the derivative of the sparse
output for inputs between -1 and 1. This is the derivative of the sparse
plus function.
Args:
x: input (float)
Returns:
sparsesigmoid(x) as defined above
sparse_sigmoid(x) as defined above
"""
return 0.5 * projection_hypercube(x + 1.0, 2.0)

Expand Down
2 changes: 1 addition & 1 deletion jaxopt/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

"""JAXopt version."""

__version__ = "0.5.5"
__version__ = "0.6"

0 comments on commit e1d8355

Please sign in to comment.