forked from thuml/Transfer-Learning-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdst.py
145 lines (117 loc) · 5.95 KB
/
dst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
@author: Baixu Chen
@contact: [email protected]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.grl import WarmStartGradientReverseLayer
from tllib.modules.classifier import Classifier
class ImageClassifier(Classifier):
r"""
Classifier with non-linear pseudo head :math:`h_{\text{pseudo}}` and worst-case estimation head
:math:`h_{\text{worst}}` from `Debiased Self-Training for Semi-Supervised Learning <https://arxiv.org/abs/2202.07136>`_.
Both heads are directly connected to the feature extractor :math:`\psi`. We implement end-to-end adversarial
training procedure between :math:`\psi` and :math:`h_{\text{worst}}` by introducing a gradient reverse layer.
Note that both heads can be safely discarded during inference, and thus will introduce no inference cost.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer.
width (int, optional): Hidden dimension of the non-linear pseudo head and worst-case estimation head.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- outputs: predictions of the main head :math:`h`
- outputs_adv: predictions of the worst-case estimation head :math:`h_{\text{worst}}`
- outputs_pseudo: predictions of the pseudo head :math:`h_{\text{pseudo}}`
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- outputs, outputs_adv, outputs_pseudo: (minibatch, `num_classes`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, width=2048, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
self.pseudo_head = nn.Sequential(
nn.Linear(self.features_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, self.num_classes)
)
self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)
self.adv_head = nn.Sequential(
nn.Linear(self.features_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, self.num_classes)
)
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
f_adv = self.grl_layer(f)
outputs_adv = self.adv_head(f_adv)
outputs = self.head(f)
outputs_pseudo = self.pseudo_head(f)
if self.training:
return outputs, outputs_adv, outputs_pseudo
else:
return outputs
def get_parameters(self, base_lr=1.0):
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
{"params": self.pseudo_head.parameters(), "lr": 1.0 * base_lr},
{"params": self.adv_head.parameters(), "lr": 1.0 * base_lr}
]
return params
def step(self):
self.grl_layer.step()
def shift_log(x, offset=1e-6):
"""
First shift, then calculate log for numerical stability.
"""
return torch.log(torch.clamp(x + offset, max=1.))
class WorstCaseEstimationLoss(nn.Module):
r"""
Worst-case Estimation loss from `Debiased Self-Training for Semi-Supervised Learning <https://arxiv.org/abs/2202.07136>`_
that forces the worst possible head :math:`h_{\text{worst}}` to predict correctly on all labeled samples
:math:`\mathcal{L}` while making as many mistakes as possible on unlabeled data :math:`\mathcal{U}`. In the
classification task, it is defined as:
.. math::
loss(\mathcal{L}, \mathcal{U}) =
\eta' \mathbb{E}_{y^l, y_{adv}^l \sim\hat{\mathcal{L}}} -\log\left(\frac{\exp(y_{adv}^l[h_{y^l}])}{\sum_j \exp(y_{adv}^l[j])}\right) +
\mathbb{E}_{y^u, y_{adv}^u \sim\hat{\mathcal{U}}} -\log\left(1-\frac{\exp(y_{adv}^u[h_{y^u}])}{\sum_j \exp(y_{adv}^u[j])}\right),
where :math:`y^l` and :math:`y^u` are logits output by the main head :math:`h` on labeled data and unlabeled data,
respectively. :math:`y_{adv}^l` and :math:`y_{adv}^u` are logits output by the worst-case estimation
head :math:`h_{\text{worst}}`. :math:`h_y` refers to the predicted label when the logits output is :math:`y`.
Args:
eta_prime (float): the trade-off hyper parameter :math:`\eta'`.
Inputs:
- y_l: logits output :math:`y^l` by the main head on labeled data
- y_l_adv: logits output :math:`y^l_{adv}` by the worst-case estimation head on labeled data
- y_u: logits output :math:`y^u` by the main head on unlabeled data
- y_u_adv: logits output :math:`y^u_{adv}` by the worst-case estimation head on unlabeled data
Shape:
- Inputs: :math:`(minibatch, C)` where C denotes the number of classes.
- Output: scalar.
"""
def __init__(self, eta_prime):
super(WorstCaseEstimationLoss, self).__init__()
self.eta_prime = eta_prime
def forward(self, y_l, y_l_adv, y_u, y_u_adv):
_, prediction_l = y_l.max(dim=1)
loss_l = self.eta_prime * F.cross_entropy(y_l_adv, prediction_l)
_, prediction_u = y_u.max(dim=1)
loss_u = F.nll_loss(shift_log(1. - F.softmax(y_u_adv, dim=1)), prediction_u)
return loss_l + loss_u