-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
65 lines (56 loc) · 2.1 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
@Time : 2021/7/8 09:48
@Author : Haiyang Mei
@E-mail : [email protected]
@Project : TCSVT2021_DCENet
@File : loss.py
@Function: Loss Functions
"""
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
###################################################################
# ########################## iou loss #############################
###################################################################
def _iou(pred, target):
b = pred.shape[0]
IoU = 0.0
for i in range(0,b):
Iand1 = torch.sum(target[i,:,:,:] * pred[i,:,:,:])
Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:]) - Iand1
if Ior1:
IoU1 = Iand1 / Ior1
else:
IoU1 = 1
IoU = IoU + (1-IoU1)
return IoU / b
class IOU(torch.nn.Module):
def __init__(self):
super(IOU, self).__init__()
def forward(self, pred, target):
return _iou(pred, target)
###################################################################
# ########################## edge loss ############################
###################################################################
def cross_entropy(logits, labels):
return torch.mean((1 - labels) * logits + torch.log(1 + torch.exp(-logits)))
class EdgeLoss(nn.Module):
def __init__(self):
super().__init__()
laplace = torch.FloatTensor([[-1,-1,-1,],[-1,8,-1],[-1,-1,-1]]).view([1,1,3,3])
# filter shape in Pytorch: out_channel, in_channel, height, width
self.laplace = nn.Parameter(data=laplace, requires_grad=False)
def torchLaplace(self, x):
edge = F.conv2d(x, self.laplace, padding=1)
edge = torch.abs(torch.tanh(edge))
return edge
def forward(self, y_pred, y_true, mode=None):
y_true_edge = self.torchLaplace(y_true)
y_pred_edge = self.torchLaplace(y_pred)
edge_loss = cross_entropy(y_pred_edge, y_true_edge)
return edge_loss