-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathvggfeatures.py
66 lines (49 loc) Β· 1.97 KB
/
vggfeatures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Extract features from a pretrained VGG16
# Code from https://github.com/futscdav/strotss
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
class VGG16_Extractor(nn.Module):
def __init__(self):
super().__init__()
self.vgg_layers = models.vgg16(pretrained=True).features
for param in self.parameters():
param.requires_grad = False
self.capture_layers = [1,3,6,8,11,13,15,22,29]
def forward_base(self, x):
feat = [x]
for i in range(len(self.vgg_layers)):
x = self.vgg_layers[i](x)
if i in self.capture_layers: feat.append(x)
return feat
def forward(self, x):
x = (x + 1.) / 2.
x = x - (torch.Tensor([0.485, 0.456, 0.406]).to(x.device).view(1, -1, 1, 1))
x = x / (torch.Tensor([0.229, 0.224, 0.225]).to(x.device).view(1, -1, 1, 1))
feat = self.forward_base(x)
return feat
def forward_samples_hypercolumn(self, X, samps=100):
feat = self.forward(X)
xx,xy = np.meshgrid(np.arange(X.shape[2]), np.arange(X.shape[3]))
xx = np.expand_dims(xx.flatten(),1)
xy = np.expand_dims(xy.flatten(),1)
xc = np.concatenate([xx,xy],1)
samples = min(samps,xc.shape[0])
np.random.shuffle(xc)
xx = xc[:samples,0]
yy = xc[:samples,1]
feat_samples = []
for i in range(len(feat)):
layer_feat = feat[i]
# Hack to detect lower resolution
if i > 0 and feat[i].size(2) < feat[i-1].size(2):
xx = xx/2.0
yy = yy/2.0
xx = np.clip(xx, 0, layer_feat.shape[2]-1).astype(np.int32)
yy = np.clip(yy, 0, layer_feat.shape[3]-1).astype(np.int32)
features = layer_feat[:,:, xx[range(samples)], yy[range(samples)]]
feat_samples.append(features.clone().detach())
feat = torch.cat(feat_samples,1)
return feat