-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
109 lines (92 loc) · 2.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import argparse
import numpy as np
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision import transforms
from DGMR.utils.config import convert
from DGMR.DGMRTrainer import DGMRTrainer
def validate(cfg):
seed_everything(cfg.SETTINGS.RNG_SEED, workers=True)
tb_logger = TensorBoardLogger(
cfg.TENSORBOARD.SAVE_DIR,
name=cfg.TENSORBOARD.NAME,
version=cfg.TENSORBOARD.VERSION,
log_graph=False,
default_hp_metric=True
)
trainer = Trainer(
accelerator="gpu",
devices=cfg.SETTINGS.NUM_GPUS,
max_epochs=cfg.PARAMS.EPOCH,
check_val_every_n_epoch=cfg.TRAIN.VAL_PERIOD,
enable_progress_bar=True,
logger=tb_logger
)
model = DGMRTrainer(cfg)
trainer.validate(model, ckpt_path=cfg.SETTINGS.IMPORT_CKPT_PATH)
pass
def train(cfg):
seed_everything(cfg.SETTINGS.RNG_SEED, workers=True)
tb_logger = TensorBoardLogger(
cfg.TENSORBOARD.SAVE_DIR,
name=cfg.TENSORBOARD.NAME,
version=cfg.TENSORBOARD.VERSION,
log_graph=False,
default_hp_metric=True
)
trainer = Trainer(
accelerator="gpu",
devices=cfg.SETTINGS.NUM_GPUS,
max_epochs=cfg.PARAMS.EPOCH,
check_val_every_n_epoch=cfg.TRAIN.VAL_PERIOD,
enable_progress_bar=True,
logger=tb_logger
)
model = DGMRTrainer(cfg)
if cfg.SETTINGS.IMPORT_CKPT_PATH:
trainer.fit(model, ckpt_path=cfg.SETTINGS.IMPORT_CKPT_PATH)
else:
trainer.fit(model)
pass
def test(cfg):
seed_everything(cfg.SETTINGS.RNG_SEED, workers=True)
tb_logger = TensorBoardLogger(
cfg.TENSORBOARD.SAVE_DIR,
name=cfg.TENSORBOARD.NAME,
version=cfg.TENSORBOARD.VERSION,
log_graph=False,
default_hp_metric=True
)
trainer = Trainer(
accelerator="gpu",
devices=cfg.SETTINGS.NUM_GPUS,
max_epochs=cfg.PARAMS.EPOCH,
check_val_every_n_epoch=cfg.TRAIN.VAL_PERIOD,
enable_progress_bar=True,
logger=tb_logger
)
model = DGMRTrainer(cfg)
trainer.test(model, ckpt_path=cfg.SETTINGS.IMPORT_CKPT_PATH)
pass
def main(args):
cfg_file = args.config
cfg = convert(cfg_file)
if args.mode == "train":
train(cfg)
elif args.mode == "test":
test(cfg)
elif args.mode == "validate":
validate(cfg)
else:
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-c", "--config", type=str,
help="a config file in yaml format to control experiment")
parser.add_argument(
"-m", "--mode", type=str,
choices=["train", "validate", "test"])
args = parser.parse_args()
main(args)