Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506873395
  • Loading branch information
q-berthet authored and JAXopt authors committed Feb 3, 2023
1 parent 0b4985c commit 60259a4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 15 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ Loss functions

jaxopt.loss.binary_logistic_loss
jaxopt.loss.binary_sparsemax_loss
jaxopt.loss.binary_hinge_loss
jaxopt.loss.binary_perceptron_loss
jaxopt.loss.sparse_plus
jaxopt.loss.sparse_sigmoid
jaxopt.loss.huber_loss
jaxopt.loss.multiclass_logistic_loss
jaxopt.loss.multiclass_sparsemax_loss
jaxopt.loss.multiclass_hinge_loss
jaxopt.loss.multiclass_perceptron_loss

Linear system solving
---------------------
Expand Down
4 changes: 4 additions & 0 deletions docs/objective_and_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Binary classification

jaxopt.loss.binary_logistic_loss
jaxopt.loss.binary_sparsemax_loss
jaxopt.loss.binary_hinge_loss
jaxopt.loss.binary_perceptron_loss

Binary classification losses are of the form ``loss(int: label, float: score) -> float``,
where ``label`` is the ground-truth (``0`` or ``1``) and ``score`` is the model's output.
Expand All @@ -35,6 +37,8 @@ Multiclass classification

jaxopt.loss.multiclass_logistic_loss
jaxopt.loss.multiclass_sparsemax_loss
jaxopt.loss.multiclass_hinge_loss
jaxopt.loss.multiclass_perceptron_loss

Multiclass classification losses are of the form ``loss(int: label, jnp.ndarray: scores) -> float``,
where ``label`` is the ground-truth (between ``0`` and ``n_classes - 1``) and
Expand Down
74 changes: 71 additions & 3 deletions jaxopt/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def binary_sparsemax_loss(label: int, logit: float) -> float:
Vlad Niculae. JMLR 2020. (Sec. 4.4)
"""
return sparse_plus(jnp.where(label, -logit, logit))


def sparse_plus(x: float) -> float:
"""Sparse plus function.
Computes the function:
.. math:
Expand All @@ -107,7 +107,7 @@ def sparse_plus(x: float) -> float:

def sparse_sigmoid(x: float) -> float:
"""Sparse sigmoid function.
Computes the function:
.. math:
Expand All @@ -130,6 +130,37 @@ def sparse_sigmoid(x: float) -> float:
return 0.5 * projection_hypercube(x + 1.0, 2.0)


def binary_hinge_loss(label: int, score: float) -> float:
"""Binary hinge loss.
Args:
label: ground-truth integer label (0 or 1).
score: score produced by the model (float).
Returns:
loss value.
References:
https://en.wikipedia.org/wiki/Hinge_loss
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, 1 - score * signed_label)


def binary_perceptron_loss(label: int, score: float) -> float:
"""Binary perceptron loss.
Args:
label: ground-truth integer label (0 or 1).
score: score produced by the model (float).
Returns:
loss value.
References:
https://en.wikipedia.org/wiki/Perceptron
"""
signed_label = 2.0 * label - 1.0
return jnp.maximum(0, - score * signed_label)

# Multiclass classification.


Expand Down Expand Up @@ -174,6 +205,43 @@ def multiclass_sparsemax_loss(label: int, scores: jnp.ndarray) -> float:
+ 0.5 * (1.0 - jnp.dot(proba, proba)))


def multiclass_hinge_loss(label: int,
scores: jnp.ndarray) -> float:
"""Multiclass hinge loss.
Args:
label: ground-truth integer label.
scores: scores produced by the model (floats).
Returns:
loss value
References:
https://en.wikipedia.org/wiki/Hinge_loss
"""
one_hot_label = jax.nn.one_hot(label, scores.shape[0])
return jnp.max(scores + 1.0 - one_hot_label) - jnp.dot(scores, one_hot_label)


def multiclass_perceptron_loss(label: int,
scores: jnp.ndarray) -> float:
"""Binary perceptron loss.
Args:
label: ground-truth integer label.
scores: score produced by the model (float).
Returns:
loss value.
References:
Michael Collins. Discriminative training methods for Hidden Markov Models:
Theory and experiments with perceptron algorithms. EMNLP 2002
"""
one_hot_label = jax.nn.one_hot(label, scores.shape[0])
return jnp.max(scores) - jnp.dot(scores, one_hot_label)

# Fenchel-Young losses


