-
Notifications
You must be signed in to change notification settings - Fork 2
/
whole-in-one.py
49 lines (38 loc) · 1.04 KB
/
whole-in-one.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import yaml
import json
import argparse
from attrdict import AttrDict
from dkt.dataloader import Preprocess
from dkt import trainer
from dkt.utils import setSeeds
import torch
import wandb
from train import main as t_main
from inference import main as i_main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--conf', default='./conf.yml', help='wrtie configuration file root.')
term_args = parser.parse_args()
with open(term_args.conf) as f:
cf = yaml.load(f, Loader=yaml.FullLoader)
args = AttrDict(cf)
# args = parse_args(mode='train')
os.makedirs(args.model_dir, exist_ok=True)
#train
t_main(args)
#inference
i_main(args)
#save config_file as light_version
args.pop('wandb')
save_path=f"{args.output_dir}{args.task_name}/exp_config.json"
if args.model=='lgbm':
args=args.lgbm
else :
args.pop('lgbm')
json.dump(
args,
open(save_path, "w"),
indent=2,
ensure_ascii=False,
)