-
Notifications
You must be signed in to change notification settings - Fork 0
/
beheaded_inception3.py
52 lines (47 loc) · 1.78 KB
/
beheaded_inception3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models.inception import Inception3
from torch.utils.model_zoo import load_url
from warnings import warn
class BeheadedInception3(Inception3):
""" Like torchvision.models.inception.Inception3 but the head goes separately """
def forward(self, x):
if self.transform_input:
x = x.clone()
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
else: warn("Input isn't transformed")
x = self.Conv2d_1a_3x3(x)
x = self.Conv2d_2a_3x3(x)
x = self.Conv2d_2b_3x3(x)
x = F.max_pool2d(x, kernel_size=3, stride=2)
x = self.Conv2d_3b_1x1(x)
x = self.Conv2d_4a_3x3(x)
x = F.max_pool2d(x, kernel_size=3, stride=2)
x = self.Mixed_5b(x)
x = self.Mixed_5c(x)
x = self.Mixed_5d(x)
x = self.Mixed_6a(x)
x = self.Mixed_6b(x)
x = self.Mixed_6c(x)
x = self.Mixed_6d(x)
x = self.Mixed_6e(x)
x = self.Mixed_7a(x)
x = self.Mixed_7b(x)
x_for_attn = x = self.Mixed_7c(x)
# 8 x 8 x 2048
x = F.avg_pool2d(x, kernel_size=8)
# 1 x 1 x 2048
x_for_capt = x = x.view(x.size(0), -1)
# 2048
x = self.fc(x)
# 1000 (num_classes)
return x_for_attn, x_for_capt, x
def beheaded_inception_v3(transform_input=True):
model= BeheadedInception3(transform_input=transform_input)
inception_url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
model.load_state_dict(load_url(inception_url))
return model