-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcc.py
86 lines (70 loc) · 3.56 KB
/
mcc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
@author: Ying Jin
@contact: [email protected]
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
from ..modules.entropy import entropy
__all__ = ['MinimumClassConfusionLoss', 'ImageClassifier']
class MinimumClassConfusionLoss(nn.Module):
r"""
Minimum Class Confusion loss minimizes the class confusion in the target predictions.
You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) <https://arxiv.org/abs/1912.03699>`_
Args:
temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if
temperature is 1.0.
.. note::
Make sure that temperature is larger than 0.
Inputs: g_t
- g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
Shape:
- g_t: :math:`(minibatch, C)` where C means the number of classes.
- Output: scalar.
Examples::
>>> temperature = 2.0
>>> loss = MinimumClassConfusionLoss(temperature)
>>> # logits output from target domain
>>> g_t = torch.randn(batch_size, num_classes)
>>> output = loss(g_t)
MCC can also serve as a regularizer for existing methods.
Examples::
>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> num_classes = 2
>>> feature_dim = 1024
>>> batch_size = 10
>>> temperature = 2.0
>>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024)
>>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
>>> mcc_loss = MinimumClassConfusionLoss(temperature)
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # logits output from source domain adn target domain
>>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)
"""
def __init__(self, temperature: float):
super(MinimumClassConfusionLoss, self).__init__()
self.temperature = temperature
def forward(self, logits: torch.Tensor) -> torch.Tensor:
batch_size, num_classes = logits.shape
predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes
entropy_weight = entropy(predictions).detach()
entropy_weight = 1 + torch.exp(-entropy_weight)
entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1
class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes
class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)
mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes
return mcc_loss
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)