-
Notifications
You must be signed in to change notification settings - Fork 0
/
CEDiceLoss.py
54 lines (46 loc) · 2.15 KB
/
CEDiceLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import os
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from torchgeometry.losses import one_hot
class CEDiceLoss(nn.Module):
def __init__(self, weights) -> None:
super(CEDiceLoss, self).__init__()
self.eps: float = 1e-5
self.weights: torch.Tensor = weights
def forward(
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) == 4:
raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
.format(input.shape))
if not input.shape[-2:] == target.shape[-2:]:
raise ValueError("input and target shapes must be the same. Got: {}"
.format(input.shape, input.shape))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}" .format(
input.device, target.device))
if not self.weights.shape[1] == input.shape[1]:
raise ValueError("The number of weights must equal the number of classes")
if not torch.sum(self.weights).item() == 1:
raise ValueError("The sum of all weights must equal 1")
# cross entropy loss
celoss = nn.CrossEntropyLoss(self.weights)(input, target)
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot = one_hot(target, num_classes=input.shape[1],
device=input.device, dtype=input.dtype)
# compute the actual dice score
dims = (2, 3)
intersection = torch.sum(input_soft * target_one_hot, dims)
cardinality = torch.sum(input_soft + target_one_hot, dims)
dice_score = 2. * intersection / (cardinality + self.eps)
dice_score = torch.sum(dice_score * self.weights, dim=1)
return torch.mean(1. - dice_score) + celoss