-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathmain_training_mamba.py
177 lines (152 loc) · 5.02 KB
/
main_training_mamba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import math
import os
from pathlib import Path
import fire
import torch
import torch.optim as optim
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.modules.block import Block
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import LambdaLR
from fms_fsdp import config
from fms_fsdp.utils.checkpointing_utils import Checkpointer
from fms_fsdp.utils.config_utils import get_model_config, update_config
from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader
from fms_fsdp.utils.train_utils import (
get_policies,
get_profiler,
setup,
setup_environ_flags,
train,
)
def main(**kwargs):
# get configs
cfg = config.train_config()
update_config(cfg, **kwargs)
# ensure reproducibility
torch.cuda.manual_seed(cfg.seed)
torch.manual_seed(cfg.seed)
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if rank == 0:
print(f"--> running with these configs {cfg}")
# some setups
setup()
torch.cuda.set_device(local_rank)
torch.cuda.empty_cache()
setup_environ_flags()
os.environ["TRITON_CACHE_DIR"] = os.path.join(
Path.home(), ".triton", "cache", str(local_rank)
)
# get policy
block = Block
(
mixed_precision_policy,
wrapping_policy,
sharding_strategy_policy,
apply_selective_ac,
param_init_fn,
) = get_policies(cfg, rank, block)
# get model
config_data = get_model_config(cfg.model_variant)
mamba_config = MambaConfig(**config_data)
model = MambaLMHeadModel(mamba_config)
if rank == 0:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> model has {total_params / 1e6} Million params\n")
# get data loader
if rank == 0:
print("Constructing datasets...")
if not cfg.use_dummy_dataset:
train_loader = get_data_loader(cfg, rank, world_size)
else:
train_loader = get_dummy_loader(cfg, rank, world_size)
if rank == 0:
print("Datasets constructed!")
# FSDP
model = FSDP(
model,
auto_wrap_policy=wrapping_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy_policy,
use_orig_params=cfg.use_torch_compile,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
param_init_fn=param_init_fn,
)
# fsdp activation checkpointing
if cfg.fsdp_activation_checkpointing:
if rank == 0:
print(f"--> applying FSDP activation checkpointing...")
apply_selective_ac(model, p=cfg.selective_checkpointing)
# torch compile
if cfg.use_torch_compile:
if rank == 0:
print(f"--> enabling torch compile...")
# the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here
torch._dynamo.config.accumulated_cache_size_limit = 128
model = torch.compile(model)
# Optimizer
optimizer = optim.AdamW(
model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1
)
# optionally load from checkpoint (when continue pretraining)
checkpointer = Checkpointer(
cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank
)
model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load(
model,
optimizer,
None,
path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
if not os.path.isfile(cfg.ckpt_load_path)
else cfg.ckpt_load_path,
strict=False,
)
if not is_resuming:
start_step = 0
# Override loaded optim hyperparams with the current values
for g in optimizer.param_groups:
g["initial_lr"] = cfg.learning_rate
# LR schedule
# linear decay for annealing
if cfg.training_stage == "annealing":
schedule = lambda x: 1 - x / cfg.num_steps
else:
# cosine decay
warmup_interval = min(2000, cfg.num_steps // 20)
schedule = lambda x: min(
1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
0.1
+ 0.5
* (1 - 0.1)
* (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
)
scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
# profiler
profiler = get_profiler(cfg, rank)
# Train
if rank == 0:
print(f"Training for {cfg.num_steps} steps")
train(
cfg,
model,
local_rank,
rank,
train_loader,
optimizer,
scheduler,
profiler,
checkpointer,
start_step,
tokens_seen,
)
checkpointer.save_single_file(cfg.num_steps, model)
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
fire.Fire(main)