diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c50334f6..8a8c8926 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -159,6 +159,8 @@ class CheckpointsArgs: save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[xPath] = None + load_lr_scheduler: Optional[bool] = True + load_optimizer: Optional[bool] = True checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 4ed830de..94b03c6e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -190,7 +190,7 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer: load_optimizer( optimizer=self.optimizer, parallel_context=self.parallel_context, @@ -206,7 +206,7 @@ def __init__( lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, total_training_steps=self.config.tokens.train_steps, ) - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler: load_lr_scheduler( lr_scheduler=self.lr_scheduler, is_zero=self.config.optimizer.zero_stage, @@ -215,7 +215,7 @@ def __init__( ) # Define iteration start state - if self.init_checkpoint_path is not None: + if self.init_checkpoint_path is not None and self.config.checkpoints.load_lr_scheduler: checkpoint_metadata = load_meta( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path ) @@ -553,7 +553,11 @@ def training_step( handle = None # Move optimizer states back to GPU before optimizer step - if self.init_checkpoint_path is not None and self.iteration_step == self.initial_iter_step: + if ( + self.init_checkpoint_path is not None + and self.config.checkpoints.load_optimizer + and self.iteration_step == self.initial_iter_step + ): state_dict_to_device(self.optimizer.state_dict(), "cuda") before_optim_step_sanity_checks(