-
Notifications
You must be signed in to change notification settings - Fork 392
/
train.py
104 lines (84 loc) · 4.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
'''
-----------------------------------------------------------------------------
Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
'''
import argparse
import os
import imaginaire.config
from imaginaire.config import Config, recursive_update_strict, parse_cmdline_arguments
from imaginaire.utils.cudnn import init_cudnn
from imaginaire.utils.distributed import init_dist, get_world_size, master_only_print as print, is_master
from imaginaire.utils.gpu_affinity import set_affinity
from imaginaire.trainers.utils.logging import init_logging
from imaginaire.trainers.utils.get_trainer import get_trainer
from imaginaire.utils.set_random_seed import set_random_seed
def parse_args():
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--config', help='Path to the training config file.', required=True)
parser.add_argument('--logdir', help='Dir for saving logs and models.', default=None)
parser.add_argument('--checkpoint', default=None, help='Checkpoint path.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0))
parser.add_argument('--single_gpu', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--profile', action='store_true')
parser.add_argument('--show_pbar', action='store_true')
parser.add_argument('--wandb', action='store_true', help="Enable using Weights & Biases as the logger")
parser.add_argument('--wandb_name', default='default', type=str)
parser.add_argument('--resume', action='store_true')
args, cfg_cmd = parser.parse_known_args()
return args, cfg_cmd
def main():
args, cfg_cmd = parse_args()
set_affinity(args.local_rank)
cfg = Config(args.config)
cfg_cmd = parse_cmdline_arguments(cfg_cmd)
recursive_update_strict(cfg, cfg_cmd)
# If args.single_gpu is set to True, we will disable distributed data parallel.
if not args.single_gpu:
# this disables nccl timeout
os.environ["NCLL_BLOCKING_WAIT"] = "0"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
cfg.local_rank = args.local_rank
init_dist(cfg.local_rank, rank=-1, world_size=-1)
print(f"Training with {get_world_size()} GPUs.")
# set random seed by rank
set_random_seed(args.seed, by_rank=True)
# Global arguments.
imaginaire.config.DEBUG = args.debug
# Create log directory for storing training results.
cfg.logdir = init_logging(args.config, args.logdir, makedir=True)
# Print and save final config
if is_master():
cfg.print_config()
cfg.save_config(cfg.logdir)
# Initialize cudnn.
init_cudnn(cfg.cudnn.deterministic, cfg.cudnn.benchmark)
# Initialize data loaders and models.
trainer = get_trainer(cfg, is_inference=False, seed=args.seed)
trainer.set_data_loader(cfg, split="train")
trainer.set_data_loader(cfg, split="val")
trainer.checkpointer.load(args.checkpoint, args.resume, load_sch=True, load_opt=True)
# Initialize Wandb.
trainer.init_wandb(cfg,
project=args.wandb_name,
mode="disabled" if args.debug or not args.wandb else "online",
resume=args.resume,
use_group=True)
trainer.mode = 'train'
# Start training.
trainer.train(cfg,
trainer.train_data_loader,
single_gpu=args.single_gpu,
profile=args.profile,
show_pbar=args.show_pbar)
# Finalize training.
trainer.finalize(cfg)
if __name__ == "__main__":
main()