-
Notifications
You must be signed in to change notification settings - Fork 116
/
Copy pathmain.py
139 lines (114 loc) · 4.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import argparse
import pprint
from imagenet_inat.data import dataloader
from imagenet_inat.run_networks import model
import warnings
import yaml
from imagenet_inat.utils import source_import, get_value
data_root = {'ImageNet': './data',
'iNaturalist18': './data'}
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', default=None, type=str)
parser.add_argument('--test', default=False, action='store_true')
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--test_open', default=False, action='store_true')
parser.add_argument('--output_logits', default=False)
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--save_feat', type=str, default='')
parser.add_argument('--val_as_train', default=False, action='store_true')
args = parser.parse_args()
def split2phase(split):
if split == 'train' and args.val_as_train:
return 'train_val'
else:
return split
def update(config, args):
# Change parameters
config['model_dir'] = get_value(config['model_dir'], args.model_dir)
config['training_opt']['batch_size'] = \
get_value(config['training_opt']['batch_size'], args.batch_size)
return config
# ============================================================================
# LOAD CONFIGURATIONS
# ============================================================================
with open(args.cfg) as f:
config = yaml.load(f)
config = update(config, args)
test_mode = args.test
test_open = args.test_open
if test_open:
test_mode = True
output_logits = args.output_logits
training_opt = config['training_opt']
relatin_opt = config['memory']
dataset = training_opt['dataset']
if not os.path.isdir(training_opt['log_dir']):
os.makedirs(training_opt['log_dir'])
print('Loading dataset from: %s' % data_root[dataset.rstrip('_LT')])
pprint.pprint(config)
# ============================================================================
# MAIN LOOP
# ============================================================================
if not test_mode:
sampler_defs = training_opt['sampler']
if sampler_defs:
if sampler_defs['type'] == 'ClassAwareSampler':
sampler_dic = {
'sampler': source_import(sampler_defs['def_file']).get_sampler(),
'params': {'num_samples_cls': sampler_defs['num_samples_cls']}
}
elif sampler_defs['type'] in ['MixedPrioritizedSampler',
'ClassPrioritySampler']:
sampler_dic = {
'sampler': source_import(sampler_defs['def_file']).get_sampler(),
'params': {k: v for k, v in sampler_defs.items() \
if k not in ['type', 'def_file']}
}
else:
sampler_dic = None
splits = ['train', 'train_plain', 'val']
if dataset not in ['iNaturalist18', 'ImageNet']:
splits.append('test')
data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')],
dataset=dataset, phase=split2phase(x),
batch_size=training_opt['batch_size'],
sampler_dic=sampler_dic,
num_workers=training_opt['num_workers'])
for x in splits}
training_model = model(config, data, test=False)
training_model.train()
else:
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
print('Under testing phase, we load training data simply to calculate \
training data number for each class.')
if 'iNaturalist' in training_opt['dataset']:
splits = ['train', 'val']
test_split = 'val'
else:
splits = ['train', 'val', 'test']
test_split = 'test'
if 'ImageNet' == training_opt['dataset']:
splits = ['train', 'val']
test_split = 'val'
splits.append('train_plain')
data = {x: dataloader.load_data(data_root=data_root[dataset.rstrip('_LT')],
dataset=dataset, phase=x,
batch_size=training_opt['batch_size'],
sampler_dic=None,
test_open=test_open,
num_workers=training_opt['num_workers'],
shuffle=False)
for x in splits}
training_model = model(config, data, test=True)
# training_model.load_model()
# training_model.load_model(args.model_dir)
if args.save_feat in ['train_plain', 'val', 'test']:
saveit = True
test_split = args.save_feat
else:
saveit = False
training_model.eval(phase=test_split, openset=test_open, save_feat=saveit)
if output_logits:
training_model.output_logits(openset=test_open)
print('ALL COMPLETED.')