-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathtrain_sampler.py
122 lines (98 loc) · 4.05 KB
/
train_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import logging
import os
import os.path as osp
import random
import time
import torch
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, is_train=True)
# mkdir and loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(
logger_name='base', log_level=logging.INFO, log_file=log_file)
logger.info(dict2str(opt))
# initialize tensorboard logger
tb_logger = None
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
# convert to NoneDict, which returns None for missing keys
opt = dict_to_nonedict(opt)
# set up data loader
train_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['train_ann_file'],
xflip=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt['batch_size'],
shuffle=True,
num_workers=opt['num_workers'],
persistent_workers=True,
drop_last=True)
logger.info(f'Number of train set: {len(train_dataset)}.')
opt['max_iters'] = opt['num_epochs'] * len(
train_dataset) // opt['batch_size']
val_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['val_ann_file'])
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of val set: {len(val_dataset)}.')
test_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['test_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['test_ann_file'])
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of test set: {len(test_dataset)}.')
current_iter = 0
model = create_model(opt)
data_time, iter_time = 0, 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
for epoch in range(opt['num_epochs']):
lr = model.update_learning_rate(epoch, current_iter)
for _, batch_data in enumerate(train_loader):
data_time = time.time() - data_time
current_iter += 1
model.feed_data(batch_data)
model.optimize_parameters()
iter_time = time.time() - iter_time
if current_iter % opt['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': [lr]})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
data_time = time.time()
iter_time = time.time()
if epoch % opt['val_freq'] == 0 and epoch != 0:
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(val_loader, save_dir)
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(test_loader, save_dir)
# save model
model.save_network(
model._denoise_fn,
f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
if __name__ == '__main__':
main()