-
Notifications
You must be signed in to change notification settings - Fork 6
/
losses.py
98 lines (72 loc) · 2.89 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# boundary loss code from https://github.com/yiskw713/boundary_loss_for_remote_sensing
import torch
import torch.nn as nn
import torch.nn.functional as F
def one_hot(label, n_classes, requires_grad=True):
"""Return One Hot Label"""
device = label.device
one_hot_label = torch.eye(
n_classes, device=device, requires_grad=requires_grad)[label]
one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)
return one_hot_label
class BoundaryLoss(nn.Module):
"""Boundary Loss proposed in:
Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation
https://arxiv.org/abs/1905.07852
"""
def __init__(self, theta0=3, theta=5):
super().__init__()
self.theta0 = theta0
self.theta = theta
def forward(self, pred, gt):
"""
Input:
- pred: the output from model (before softmax)
shape (N, C, H, W)
- gt: ground truth map
shape (N, H, w)
Return:
- boundary loss, averaged over mini-bathc
"""
n, c, _, _ = pred.shape
# softmax so that predicted map can be distributed in [0, 1]
pred = torch.softmax(pred, dim=1)
# one-hot vector of ground truth
one_hot_gt = one_hot(gt, c)
# boundary map
gt_b = F.max_pool2d(
1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
gt_b -= 1 - one_hot_gt
pred_b = F.max_pool2d(
1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
pred_b -= 1 - pred
# extended boundary map
gt_b_ext = F.max_pool2d(
gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
pred_b_ext = F.max_pool2d(
pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
# reshape
gt_b = gt_b.view(n, c, -1)
pred_b = pred_b.view(n, c, -1)
gt_b_ext = gt_b_ext.view(n, c, -1)
pred_b_ext = pred_b_ext.view(n, c, -1)
# Precision, Recall
P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)
# Boundary F1 Score
BF1 = 2 * P * R / (P + R + 1e-7)
# summing BF1 Score for each class and average over mini-batch
loss = torch.mean(1 - BF1)
return loss
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, y_pred, y_true):
smooth = 1.
y_pred = torch.sigmoid(y_pred)[:, 1, :, :]
y_true = (y_true > 0.5).float()
intersection = torch.sum(y_pred * y_true)
union = torch.sum(y_pred) + torch.sum(y_true)
dice = (2.0 * intersection + smooth) / (union + smooth)
loss = 1 - dice
return loss