Skip to content

Commit

Permalink
Merge pull request #253 from eliebak/add-load-lr-flag
Browse files Browse the repository at this point in the history
resuming checkpoint without lr schedule or optimizer state
  • Loading branch information
NouamaneTazi authored Dec 3, 2024
2 parents 4cf36a8 + 3034bd2 commit fdd5151
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class CheckpointsArgs:
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[xPath] = None
load_lr_scheduler: Optional[bool] = True
load_optimizer: Optional[bool] = True
checkpoints_path_is_shared_file_system: Optional[bool] = False

def __post_init__(self):
Expand Down
12 changes: 8 additions & 4 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def __init__(
optimizer_args=self.config.optimizer,
parallel_context=self.parallel_context,
)
if self.init_checkpoint_path is not None:
if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer:
load_optimizer(
optimizer=self.optimizer,
parallel_context=self.parallel_context,
Expand All @@ -206,7 +206,7 @@ def __init__(
lr_scheduler_args=self.config.optimizer.learning_rate_scheduler,
total_training_steps=self.config.tokens.train_steps,
)
if self.init_checkpoint_path is not None:
if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler:
load_lr_scheduler(
lr_scheduler=self.lr_scheduler,
is_zero=self.config.optimizer.zero_stage,
Expand All @@ -215,7 +215,7 @@ def __init__(
)

# Define iteration start state
if self.init_checkpoint_path is not None:
if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler:
checkpoint_metadata = load_meta(
parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
)
Expand Down Expand Up @@ -553,7 +553,11 @@ def training_step(
handle = None

# Move optimizer states back to GPU before optimizer step
if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step:
if (
self.init_checkpoint_path is not None
and self.config.checkpoints.load_optimizer
and self.iteration_step == self.initial_iter_step
):
state_dict_to_device(self.optimizer.state_dict(), "cuda")

before_optim_step_sanity_checks(
Expand Down

0 comments on commit fdd5151

Please sign in to comment.