Skip to content

Commit

Permalink
Use optax losses as backend for jaxopt losses.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616117244
  • Loading branch information
mtthss authored and JAXopt authors committed Mar 27, 2024
1 parent 8da3350 commit 866f1f4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 52 deletions.
63 changes: 12 additions & 51 deletions jaxopt/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import Callable

import jax
from jax.nn import softplus
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jaxopt._src.projection import projection_simplex, projection_hypercube

from optax import losses as optax_losses


# Regression

Expand All @@ -39,10 +39,7 @@ def huber_loss(target: float, pred: float, delta: float = 1.0) -> float:
References:
https://en.wikipedia.org/wiki/Huber_loss
"""
abs_diff = jnp.abs(target - pred)
return jnp.where(abs_diff > delta,
delta * (abs_diff - .5 * delta),
0.5 * abs_diff ** 2)
return optax_losses.huber_loss(pred, target, delta)

# Binary classification.

Expand All @@ -56,12 +53,8 @@ def binary_logistic_loss(label: int, logit: float) -> float:
Returns:
loss value
"""
# Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1].
# softplus = proba * logit - xlogx(proba) - xlogx(1 - proba),
# where xlogx(proba) = proba * log(proba).
# Use -log sigmoid(logit) = softplus(-logit)
# and 1 - sigmoid(logit) = sigmoid(-logit).
return softplus(jnp.where(label, -logit, logit))
return optax_losses.sigmoid_binary_cross_entropy(
jnp.asarray(logit), jnp.asarray(label))


def binary_sparsemax_loss(label: int, logit: float) -> float:
Expand All @@ -77,33 +70,7 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Vlad Niculae. JMLR 2020. (Sec. 4.4)
"""
return sparse_plus(jnp.where(label, -logit, logit))


def sparse_plus(x: float) -> float:
r"""Sparse plus function.
Computes the function:
.. math::
\mathrm{sparse\_plus}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
x, & 1 \leq x
\end{cases}
This is the twin function of the softplus activation ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, monotonic by an adequate definition between
-1 and 1.
Args:
x: input (float)
Returns:
sparse_plus(x) as defined above
"""
return jnp.where(x <= -1.0, 0.0, jnp.where(x >= 1.0, x, (x + 1.0)**2/4))
return jax.nn.sparse_plus(jnp.where(label, -logit, logit))


def sparse_sigmoid(x: float) -> float:
Expand Down Expand Up @@ -144,8 +111,7 @@ def binary_hinge_loss(label: int, score: float) -> float:
References:
https://en.wikipedia.org/wiki/Hinge_loss
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, 1 - score * signed_label)
return optax_losses.hinge_loss(score, 2.0 * label - 1.0)


def binary_perceptron_loss(label: int, score: float) -> float:
Expand All @@ -160,8 +126,7 @@ def binary_perceptron_loss(label: int, score: float) -> float:
References:
https://en.wikipedia.org/wiki/Perceptron
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, - score * signed_label)
return optax_losses.perceptron_loss(score, 2.0 * label - 1.0)

# Multiclass classification.

Expand All @@ -175,13 +140,8 @@ def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float:
Returns:
loss value
"""
logits = jnp.asarray(logits)
# Logsumexp is the Fenchel conjugate of the Shannon negentropy on the simplex.
# logsumexp = jnp.dot(proba, logits) - jnp.dot(proba, jnp.log(proba))
# To avoid roundoff error, subtract target inside logsumexp.
# logsumexp(logits) - logits[y] = logsumexp(logits - logits[y])
logits = (logits - logits[label]).at[label].set(0.0)
return logsumexp(logits)
return optax_losses.softmax_cross_entropy_with_integer_labels(
jnp.asarray(logits), jnp.asarray(label))


def multiclass_sparsemax_loss(label: int, scores: jnp.ndarray) -> float:
Expand Down Expand Up @@ -272,5 +232,6 @@ def make_fenchel_young_loss(max_fun: Callable[[jnp.ndarray], float]):
"""

def fy_loss(y_true, scores, *args, **kwargs):
return max_fun(scores, *args, **kwargs) - jnp.vdot(y_true, scores)
return optax_losses.make_fenchel_young_loss(max_fun)(
scores.ravel(), y_true.ravel(), *args, **kwargs)
return fy_loss
5 changes: 4 additions & 1 deletion jaxopt/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
sparse_plus = jax.nn.sparse_plus

from jaxopt._src.loss import binary_logistic_loss
from jaxopt._src.loss import binary_sparsemax_loss, sparse_plus, sparse_sigmoid
from jaxopt._src.loss import binary_sparsemax_loss, sparse_sigmoid
from jaxopt._src.loss import huber_loss
from jaxopt._src.loss import make_fenchel_young_loss
from jaxopt._src.loss import multiclass_logistic_loss
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
jax>=0.2.18
jaxlib>=0.1.69
numpy>=1.18.4
optax>=0.2.2
scipy>=1.0.0

0 comments on commit 866f1f4

Please sign in to comment.