-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathtrain_vqvae.py
132 lines (105 loc) · 4.4 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import logging
import os
import os.path as osp
import random
import time
import torch
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, is_train=True)
# mkdir and loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(
logger_name='base', log_level=logging.INFO, log_file=log_file)
logger.info(dict2str(opt))
# initialize tensorboard logger
tb_logger = None
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
# convert to NoneDict, which returns None for missing keys
opt = dict_to_nonedict(opt)
# set up data loader
train_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['train_ann_file'],
xflip=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt['batch_size'],
shuffle=True,
num_workers=opt['num_workers'],
persistent_workers=True,
drop_last=True)
logger.info(f'Number of train set: {len(train_dataset)}.')
opt['max_iters'] = opt['num_epochs'] * len(
train_dataset) // opt['batch_size']
val_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['val_ann_file'])
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset, batch_size=1, shuffle=False)
logger.info(f'Number of val set: {len(val_dataset)}.')
test_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['test_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['test_ann_file'])
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=1, shuffle=False)
logger.info(f'Number of test set: {len(test_dataset)}.')
current_iter = 0
best_epoch = None
best_loss = 100000
model = create_model(opt)
data_time, iter_time = 0, 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
for epoch in range(opt['num_epochs']):
lr = model.update_learning_rate(epoch)
for _, batch_data in enumerate(train_loader):
data_time = time.time() - data_time
current_iter += 1
model.optimize_parameters(batch_data, current_iter)
iter_time = time.time() - iter_time
if current_iter % opt['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': [lr]})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
data_time = time.time()
iter_time = time.time()
if epoch % opt['val_freq'] == 0:
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
val_loss_total = model.inference(val_loader, save_dir)
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
test_loss_total = model.inference(test_loader, save_dir)
logger.info(f'Epoch: {epoch}, '
f'val_loss_total: {val_loss_total}, '
f'test_loss_total: {test_loss_total}.')
if test_loss_total < best_loss:
best_epoch = epoch
best_loss = test_loss_total
logger.info(f'Best epoch: {best_epoch}, '
f'Best test loss: {best_loss: .4f}.')
# save model
model.save_network(f'{opt["path"]["models"]}/epoch{epoch}.pth')
if __name__ == '__main__':
main()