forked from sxyu/svox2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
config_generator.py
83 lines (60 loc) · 2.21 KB
/
config_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import itertools
from pathlib import Path
import json
from copy import deepcopy
ROOT_PATH = 'opt/configs/auto_tune/llff_bg_less_trunc_zero_lv/'
root_dir = Path(ROOT_PATH)
root_dir.mkdir(parents=True, exist_ok=True)
with (root_dir / 'config.json').open('r') as f:
tune_conf = json.load(f)
params = tune_conf['params']
ids = []
for i in range(len(params)):
ids.append(list(range(len(params[i]['values']))))
choices = list(itertools.product(*ids))
source_config = ""
if 'source_conf' in tune_conf:
if tune_conf['source_conf'] == 'CONF_FOLDER':
source_conf_path = root_dir / 'source.yaml'
else:
source_conf_path = Path(tune_conf['source_conf'])
with source_conf_path.open('r') as f:
source_config += "########## Source Config ##########\n"
for line in f:
# if line.startswith('#') or line.startswith('\n'):
# continue
source_config += line
source_config += "\n########## Tuned Config ##########\n"
configs = []
check_pairs = [
['lr_surface = {}\n', 'lr_surface_final = {}\n'],
['fake_sample_std = {}\n', 'fake_sample_std_final = {}\n'],
['lr_fake_sample_std = {}\n', 'lr_fake_sample_std_final = {}\n'],
['lr_sigma = {}\n', 'lr_sigma_final = {}\n'],
['lr_alpha = {}\n', 'lr_alpha_final = {}\n'],
]
for choice in choices:
config_record = {} # used to check whether lr_start is larger than lr_final
config = deepcopy(source_config)
# config += f"include '{tune_conf['source_conf']}'\n\n"
for i in range(len(params)):
v = params[i]['values'][choice[i]]
config_record[params[i]['text']] = v
# if isinstance(v, list):
# config += params[i]['text'].format(*v) + "\n"
# else:
config += params[i]['text'].format(v) + "\n"
skip = False
for pair in check_pairs:
if pair[0] in config_record and pair[1] in config_record:
if config_record[pair[0]] < config_record[pair[1]]:
skip = True
break
if skip:
continue
configs.append(config)
for i in range(len(configs)):
filepath = root_dir / f"{i:04d}.yaml"
with filepath.open("w") as f:
f.write(configs[i])
print(filepath)