Skip to content

Commit

Permalink
LR Schedule same name as optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Nov 26, 2024
1 parent ef835e8 commit 51bd072
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
17 changes: 12 additions & 5 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool):
return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def lr_scheduler_filename(parallel_context: ParallelContext):
"""The lr_scheduler is the same for all processes."""
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}.pt"
def lr_scheduler_filename(parallel_context: ParallelContext, is_zero: bool):
if is_zero is True:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_dp-{dist.get_rank(parallel_context.dp_pg)}-of-{parallel_context.dp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"
else:
return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt"


def save_optimizer(
Expand Down Expand Up @@ -105,18 +107,22 @@ def convert_to_string(input_item):

def save_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
"""Saves lr scheduler states"""
if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0:
# this is Zero-0, so only DP-0 saves the optimizer states
return

root_folder = root_folder / "lr_scheduler"
root_folder.mkdir(exist_ok=True, parents=True)

# We dump the optimizer state using `torch.save`
torch.save(
lr_scheduler.state_dict(),
root_folder / lr_scheduler_filename(parallel_context),
root_folder / lr_scheduler_filename(parallel_context, is_zero),
)


Expand Down Expand Up @@ -310,10 +316,11 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

def load_lr_scheduler(
lr_scheduler,
is_zero,
parallel_context: ParallelContext,
root_folder: Path,
):
root_folder = root_folder / "lr_scheduler"

state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context))
state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context, is_zero))
lr_scheduler.load_state_dict(state_dict)
7 changes: 2 additions & 5 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
if self.init_checkpoint_path is not None:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
is_zero=self.config.optimizer.zero_stage,
parallel_context=self.parallel_context,
root_folder=self.init_checkpoint_path,
)
Expand Down Expand Up @@ -864,11 +865,7 @@ def save_checkpoint(self) -> Path:
dist.get_rank(self.parallel_context.dp_pg) == 0
), # We only save the weights on DP==0
should_save_optimizer=True,
should_save_lr_scheduler=bool(
dist.get_rank(self.parallel_context.dp_pg) == 0
and dist.get_rank(self.parallel_context.tp_pg) == 0
and dist.get_rank(self.parallel_context.expert_pg) == 0
), # We only save the lr_scheduler on DP==0 && TP==0 && EP==0
should_save_lr_scheduler=True,
should_save_config=bool(
dist.get_rank(self.parallel_context.world_pg) == 0
), # We only save the config on world_rank==0
Expand Down

0 comments on commit 51bd072

Please sign in to comment.