Skip to content

Commit

Permalink
fix none scheduler for sft
Browse files Browse the repository at this point in the history
  • Loading branch information
hijkzzz committed Nov 9, 2023
1 parent f07ae13 commit 48a0a83
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
3 changes: 2 additions & 1 deletion nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def train_single_step(self, batch):
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.scheduler.step()
if self.scheduler:
self.scheduler.step()

trainer_metrics = {}
if grad_norm is not None:
Expand Down
18 changes: 10 additions & 8 deletions nemo_aligner/utils/train_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,16 @@ def init_using_ptl(ptl_trainer, ptl_model, train_dataloader, train_ds):
ptl_trainer._checkpoint_connector._restore_modules_and_callbacks(ptl_trainer.ckpt_path)
ptl_trainer._checkpoint_connector.restore_training_state()
ptl_trainer._checkpoint_connector.resume_end()
scheduler = ptl_model._scheduler["scheduler"]

# restore the previous state of the learning rate
if scheduler.last_epoch > 0:
# NOTE: we are doing this because load_state_dict on a LRScheduler
# does not do anything that restores the learning rate on the optimizer
# stepping here will restore it properly
scheduler.step(scheduler.last_epoch)

if ptl_model._scheduler:
scheduler = ptl_model._scheduler["scheduler"]

# restore the previous state of the learning rate
if scheduler.last_epoch > 0:
# NOTE: we are doing this because load_state_dict on a LRScheduler
# does not do anything that restores the learning rate on the optimizer
# stepping here will restore it properly
scheduler.step(scheduler.last_epoch)


def add_custom_checkpoint_callback(ptl_trainer, ptl_model):
Expand Down

0 comments on commit 48a0a83

Please sign in to comment.