forked from facebookresearch/deepmask
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrainerSharpMask.lua
148 lines (121 loc) · 4.39 KB
/
TrainerSharpMask.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
--[[----------------------------------------------------------------------------
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
Training and testing loop for SharpMask
------------------------------------------------------------------------------]]
paths.dofile('trainMeters.lua')
local Trainer = torch.class('Trainer')
--------------------------------------------------------------------------------
-- function: init
function Trainer:__init(model, criterion, config)
-- training params
self.model = model
self.criterion = criterion
self.lr = config.lr
-- allocate cuda tensors
self.inputs, self.labels = torch.CudaTensor(), torch.CudaTensor()
-- meters
self.lossmeter = LossMeter()
self.maskmeter = IouMeter(0.5,config.testmaxload*config.batch)
-- log
self.modelsv = {model=model:clone('weight', 'bias'),config=config}
self.rundir = config.rundir
self.log = torch.DiskFile(self.rundir .. '/log', 'rw'); self.log:seekEnd()
end
--------------------------------------------------------------------------------
-- function: train
function Trainer:train(epoch, dataloader)
self.model:training()
self:updateScheduler(epoch)
self.lossmeter:reset()
local timer = torch.Timer()
for n, sample in dataloader:run() do
-- copy samples to the GPU
self:copySamples(sample)
-- forward/backward
local outputs = self.model:forward(self.inputs)
local lossbatch = self.criterion:forward(outputs, self.labels)
local gradOutputs = self.criterion:backward(outputs, self.labels)
gradOutputs:mul(self.inputs:size(1))
self.model:zeroGradParameters()
self.model:backward(self.inputs, gradOutputs)
self.model:updateParameters(self.lr)
-- update loss
self.lossmeter:add(lossbatch)
end
-- write log
local logepoch =
string.format('[train] | epoch %05d | s/batch %04.2f | loss: %07.5f ',
epoch, timer:time().real/dataloader:size(),self.lossmeter:value())
print(logepoch)
self.log:writeString(string.format('%s\n',logepoch))
self.log:synchronize()
--save model
torch.save(string.format('%s/model.t7', self.rundir),self.modelsv)
if epoch%50 == 0 then
torch.save(string.format('%s/model_%d.t7', self.rundir, epoch),
self.modelsv)
end
collectgarbage()
end
--------------------------------------------------------------------------------
-- function: test
local maxacc = 0
function Trainer:test(epoch, dataloader)
self.model:evaluate()
self.maskmeter:reset()
for n, sample in dataloader:run() do
-- copy input and target to the GPU
self:copySamples(sample)
-- infer mask in batch
local outputs = self.model:forward(self.inputs):float()
cutorch.synchronize()
self.maskmeter:add(outputs:view(sample.labels:size()),sample.labels)
end
self.model:training()
-- check if bestmodel so far
local z,bestmodel = self.maskmeter:value('0.7')
if z > maxacc then
torch.save(string.format('%s/bestmodel.t7', self.rundir),self.modelsv)
maxacc = z
bestmodel = true
end
-- write log
local logepoch =
string.format('[test] | epoch %05d '..
'| IoU: mean %06.2f median %06.2f [email protected] %06.2f [email protected] %06.2f '..
'| bestmodel %s',
epoch,
self.maskmeter:value('mean'),self.maskmeter:value('median'),
self.maskmeter:value('0.5'), self.maskmeter:value('0.7'),
bestmodel and '*' or 'x')
print(logepoch)
self.log:writeString(string.format('%s\n',logepoch))
self.log:synchronize()
collectgarbage()
end
--------------------------------------------------------------------------------
-- function: copy inputs/labels to CUDA tensor
function Trainer:copySamples(sample)
self.inputs:resize(sample.inputs:size()):copy(sample.inputs)
self.labels:resize(sample.labels:size()):copy(sample.labels)
end
--------------------------------------------------------------------------------
-- function: update training schedule according to epoch
function Trainer:updateScheduler(epoch)
if self.lr == 0 then
local regimes = {
{ 1, 50, 1e-3},
{ 51, 80, 5e-4},
{ 81, 1e8, 1e-4}
}
for _, row in ipairs(regimes) do
if epoch >= row[1] and epoch <= row[2] then
self.lr = row[3]
end
end
end
end
return Trainer