-
Notifications
You must be signed in to change notification settings - Fork 0
/
focal_loss.py
263 lines (243 loc) · 10.9 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import itertools
from typing import Any, Optional
import tensorflow as tf
_EPSILON = tf.keras.backend.epsilon()
def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
class_weight: Optional[Any] = None,
from_logits: bool = False, axis: int = -1
) -> tf.Tensor:
r"""Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by
introducing a hyperparameter called the *focusing parameter* that allows
hard-to-classify examples to be penalized more heavily relative to
easy-to-classify examples.
See :meth:`~focal_loss.binary_focal_loss` for a description of the focal
loss in the binary setting, as presented in the original work [1]_.
In the multiclass setting, with integer labels :math:`y`, focal loss is
defined as
.. math::
L(y, \hat{\mathbf{p}})
= -\left(1 - \hat{p}_y\right)^\gamma \log(\hat{p}_y)
where
* :math:`y \in \{0, \ldots, K - 1\}` is an integer class label (:math:`K`
denotes the number of classes),
* :math:`\hat{\mathbf{p}} = (\hat{p}_0, \ldots, \hat{p}_{K-1})
\in [0, 1]^K` is a vector representing an estimated probability
distribution over the :math:`K` classes,
* :math:`\gamma` (gamma, not :math:`y`) is the *focusing parameter* that
specifies how much higher-confidence correct predictions contribute to
the overall loss (the higher the :math:`\gamma`, the higher the rate at
which easy-to-classify examples are down-weighted).
The usual multiclass softmax cross-entropy loss is recovered by setting
:math:`\gamma = 0`.
Parameters
----------
y_true : tensor-like
Integer class labels.
y_pred : tensor-like
Either probabilities or logits, depending on the `from_logits`
parameter.
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether `y_pred` contains logits or probabilities.
axis : int, optional
Channel axis in the `y_pred` tensor.
Returns
-------
:class:`tf.Tensor`
The focal loss for each example.
Examples
--------
This function computes the per-example focal loss between a one-dimensional
integer label vector and a two-dimensional prediction matrix:
>>> import numpy as np
>>> from focal_loss import sparse_categorical_focal_loss
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2)
>>> np.set_printoptions(precision=3)
>>> print(loss.numpy())
[0.009 0.032 0.082]
Warnings
--------
This function does not reduce its output to a scalar, so it cannot be passed
to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the
wrapper class :class:`~focal_loss.SparseCategoricalFocalLoss`.
References
----------
.. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for
dense object detection. IEEE Transactions on Pattern Analysis and
Machine Intelligence, 2018.
(`DOI <https://doi.org/10.1109/TPAMI.2018.2858826>`__)
(`arXiv preprint <https://arxiv.org/abs/1708.02002>`__)
See Also
--------
:meth:`~focal_loss.SparseCategoricalFocalLoss`
A wrapper around this function that makes it a
:class:`tf.keras.losses.Loss`.
"""
# Process focusing parameter
gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32)
gamma_rank = gamma.shape.rank
scalar_gamma = gamma_rank == 0
# Process class weight
if class_weight is not None:
class_weight = tf.convert_to_tensor(class_weight,
dtype=tf.dtypes.float32)
# Process prediction tensor
y_pred = tf.convert_to_tensor(y_pred)
y_pred_rank = y_pred.shape.rank
if y_pred_rank is not None:
axis %= y_pred_rank
if axis != y_pred_rank - 1:
# Put channel axis last for sparse_softmax_cross_entropy_with_logits
perm = list(itertools.chain(range(axis),
range(axis + 1, y_pred_rank), [axis]))
y_pred = tf.transpose(y_pred, perm=perm)
elif axis != -1:
raise ValueError(
f'Cannot compute sparse categorical focal loss with axis={axis} on '
'a prediction tensor with statically unknown rank.')
y_pred_shape = tf.shape(y_pred)
# Process ground truth tensor
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64)
y_true_rank = y_true.shape.rank
if y_true_rank is None:
raise NotImplementedError('Sparse categorical focal loss not supported '
'for target/label tensors of unknown rank')
reshape_needed = (y_true_rank is not None and y_pred_rank is not None and
y_pred_rank != y_true_rank + 1)
if reshape_needed:
y_true = tf.reshape(y_true, [-1])
y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]])
if from_logits:
logits = y_pred
probs = tf.nn.softmax(y_pred, axis=-1)
else:
probs = y_pred
logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON))
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y_true,
logits=logits,
)
y_true_rank = y_true.shape.rank
probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank)
if not scalar_gamma:
gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank)
focal_modulation = (1 - probs) ** gamma
loss = focal_modulation * xent_loss
if class_weight is not None:
class_weight = tf.gather(class_weight, y_true, axis=0,
batch_dims=y_true_rank)
loss *= class_weight
if reshape_needed:
loss = tf.reshape(loss, y_pred_shape[:-1])
return loss
@tf.keras.utils.register_keras_serializable()
class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
r"""Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by
introducing a hyperparameter :math:`\gamma` (gamma), called the
*focusing parameter*, that allows hard-to-classify examples to be penalized
more heavily relative to easy-to-classify examples.
This class is a wrapper around
:class:`~focal_loss.sparse_categorical_focal_loss`. See the documentation
there for details about this loss function.
Parameters
----------
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether model prediction will be logits or probabilities.
**kwargs : keyword arguments
Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name`
or `reduction`).
Examples
--------
An instance of this class is a callable that takes a rank-one tensor of
integer class labels `y_true` and a tensor of model predictions `y_pred` and
returns a scalar tensor obtained by reducing the per-example focal loss (the
default reduction is a batch-wise average).
>>> from focal_loss import SparseCategoricalFocalLoss
>>> loss_func = SparseCategoricalFocalLoss(gamma=2)
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss_func(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.040919524>
Use this class in the :mod:`tf.keras` API like any other multiclass
classification loss function class that accepts integer labels found in
:mod:`tf.keras.losses` (e.g.,
:class:`tf.keras.losses.SparseCategoricalCrossentropy`:
.. code-block:: python
# Typical usage
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
See Also
--------
:meth:`~focal_loss.sparse_categorical_focal_loss`
The function that performs the focal loss computation, taking a label
tensor and a prediction tensor and outputting a loss.
"""
def __init__(self, gamma, class_weight: Optional[Any] = None,
from_logits: bool = False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.class_weight = class_weight
self.from_logits = from_logits
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary containing the configuration of a
layer. The same layer can be re-instantiated later (without its trained
weights) from this configuration.
Returns
-------
dict
This layer's config.
"""
config = super().get_config()
config.update(gamma=self.gamma, class_weight=self.class_weight,
from_logits=self.from_logits)
return config
def call(self, y_true, y_pred):
"""Compute the per-example focal loss.
This method simply calls
:meth:`~focal_loss.sparse_categorical_focal_loss` with the appropriate
arguments.
Parameters
----------
y_true : tensor-like, shape (N,)
Integer class labels.
y_pred : tensor-like, shape (N, K)
Either probabilities or logits, depending on the `from_logits`
parameter.
Returns
-------
:class:`tf.Tensor`
The per-example focal loss. Reduction to a scalar is handled by
this layer's
:meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method.
"""
return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred,
class_weight=self.class_weight,
gamma=self.gamma,
from_logits=self.from_logits)