def make_fenchel_young_loss(max_fun: Callable[[jnp.array], float]):
"""Creates a Fenchel-Young loss from a max function.
Expand Down
6 changes: 5 additions & 1 deletion jaxopt/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@
from jaxopt._src.loss import huber_loss
from jaxopt._src.loss import make_fenchel_young_loss
from jaxopt._src.loss import multiclass_logistic_loss
from jaxopt._src.loss import multiclass_sparsemax_loss
from jaxopt._src.loss import multiclass_sparsemax_loss
from jaxopt._src.loss import binary_hinge_loss
from jaxopt._src.loss import binary_perceptron_loss
from jaxopt._src.loss import multiclass_hinge_loss
from jaxopt._src.loss import multiclass_perceptron_loss
63 changes: 52 additions & 11 deletions tests/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,28 @@ def reference_impl(label: int, logit: float) -> float:
elif scores >= 1.0:
return scores
else:
return (scores + 1.0)**2/4
return (scores + 1.0) ** 2 / 4

self._test_binary_loss_function(loss.binary_sparsemax_loss, loss.sparse_sigmoid,
self._test_binary_loss_function(
loss.binary_sparsemax_loss, loss.sparse_sigmoid, reference_impl
)

def test_binary_hinge_loss(self):
def reference_impl(label: int, logit: float) -> float:
return jax.nn.relu(1 - logit * (2.0 * label - 1.0))
self._test_binary_loss_function(loss.binary_hinge_loss, jnp.sign,
reference_impl)

def _test_multiclass_loss_function(self, loss_fun, inv_link_fun,
reference_impl):
def test_perceptron_loss(self):
def reference_impl(label: int, logit: float) -> float:
return jax.nn.relu(- logit * (2.0 * label - 1.0))
self._test_binary_loss_function(loss.binary_perceptron_loss, jnp.sign,
reference_impl)

def _test_multiclass_loss_function(
self, loss_fun, inv_link_fun, reference_impl, large_inputs_behavior=True,
incorrect_label_infty=True
):
# Check that loss is zero when all weights goes to the correct label.
loss_val = loss_fun(0, jnp.array([1e5, 0, 0]))
self.assertEqual(loss_val, 0)
Expand Down Expand Up @@ -138,14 +153,16 @@ def _test_multiclass_loss_function(self, loss_fun, inv_link_fun,
self.assertAllClose(loss_val, expected)

# Check that correct value is obtained for large inputs.
loss_val = loss_fun(0, jnp.array([1e9, 1e9]))
expected = loss_fun(0, jnp.array([1, 1]))
self.assertAllClose(loss_val, expected)
if large_inputs_behavior:
loss_val = loss_fun(0, jnp.array([1e9, 1e9]))
expected = loss_fun(0, jnp.array([1, 1]))
self.assertAllClose(loss_val, expected)

# Check that -inf for incorrect label has no impact.
loss_val = loss_fun(0, jnp.array([0.0, 0.0, -jnp.inf]))
expected = loss_fun(0, jnp.array([0.0, 0.0]))
self.assertAllClose(loss_val, expected)
if incorrect_label_infty:
loss_val = loss_fun(0, jnp.array([0.0, 0.0, -jnp.inf]))
expected = loss_fun(0, jnp.array([0.0, 0.0]))
self.assertAllClose(loss_val, expected)
# Check that -inf for correct label results in infinite loss.
loss_val = loss_fun(0, jnp.array([-jnp.inf, 0.0, 0.0]))
self.assertEqual(loss_val, jnp.inf)
Expand Down Expand Up @@ -186,6 +203,30 @@ def reference_impl(label, scores):
projection.projection_simplex,
reference_impl)

def test_multiclass_hinge_loss(self):
def reference_impl(label, scores):
one_hot_label = jax.nn.one_hot(label, scores.shape[-1])
return jnp.max(scores + 1.0 - one_hot_label) - scores[label]
def inv_link_fun(scores):
return jax.nn.one_hot(jnp.argmax(scores), scores.shape[-1])

self._test_multiclass_loss_function(loss.multiclass_hinge_loss,
inv_link_fun,
reference_impl,
large_inputs_behavior=False,
incorrect_label_infty=False)

def test_multiclass_perceptron_loss(self):
def reference_impl(label, scores):
return jnp.max(scores) - scores[label]
def inv_link_fun(scores):
return jax.nn.one_hot(jnp.argmax(scores), scores.shape[-1])

self._test_multiclass_loss_function(loss.multiclass_perceptron_loss,
inv_link_fun,
reference_impl,
incorrect_label_infty=False)

def test_huber(self):
self.assertAllClose(0.0, loss.huber_loss(0, 0, .1))
self.assertAllClose(0.0, loss.huber_loss(1, 1, .1))
Expand All @@ -212,7 +253,7 @@ def test_fenchel_young_reg(self):
y_one_hot = jax.vmap(one_hot_argmax)(theta_true)
int_one_hot = jnp.where(y_one_hot == 1.)[1]
loss_one_hot = jax.vmap(fy_loss)(y_one_hot, theta_random)
log_loss = jax.vmap(loss.multiclass_logistic_loss)(int_one_hot,
log_loss = jax.vmap(loss.multiclass_logistic_loss)(int_one_hot,
theta_random)
# Checks that the FY loss associated to logsumexp is correct.
self.assertArraysAllClose(loss_one_hot, log_loss)
Expand Down

0 comments on commit 60259a4

Please sign in to comment.