forked from google/youtube-8m
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
102 lines (84 loc) · 4 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides definitions for non-regularized training or test losses."""
import tensorflow as tf
class BaseLoss(object):
"""Inherit from this class when implementing new losses."""
def calculate_loss(self, unused_predictions, unused_labels, **unused_params):
"""Calculates the average loss of the examples in a mini-batch.
Args:
unused_predictions: a 2-d tensor storing the prediction scores, in which
each row represents a sample in the mini-batch and each column
represents a class.
unused_labels: a 2-d tensor storing the labels, which has the same shape
as the unused_predictions. The labels must be in the range of 0 and 1.
unused_params: loss specific parameters.
Returns:
A scalar loss tensor.
"""
raise NotImplementedError()
class CrossEntropyLoss(BaseLoss):
"""Calculate the cross entropy loss between the predictions and labels."""
def calculate_loss(self,
predictions,
labels,
label_weights=None,
**unused_params):
with tf.name_scope("loss_xent"):
epsilon = 1e-5
float_labels = tf.cast(labels, tf.float32)
cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + (
1 - float_labels) * tf.math.log(1 - predictions + epsilon)
cross_entropy_loss = tf.negative(cross_entropy_loss)
if label_weights is not None:
cross_entropy_loss *= label_weights
return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
class HingeLoss(BaseLoss):
"""Calculate the hinge loss between the predictions and labels.
Note the subgradient is used in the backpropagation, and thus the optimization
may converge slower. The predictions trained by the hinge loss are between -1
and +1.
"""
def calculate_loss(self, predictions, labels, b=1.0, **unused_params):
with tf.name_scope("loss_hinge"):
float_labels = tf.cast(labels, tf.float32)
all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32)
all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32)
sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones)
hinge_loss = tf.maximum(
all_zeros,
tf.scalar_mul(b, all_ones) - sign_labels * predictions)
return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1))
class SoftmaxLoss(BaseLoss):
"""Calculate the softmax loss between the predictions and labels.
The function calculates the loss in the following way: first we feed the
predictions to the softmax activation function and then we calculate
the minus linear dot product between the logged softmax activations and the
normalized ground truth label.
It is an extension to the one-hot label. It allows for more than one positive
labels for each sample.
"""
def calculate_loss(self, predictions, labels, **unused_params):
with tf.name_scope("loss_softmax"):
epsilon = 10e-8
float_labels = tf.cast(labels, tf.float32)
# l1 normalization (labels are no less than 0)
label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True),
epsilon)
norm_float_labels = tf.div(float_labels, label_rowsum)
softmax_outputs = tf.nn.softmax(predictions)
softmax_loss = tf.negative(
tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)),
1))
return tf.reduce_mean(softmax_loss)