-
Notifications
You must be signed in to change notification settings - Fork 64
/
BN-absorber.lua
105 lines (92 loc) · 3.91 KB
/
BN-absorber.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
require('nn')
local absorb_bn_conv = function (w, b, mean, invstd, affine, gamma, beta)
w:cmul(invstd:view(w:size(1),1):repeatTensor(1,w:nElement()/w:size(1)))
b:add(-mean):cmul(invstd)
if affine then
w:cmul(gamma:view(w:size(1),1):repeatTensor(1,w:nElement()/w:size(1)))
b:cmul(gamma):add(beta)
end
end
local absorb_bn_deconv = function (w, b, mean, invstd, affine, gamma, beta)
w:cmul(invstd:view(b:size(1),1):repeatTensor(w:size(1),w:nElement()/w:size(1)/b:nElement()))
b:add(-mean):cmul(invstd)
if affine then
w:cmul(gamma:view(b:size(1),1):repeatTensor(w:size(1),w:nElement()/w:size(1)/b:nElement()))
b:cmul(gamma):add(beta)
end
end
local backward_compat_running_std = function(x, i)
if x.modules[i].running_std then
x.modules[i].running_var = x.modules[i].running_std:pow(-2):add(-x.modules[i].eps)
x.modules[i].running_std = nil
end
end
local function BN_absorber(x)
local i = 1
while (i <= #x.modules) do
if x.modules[i].__typename == 'nn.Sequential' then
BN_absorber(x.modules[i])
elseif x.modules[i].__typename == 'nn.Parallel' then
BN_absorber(x.modules[i])
elseif x.modules[i].__typename == 'nn.Concat' then
BN_absorber(x.modules[i])
elseif x.modules[i].__typename == 'nn.DataParallel' then
BN_absorber(x.modules[i])
elseif x.modules[i].__typename == 'nn.ModelParallel' then
BN_absorber(x.modules[i])
elseif x.modules[i].__typename == 'nn.ConcatTable' then
BN_absorber(x.modules[i])
else
-- check BN
if x.modules[i].__typename == 'nn.SpatialBatchNormalization' then
backward_compat_running_std(x, i)
if x.modules[i-1] and
(x.modules[i-1].__typename == 'nn.SpatialConvolution' or
x.modules[i-1].__typename == 'nn.SpatialConvolutionMM') then
absorb_bn_conv(x.modules[i-1].weight,
x.modules[i-1].bias,
x.modules[i].running_mean,
x.modules[i].running_var:clone():add(x.modules[i].eps):pow(-0.5),
x.modules[i].affine,
x.modules[i].weight,
x.modules[i].bias)
x:remove(i)
i = i - 1
elseif x.modules[i-1] and
(x.modules[i-1].__typename == 'nn.SpatialFullConvolution') then
absorb_bn_deconv(x.modules[i-1].weight,
x.modules[i-1].bias,
x.modules[i].running_mean,
x.modules[i].running_var:clone():add(x.modules[i].eps):pow(-0.5),
x.modules[i].affine,
x.modules[i].weight,
x.modules[i].bias)
x:remove(i)
i = i - 1
else
assert(false, 'Convolution module must exist right before batch normalization layer')
end
elseif x.modules[i].__typename == 'nn.BatchNormalization' then
backward_compat_running_std(x, i)
if x.modules[i-1] and
(x.modules[i-1].__typename == 'nn.Linear') then
absorb_bn_conv(x.modules[i-1].weight,
x.modules[i-1].bias,
x.modules[i].running_mean,
x.modules[i].running_var:clone():add(x.modules[i].eps):pow(-0.5),
x.modules[i].affine,
x.modules[i].weight,
x.modules[i].bias)
x:remove(i)
i = i - 1
else
assert(false, 'Convolution module must exist right before batch normalization layer')
end
end
end
i = i + 1
end
collectgarbage()
return x
end
return BN_absorber