-
Notifications
You must be signed in to change notification settings - Fork 64
/
sanitize.lua
93 lines (73 loc) · 2.44 KB
/
sanitize.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
require('torch')
require('nn')
require('cunn')
require('cudnn')
-- common obj name to be freed
local common = {'output', 'gradInput'}
-- temporary buffer name other than output/gradInput
local t = {
-- convolution
['nn.SpatialConvolution'] = {'finput', 'fgradInput'},
['nn.SpatialConvolutionMM'] = {'finput', 'fgradInput'},
-- pooling
['nn.SpatialMaxPooling'] = {'indices'},
['nn.TemporalMaxPooling'] = {'indices'},
['nn.VolumetricMaxPooling'] = {'indices'},
['nn.SpatialFractionalMaxPooling'] = {'indices'},
-- regularizer
['nn.BatchNormalization'] = {'buffer', 'buffer2', 'centered', 'normalized'},
['nn.SpatialBatchNormalization'] = {'buffer', 'buffer2','centered', 'normalized'},
['nn.Dropout'] = {'noise'},
['nn.SpatialDropout'] = {'noise'},
-- transfer
['nn.PReLU'] = {'gradWeightBuf', 'gradWeightBuf2'},
['nn.LogSigmoid'] = {'buffer'},
-- etc
['nn.Mean'] = {'_gradInput'},
['nn.Normalize'] = {'_output', 'norm', 'normp'},
['nn.PairwiseDistance'] = {'diff'},
['nn.Reshape'] = {'_input', '_gradOutput'},
-- fbcunn
['nn.AbstractParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
['nn.DataParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
['nn.ModelParallel'] = {'homeGradBuffers', 'input_gpu', 'gradOutput_gpu', 'gradInput_gpu'},
}
local function free_table_or_tensor(val, name, field)
if type(val[name]) == 'table' then
val[name] = {}
elseif type(val[name]) == 'userdata' then
val[name] = field.new()
end
end
local function is_member(name, t)
if t == nil then
return false
end
for _, value in pairs(t) do
if name == value then
return true
end
end
return false
end
-- Taken and modified from Soumith's imagenet-multiGPU.torch code
-- https://github.com/soumith/imagenet-multiGPU.torch/blob/master/train.lua
local function sanitize(model)
local list = model:listModules()
for _,val in ipairs(list) do
for name,field in pairs(val) do
-- remove ffi obj
if torch.type(field) == 'cdata' then
val[name] = nil
-- remove common obj
elseif is_member(name, common) then
free_table_or_tensor(val, name, field)
-- remove specific obj
elseif is_member(name, t[val.__typename]) then
free_table_or_tensor(val, name, field)
end
end
end
return model
end
return sanitize