forked from Mhaiyang/CVPR2021_PFNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
52 lines (42 loc) · 1.71 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
@Time : 2021/7/6 14:31
@Author : Haiyang Mei
@E-mail : [email protected]
@Project : CVPR2021_PFNet
@File : loss.py
@Function: Loss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################
# ########################## iou loss #############################
###################################################################
class IOU(torch.nn.Module):
def __init__(self):
super(IOU, self).__init__()
def _iou(self, pred, target):
pred = torch.sigmoid(pred)
inter = (pred * target).sum(dim=(2, 3))
union = (pred + target).sum(dim=(2, 3)) - inter
iou = 1 - (inter / union)
return iou.mean()
def forward(self, pred, target):
return self._iou(pred, target)
###################################################################
# #################### structure loss #############################
###################################################################
class structure_loss(torch.nn.Module):
def __init__(self):
super(structure_loss, self).__init__()
def _structure_loss(self, pred, mask):
weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
pred = torch.sigmoid(pred)
inter = ((pred * mask) * weit).sum(dim=(2, 3))
union = ((pred + mask) * weit).sum(dim=(2, 3))
wiou = 1 - (inter) / (union - inter)
return (wbce + wiou).mean()
def forward(self, pred, mask):
return self._structure_loss(pred, mask)