-
Notifications
You must be signed in to change notification settings - Fork 37
/
losses.py
164 lines (126 loc) · 5.35 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
from scipy import ndimage
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
class Losses(object):
def __init__(self, *args, **kwargs): #loss_type, size_average=None, reduce=None, reduction='mean', *args, **kwargs):
"""
Class used to initialize and handle all available loss types in ViP
Args:
loss_type (String): String indicating which custom loss function is to be loaded.
Return:
Loss object
"""
self.loss_type = kwargs['loss_type']
self.loss_object = None
if self.loss_type == 'MSE':
self.loss_object = MSE(*args, **kwargs)
elif self.loss_type == 'M_XENTROPY':
self.loss_object = M_XENTROPY(*args, **kwargs)
elif self.loss_type == 'YC2BB_Attention_Loss':
self.loss_object = YC2BB_Attention_Loss(*args, **kwargs)
else:
print('Invalid loss type selected. Quitting!')
exit(1)
def loss(self, predictions, data, **kwargs):
"""
Function that calculates loss from selected loss type
Args:
predictions (Tensor, shape [N,*]): Tensor output by the network
target (Tensor, shape [N,*]): Target tensor used with predictions to compute the loss
Returns:
Calculated loss value
"""
return self.loss_object.loss(predictions, data, **kwargs)
class MSE():
def __init__(self, *args, **kwargs):
"""
Mean squared error (squared L2 norm) between predictions and target
Args:
reduction (String): 'none', 'mean', 'sum' (see PyTorch Docs). Default: 'mean'
device (String): 'cpu' or 'cuda'
Returns:
None
"""
reduction = 'mean' if 'reduction' not in kwargs else kwargs['reduction']
self.device = kwargs['device']
self.mse_loss = torch.nn.MSELoss(reduction=reduction)
def loss(self, predictions, data):
"""
Args:
predictions (Tensor, shape [N,*]): Output by the network
data (dictionary)
- labels (Tensor, shape [N,*]): Targets from ground truth data
Returns:
Return mean squared error loss
"""
targets = data['labels'].to(self.device)
return self.mse_loss(predictions, targets)
class M_XENTROPY(object):
def __init__(self, *args, **kwargs):
"""
Cross-entropy Loss with a distribution of values, not just 1-hot vectors
Args:
dim (integer): Dimension to reduce
Returns:
None
"""
self.logsoftmax = nn.LogSoftmax(dim=1)
def loss(self, predictions, data):
"""
Args:
predictions (Tensor, shape [N,*]): Output by the network
data (dictionary)
- labels (Tensor, shape [N,*]): Targets from ground truth data
Return:
Cross-entropy loss
"""
targets = data['labels']
one_hot = np.zeros((targets.shape[0], predictions.shape[1]))
one_hot[np.arange(targets.shape[0]), targets.cpu().numpy().astype('int32')[:, -1]] = 1
one_hot = torch.Tensor(one_hot).cuda()
return torch.mean(torch.sum(-one_hot * self.logsoftmax(predictions), dim=1))
#Code source: https://github.com/MichiganCOG/Video-Grounding-from-Text/blob/master/train.py
class YC2BB_Attention_Loss(object):
def __init__(self, *args, **kwargs):
"""
Frame-wise attention loss used in Weakly-Supervised Object Video Grounding...
https://arxiv.org/pdf/1805.02834.pdf
Weakly-supervised, no groundtruth labels are used.
"""
self.loss_weighting = kwargs['has_loss_weighting']
self.obj_interact = kwargs['obj_interact']
self.ranking_margin = kwargs['ranking_margin']
self.loss_factor = kwargs['loss_factor']
def loss(self, predictions, data):
"""
Args:
predictions (List):
- output (Tensor, shape [2*T, 2]): Positive and negative attention weights for each sample
- loss_weigh (Tensor, shape [2*T, 1]): Loss weighting applied to each sampled frame
data (None)
T: number of sampled frames from video (default: 5)
Return:
Frame-wise weighting loss
"""
output, loss_weigh = predictions
if self.loss_weighting or self.obj_interact:
rank_batch = F.margin_ranking_loss(output[:,0:1], output[:,1:2],
torch.ones(output.size()).type(output.data.type()), margin=self.ranking_margin, reduction='none')
if self.loss_weighting and self.obj_interact:
loss_weigh = (output[:, 0:1]+loss_weigh)/2. # avg
elif self.loss_weighting:
loss_weigh = output[:,0:1]
else:
loss_weigh = loss_weigh.unsqueeze(1)
# ranking loss
cls_loss = self.loss_factor*(rank_batch*loss_weigh).mean()+ \
(1-self.loss_factor)*-torch.log(2*loss_weigh).mean()
else:
# ranking loss
cls_loss = F.margin_ranking_loss(output[:,0:1], output[:,1:2],
torch.Tensor([[1],[1]]).type(output.data.type()), margin=self.ranking_margin)
return cls_loss