-
Notifications
You must be signed in to change notification settings - Fork 24
/
loss.py
22 lines (19 loc) · 783 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
class NLL_OHEM(torch.nn.NLLLoss):
"""Online hard sample mining, Needs input from nn.LogSoftmax()"""
def __init__(self, ratio):
super(NLL_OHEM, self).__init__(None, True):
self.ration = ratio
def forward(self, x, y, ratio=None):
if ratio is not None:
self.ratio = ratio
num_inst = x.size(0)
num_hns = int(self.ratio*num_inst)
x_ = x.clone()
inst_losses = torch.autograd.Variable(torch.zeros(num_inst)).cuda()
for idx, label in enumerate(y.data):
insta_losses[idx] = -x.data[idx, label]
_, idxs = inst_losses.topk(num_hns)
y_hn = y.index_select(0, idxs)
x_hn = y.index_select(0, idxs)
return torch.nn.functional.nll_loss(x_hn, y_hn)