-
Notifications
You must be signed in to change notification settings - Fork 202
/
Copy pathprune.py
136 lines (110 loc) · 5.14 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torch.autograd import Variable
from torchvision import models
import cv2
import sys
import numpy as np
def replace_layers(model, i, indexes, layers):
if i in indexes:
return layers[indexes.index(i)]
return model[i]
def prune_vgg16_conv_layer(model, layer_index, filter_index, use_cuda=False):
_, conv = list(model.features._modules.items())[layer_index]
next_conv = None
offset = 1
while layer_index + offset < len(model.features._modules.items()):
res = list(model.features._modules.items())[layer_index+offset]
if isinstance(res[1], torch.nn.modules.conv.Conv2d):
next_name, next_conv = res
break
offset = offset + 1
new_conv = \
torch.nn.Conv2d(in_channels = conv.in_channels, \
out_channels = conv.out_channels - 1,
kernel_size = conv.kernel_size, \
stride = conv.stride,
padding = conv.padding,
dilation = conv.dilation,
groups = conv.groups,
bias = (conv.bias is not None))
old_weights = conv.weight.data.cpu().numpy()
new_weights = new_conv.weight.data.cpu().numpy()
new_weights[: filter_index, :, :, :] = old_weights[: filter_index, :, :, :]
new_weights[filter_index : , :, :, :] = old_weights[filter_index + 1 :, :, :, :]
new_conv.weight.data = torch.from_numpy(new_weights)
if use_cuda:
new_conv.weight.data = new_conv.weight.data.cuda()
bias_numpy = conv.bias.data.cpu().numpy()
bias = np.zeros(shape = (bias_numpy.shape[0] - 1), dtype = np.float32)
bias[:filter_index] = bias_numpy[:filter_index]
bias[filter_index : ] = bias_numpy[filter_index + 1 :]
new_conv.bias.data = torch.from_numpy(bias)
if use_cuda:
new_conv.bias.data = new_conv.bias.data.cuda()
if not next_conv is None:
next_new_conv = \
torch.nn.Conv2d(in_channels = next_conv.in_channels - 1,\
out_channels = next_conv.out_channels, \
kernel_size = next_conv.kernel_size, \
stride = next_conv.stride,
padding = next_conv.padding,
dilation = next_conv.dilation,
groups = next_conv.groups,
bias = (next_conv.bias is not None))
old_weights = next_conv.weight.data.cpu().numpy()
new_weights = next_new_conv.weight.data.cpu().numpy()
new_weights[:, : filter_index, :, :] = old_weights[:, : filter_index, :, :]
new_weights[:, filter_index : , :, :] = old_weights[:, filter_index + 1 :, :, :]
next_new_conv.weight.data = torch.from_numpy(new_weights)
if use_cuda:
next_new_conv.weight.data = next_new_conv.weight.data.cuda()
next_new_conv.bias.data = next_conv.bias.data
if not next_conv is None:
features = torch.nn.Sequential(
*(replace_layers(model.features, i, [layer_index, layer_index+offset], \
[new_conv, next_new_conv]) for i, _ in enumerate(model.features)))
del model.features
del conv
model.features = features
else:
#Prunning the last conv layer. This affects the first linear layer of the classifier.
model.features = torch.nn.Sequential(
*(replace_layers(model.features, i, [layer_index], \
[new_conv]) for i, _ in enumerate(model.features)))
layer_index = 0
old_linear_layer = None
for _, module in model.classifier._modules.items():
if isinstance(module, torch.nn.Linear):
old_linear_layer = module
break
layer_index = layer_index + 1
if old_linear_layer is None:
raise BaseException("No linear laye found in classifier")
params_per_input_channel = old_linear_layer.in_features // conv.out_channels
new_linear_layer = \
torch.nn.Linear(old_linear_layer.in_features - params_per_input_channel,
old_linear_layer.out_features)
old_weights = old_linear_layer.weight.data.cpu().numpy()
new_weights = new_linear_layer.weight.data.cpu().numpy()
new_weights[:, : filter_index * params_per_input_channel] = \
old_weights[:, : filter_index * params_per_input_channel]
new_weights[:, filter_index * params_per_input_channel :] = \
old_weights[:, (filter_index + 1) * params_per_input_channel :]
new_linear_layer.bias.data = old_linear_layer.bias.data
new_linear_layer.weight.data = torch.from_numpy(new_weights)
if use_cuda:
new_linear_layer.weight.data = new_linear_layer.weight.data.cuda()
classifier = torch.nn.Sequential(
*(replace_layers(model.classifier, i, [layer_index], \
[new_linear_layer]) for i, _ in enumerate(model.classifier)))
del model.classifier
del next_conv
del conv
model.classifier = classifier
return model
if __name__ == '__main__':
model = models.vgg16(pretrained=True)
model.train()
t0 = time.time()
model = prune_conv_layer(model, 28, 10)
print("The prunning took", time.time() - t0)