forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_parkinglot.py
171 lines (157 loc) · 6.18 KB
/
train_parkinglot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# build dataset
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
import mmcv
import os, time
import torch
from mmseg.utils import collect_env, get_root_logger
from mmcv.utils import Config, DictAction, get_git_hash
from mmcv.runner import init_dist
from mmseg import __version__
# from mmcv import Config
from mmseg.apis import set_random_seed
from mmcv.utils.config import Config as ConfigRoot, ConfigDict
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
base = 'data/parkinglot/'
work_dir = 'checkpoints/Parkinglot_ocr_hr_norm_cw/'
# img_table_file = base+'img_anno.csv'
# dataset_name= 'parkinglot'
# model_name = 'OCR_HRNet_Parkinglot'
# img_dir = 'images/'
# ann_dir = 'labels/'
#load palette
# palette = eval(open(base+'color.json', 'r').read())
# set cudnn_benchmark
torch.backends.cudnn.benchmark = True
# import config
config_file = 'configs/ocrnet/ocrnet_hr48_parkinglot_config.py'# Using OCR+HRNet
# config_file = 'configs/hrnet/parkinglot.py'# Using HRNetV2
cfg = Config.fromfile(config_file)
cfg.work_dir = work_dir
# cfg.load_from = 'checkpoints/ocrnet_hr48_512x1024_160k_cityscapes_20200602_191037-dfbf1b0c.pth'
cfg.load_from = 'checkpoints/Parkinglot_ocr_hr_norm_cw/latest.pth'
# cfg.resume_from = 'checkpoints/Parkinglot_ocr_hr_norm_cw/latest.pth'
cfg.runner = dict(type='IterBasedRunner', max_iters=80000)
#GPU
n_gpu = torch.cuda.device_count()
print(f'Found {n_gpu} GPUs')
cfg.gpu_ids = range(n_gpu)
if n_gpu>1:
cfg.norm_cfg = dict(type='SyncBN', requires_grad=True)
cfg.samples_per_gpu = 8
dist_params = dict(backend='nccl')
print('Found multiple GPU, SyncBN enabled!')
else:
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.samples_per_gpu = 4
print('Using single GPU with batch size 4')
# init distributed env first, since logger depends on the dist info.
distributed = len(cfg.gpu_ids) > 1
if distributed == True:
init_dist('pytorch', **cfg.dist_params)
# adjust learning rate
cfg.optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
# In MMSegmentation, you may add following lines to config to make the LR of heads 10 times of backbone.
# cfg.optimizer=dict(
# type='AdamW', lr=0.0001, weight_decay=0.00001,
# type='SGD', lr=0.0003, momentum=0.9, weight_decay=0.00001,
# paramwise_cfg = dict(
# custom_keys={
# 'head': dict(lr_mult=2)})
# )
def update_config(obj, path='cfg'):
if isinstance(obj, ConfigRoot) or isinstance(obj, ConfigDict or isinstance(obj, dict)):
check_children = False
if 'type' in obj:
if obj.type == 'Normalize':
obj.mean = cfg.img_norm_cfg.mean
obj.std = cfg.img_norm_cfg.std
print(f'Updated `Nomalize` at {path} -> {obj}')
elif obj.type == 'Resize':
obj.img_scale=(1800, 1800)
print(f'updated `Resize` at {path} -> {obj}')
# elif obj.type == 'RandomCrop':
# obj.crop_size=(1024, 1024)
# print(f'updated `RandomCrop` at {path}')
# elif obj.type == 'Pad':
# obj.size=(1024, 1024)
# print(f'updated `Pad` at {path}')
# elif obj.type == 'MultiScaleFlipAug':
obj.img_scale=(1800, 1800)
print(f'updated `MultiScaleFlipAug` at {path} -> {obj}')
else:
check_children = True
else:
check_children = True
if check_children:
for k, v in obj.items():
update_config(v, path+'.'+k)
elif isinstance(obj, list): #list
for obj2 in obj:
update_config(obj2, f"{path}[{obj.index(obj2)}]")
else:
if type(obj) not in [str, tuple, int, bool, float, range]:
print(path, obj)
pass
update_config(cfg)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line)
meta['env_info'] = env_info
meta['exp_name'] = os.path.basename(config_file)
# log some basic info
logger.info(f'Distributed training: {distributed}')
# set random seeds
seed = 0
if seed is not None:
logger.info(f'Set random seed to {seed}, deterministic: ' f'{True}')
set_random_seed(seed, deterministic=True)
cfg.seed = seed
meta['seed'] = seed
# get gflops for model
# os.system('python tools/get_flops.py configs/hrnet/parkinglot.py --shape 1024 512')
# train and eval
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import inference_segmentor, init_segmentor, train_segmentor
# Build the dataset
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
val_dataset.pipeline = cfg.data.train.pipeline
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
# save mmseg version, config file content and class names in
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
# config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES,
PALETTE=datasets[0].PALETTE)
# save config
cfg.dump(os.path.join(cfg.work_dir, 'config.py'))
# with open(work_dir+'config.py', 'w') as f:
# f.write(cfg.pretty_text)
# log model info
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
# logger.info(model)
# model = init_segmentor(cfg, cfg.load_from, device='cuda:0')
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
model.PALETTE = datasets[0].PALETTE
# Create work_dir
mmcv.mkdir_or_exist(os.path.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=distributed, validate=True, timestamp=timestamp, meta=meta)