-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfeature_extractor.py
32 lines (26 loc) · 1.09 KB
/
feature_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from torchvision import models
#feature extractor
class FeatureExtractor(torch.nn.Module):
def __init__(self, n_feature, n_primitive):
# feature encoder
super().__init__()
self.n_feature = n_feature
#print("n_feature is " + str(n_feature))
self.n_primitive = n_primitive
self.fe = models.resnet18(pretrained=True)
self.fc1 = torch.nn.Linear(512, 512)
self.fc2 = torch.nn.Linear(512, self.n_feature)
self.fe.fc = nn.Sequential(self.fc1, self.fc2)
self.fe.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
# learnable primitie embedding
self.register_parameter(name='p', param=torch.nn.Parameter(torch.randn(n_primitive, n_feature)))
def forward(self, x):
F = self.fe(x)
F = torch.unsqueeze(F, 1)
F = F.expand(-1, self.n_primitive, -1) # should have size 4 * 5 * 256
P = torch.unsqueeze(self.p, 0)
P = P.expand(F.shape[0], -1, -1) # should have size 4 * 5 * 256
C = torch.cat([F,P], dim=-1)
return C