-
Notifications
You must be signed in to change notification settings - Fork 7
/
MyOptimizerPerModule.lua
77 lines (61 loc) · 2.7 KB
/
MyOptimizerPerModule.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
local MyOptimizerPerModule,parent = torch.class('MyOptimizerPerModule','MyOptimizer')
--NOTE: various bits of this code were inspired by fbnn Optim.lua 3/5/2015
--we're just using optInfo for regularization, etc.
function MyOptimizerPerModule:__init(model,submodel_to_update,criterion, trainingOptions,optInfo,perModuleOptInfo)
parent.__init(self,model,submodel_to_update,criterion, trainingOptions,optInfo)
self.optConfigs = {}
self.optStates = {}
self.paramsPerModule = {}
self.gradParamsPerModule = {}
for i = 1,#perModuleOptInfo do
local oInfo = perModuleOptInfo[i]
local p,g = oInfo.moduleToOptimize:parameters()
table.insert(self.paramsPerModule,p)
table.insert(self.gradParamsPerModule,g)
local numBlocks = #self.paramsPerModule[i]
self.optConfigs[i] = {}
self.optStates[i] = {}
for j = 1,numBlocks do
table.insert(self.optConfigs[i],Util:deepcopy(oInfo.optConfig))
table.insert(self.optStates[i],Util:deepcopy(oInfo.optState))
end
end
self.numModulesToUpdate = #perModuleOptInfo
end
function MyOptimizerPerModule:trainBatch(inputs, targets)
assert(inputs)
assert(targets)
local parameters = self.parameters
local gradParameters = self.gradParameters
local function fEval(x)
assert(parameters == x) --this only works when we're evaluating at the current iterate
self.model:zeroGradParameters()
local output = self.model:forward(inputs)
local err = self.criterion:forward(output, targets)
local df_do = self.criterion:backward(output, targets)
self.model:backward(inputs, df_do)
--note we don't bother adding regularizer to the objective calculation. who selects models on the objective anyway?
for i = 1,self.numRegularizers do
local l2 = self.l2s[i]
for j = 1,#self.params[i] do
self.grads[i][j]:add(l2,self.params[i][j])
end
end
self.totalError[1] = self.totalError[1] + err
return err, gradParameters
end
local err = fEval(parameters)
for i = 1,self.numModulesToUpdate do
local numBlocks = #self.paramsPerModule[i]
for j = 1,numBlocks do
local function moduleFEval(x)
assert(x == self.paramsPerModule[i][j])
local grad = self.gradParamsPerModule[i][j]
return err,grad
end
local f,g = moduleFEval(self.paramsPerModule[i][j])
self.optimMethod(moduleFEval, self.paramsPerModule[i][j], self.optConfigs[i][j], self.optStates[i][j])
end
end
return err
end