-
Notifications
You must be signed in to change notification settings - Fork 206
/
main.py
452 lines (403 loc) · 13.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
import argparse
import os
import time
import torch
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import flow_transforms
import models
import datasets
from multiscaleloss import multiscaleEPE, realEPE
import datetime
from torch.utils.tensorboard import SummaryWriter
from util import flow2rgb, AverageMeter, save_checkpoint
import numpy as np
model_names = sorted(
name for name in models.__dict__ if name.islower() and not name.startswith("__")
)
dataset_names = sorted(name for name in datasets.__all__)
parser = argparse.ArgumentParser(
description="PyTorch FlowNet Training on several datasets",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
"--dataset",
metavar="DATASET",
default="flying_chairs",
choices=dataset_names,
help="dataset type : " + " | ".join(dataset_names),
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-s", "--split-file", default=None, type=str, help="test-val split file"
)
group.add_argument(
"--split-value",
default=0.8,
type=float,
help="test-val split proportion between 0 (only test) and 1 (only train), "
"will be overwritten if a split file is set",
)
parser.add_argument(
"--split-seed",
type=int,
default=None,
help="Seed the train-val split to enforce reproducibility (consistent restart too)",
)
parser.add_argument(
"--arch",
"-a",
metavar="ARCH",
default="flownets",
choices=model_names,
help="model architecture, overwritten if pretrained is specified: "
+ " | ".join(model_names),
)
parser.add_argument(
"--solver", default="adam", choices=["adam", "sgd"], help="solver algorithms"
)
parser.add_argument(
"-j",
"--workers",
default=8,
type=int,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--epochs", default=300, type=int, metavar="N", help="number of total epochs to run"
)
parser.add_argument(
"--start-epoch",
default=0,
type=int,
metavar="N",
help="manual epoch number (useful on restarts)",
)
parser.add_argument(
"--epoch-size",
default=1000,
type=int,
metavar="N",
help="manual epoch size (will match dataset size if set to 0)",
)
parser.add_argument(
"-b", "--batch-size", default=8, type=int, metavar="N", help="mini-batch size"
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.0001,
type=float,
metavar="LR",
help="initial learning rate",
)
parser.add_argument(
"--momentum",
default=0.9,
type=float,
metavar="M",
help="momentum for sgd, alpha parameter for adam",
)
parser.add_argument(
"--beta", default=0.999, type=float, metavar="M", help="beta parameter for adam"
)
parser.add_argument(
"--weight-decay", "--wd", default=4e-4, type=float, metavar="W", help="weight decay"
)
parser.add_argument(
"--bias-decay", default=0, type=float, metavar="B", help="bias decay"
)
parser.add_argument(
"--multiscale-weights",
"-w",
default=[0.005, 0.01, 0.02, 0.08, 0.32],
type=float,
nargs=5,
help="training weight for each scale, from highest resolution (flow2) to lowest (flow6)",
metavar=("W2", "W3", "W4", "W5", "W6"),
)
parser.add_argument(
"--sparse",
action="store_true",
help="look for NaNs in target flow when computing EPE, avoid if flow is garantied to be dense,"
"automatically seleted when choosing a KITTIdataset",
)
parser.add_argument(
"--print-freq", "-p", default=10, type=int, metavar="N", help="print frequency"
)
parser.add_argument(
"-e",
"--evaluate",
dest="evaluate",
action="store_true",
help="evaluate model on validation set",
)
parser.add_argument(
"--pretrained", dest="pretrained", default=None, help="path to pre-trained model"
)
parser.add_argument(
"--no-date", action="store_true", help="don't append date timestamp to folder"
)
parser.add_argument(
"--div-flow",
default=20,
help="value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results",
)
parser.add_argument(
"--milestones",
default=[100, 150, 200],
metavar="N",
nargs="*",
help="epochs at which learning rate is divided by 2",
)
best_EPE = -1
n_iter = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main():
global args, best_EPE, n_iter
args = parser.parse_args()
save_path = "{},{},{}epochs{},b{},lr{}".format(
args.arch,
args.solver,
args.epochs,
",epochSize" + str(args.epoch_size) if args.epoch_size > 0 else "",
args.batch_size,
args.lr,
)
if not args.no_date:
timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
save_path = os.path.join(timestamp, save_path)
save_path = os.path.join(args.dataset, save_path)
print("=> will save everything to {}".format(save_path))
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.split_seed is not None:
np.random.seed(args.split_seed)
train_writer = SummaryWriter(os.path.join(save_path, "train"))
test_writer = SummaryWriter(os.path.join(save_path, "test"))
output_writers = []
for i in range(3):
output_writers.append(SummaryWriter(os.path.join(save_path, "test", str(i))))
# Data loading code
input_transform = transforms.Compose(
[
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
]
)
target_transform = transforms.Compose(
[
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow]),
]
)
if "KITTI" in args.dataset:
args.sparse = True
if args.sparse:
co_transform = flow_transforms.Compose(
[
flow_transforms.RandomCrop((320, 448)),
flow_transforms.RandomVerticalFlip(),
flow_transforms.RandomHorizontalFlip(),
]
)
else:
co_transform = flow_transforms.Compose(
[
flow_transforms.RandomTranslate(10),
flow_transforms.RandomRotate(10, 5),
flow_transforms.RandomCrop((320, 448)),
flow_transforms.RandomVerticalFlip(),
flow_transforms.RandomHorizontalFlip(),
]
)
print("=> fetching img pairs in '{}'".format(args.data))
train_set, test_set = datasets.__dict__[args.dataset](
args.data,
transform=input_transform,
target_transform=target_transform,
co_transform=co_transform,
split=args.split_file if args.split_file else args.split_value,
split_save_path=os.path.join(save_path, "split.txt"),
)
print(
"{} samples found, {} train samples and {} test samples ".format(
len(test_set) + len(train_set), len(train_set), len(test_set)
)
)
n_iter = args.start_epoch * len(train_set)
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
test_set,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
shuffle=False,
)
# create model
if args.pretrained:
network_data = torch.load(args.pretrained)
args.arch = network_data["arch"]
print("=> using pre-trained model '{}'".format(args.arch))
else:
network_data = None
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](network_data).to(device)
assert args.solver in ["adam", "sgd"]
print("=> setting {} solver".format(args.solver))
param_groups = [
{"params": model.bias_parameters(), "weight_decay": args.bias_decay},
{"params": model.weight_parameters(), "weight_decay": args.weight_decay},
]
if device.type == "cuda":
model = torch.nn.DataParallel(model).cuda()
cudnn.benchmark = True
if args.solver == "adam":
optimizer = torch.optim.Adam(
param_groups, args.lr, betas=(args.momentum, args.beta)
)
elif args.solver == "sgd":
optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum)
if args.evaluate:
best_EPE = validate(val_loader, model, 0, output_writers)
return
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=args.milestones, gamma=0.5
)
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
train_loss, train_EPE = train(
train_loader, model, optimizer, epoch, train_writer
)
scheduler.step()
train_writer.add_scalar("mean EPE", train_EPE, epoch)
# evaluate on validation set
with torch.no_grad():
EPE = validate(val_loader, model, epoch, output_writers)
test_writer.add_scalar("mean EPE", EPE, epoch)
if best_EPE < 0:
best_EPE = EPE
is_best = EPE < best_EPE
best_EPE = min(EPE, best_EPE)
save_checkpoint(
{
"epoch": epoch + 1,
"arch": args.arch,
"state_dict": model.module.state_dict(),
"best_EPE": best_EPE,
"div_flow": args.div_flow,
},
is_best,
save_path,
)
def train(train_loader, model, optimizer, epoch, train_writer):
global n_iter, args
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
flow2_EPEs = AverageMeter()
epoch_size = (
len(train_loader)
if args.epoch_size == 0
else min(len(train_loader), args.epoch_size)
)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.to(device)
input = torch.cat(input, 1).to(device)
# compute output
output = model(input)
if args.sparse:
# Since Target pooling is not very precise when sparse,
# take the highest resolution prediction and upsample it instead of downsampling target
h, w = target.size()[-2:]
output = [F.interpolate(output[0], (h, w)), *output[1:]]
loss = multiscaleEPE(
output, target, weights=args.multiscale_weights, sparse=args.sparse
)
flow2_EPE = args.div_flow * realEPE(output[0], target, sparse=args.sparse)
# record loss and EPE
losses.update(loss.item(), target.size(0))
train_writer.add_scalar("train_loss", loss.item(), n_iter)
flow2_EPEs.update(flow2_EPE.item(), target.size(0))
# compute gradient and do optimization step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print(
"Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}".format(
epoch, i, epoch_size, batch_time, data_time, losses, flow2_EPEs
)
)
n_iter += 1
if i >= epoch_size:
break
return losses.avg, flow2_EPEs.avg
def validate(val_loader, model, epoch, output_writers):
global args
batch_time = AverageMeter()
flow2_EPEs = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.to(device)
input = torch.cat(input, 1).to(device)
# compute output
output = model(input)
flow2_EPE = args.div_flow * realEPE(output, target, sparse=args.sparse)
# record EPE
flow2_EPEs.update(flow2_EPE.item(), target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i < len(output_writers): # log first output of first batches
if epoch == args.start_epoch:
mean_values = torch.tensor(
[0.45, 0.432, 0.411], dtype=input.dtype
).view(3, 1, 1)
output_writers[i].add_image(
"GroundTruth", flow2rgb(args.div_flow * target[0], max_value=10), 0
)
output_writers[i].add_image(
"Inputs", (input[0, :3].cpu() + mean_values).clamp(0, 1), 0
)
output_writers[i].add_image(
"Inputs", (input[0, 3:].cpu() + mean_values).clamp(0, 1), 1
)
output_writers[i].add_image(
"FlowNet Outputs",
flow2rgb(args.div_flow * output[0], max_value=10),
epoch,
)
if i % args.print_freq == 0:
print(
"Test: [{0}/{1}]\t Time {2}\t EPE {3}".format(
i, len(val_loader), batch_time, flow2_EPEs
)
)
print(" * EPE {:.3f}".format(flow2_EPEs.avg))
return flow2_EPEs.avg
if __name__ == "__main__":
main()