Skip to content

Commit

Permalink
fix scheduler when not supplied (#98)
Browse files Browse the repository at this point in the history
* fix scheduler when not supplied

Signed-off-by: Gerald Shen <[email protected]>

* fixup! fix scheduler when not supplied

Signed-off-by: Gerald Shen <[email protected]>

* edit changelog

Signed-off-by: Gerald Shen <[email protected]>

* remove none guard with scheduler

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
gshennvm authored Feb 12, 2024
1 parent b06f6cd commit 735c17b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
a dictionary from the training configuration.
- `exp_manager.max_time_per_run` is now respected, the trainers will save and run validation before exiting if we've reached the time limit.
- Fixed crash in PPO when using a separate reward model server (i.e., with `combine_rm_and_critic_server=False`).
- Fixed a crash when LR scheduler was not specified
- Fixed crash when LR scheduler is not specified

## [0.1.0] - 2023-12-04
### Added
Expand Down
3 changes: 1 addition & 2 deletions nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def train_single_step(self, batch):
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.scheduler.step()

trainer_metrics = {}
if grad_norm is not None:
Expand Down
40 changes: 27 additions & 13 deletions nemo_aligner/utils/train_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -106,15 +107,26 @@ def init_using_ptl(ptl_trainer, ptl_model, train_dataloader, train_ds):
ptl_trainer._checkpoint_connector._restore_modules_and_callbacks(ptl_trainer.ckpt_path)
ptl_trainer._checkpoint_connector.restore_training_state()
ptl_trainer._checkpoint_connector.resume_end()
if ptl_model._scheduler is not None:
scheduler = ptl_model._scheduler["scheduler"]
_, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model)

# restore the previous state of the learning rate
if scheduler.last_epoch > 0:
# NOTE: we are doing this because load_state_dict on a LRScheduler
# does not do anything that restores the learning rate on the optimizer
# stepping here will restore it properly
scheduler.step(scheduler.last_epoch)
# restore the previous state of the learning rate
if scheduler.last_epoch > 0:
# NOTE: we are doing this because load_state_dict on a LRScheduler
# does not do anything that restores the learning rate on the optimizer
# stepping here will restore it properly
scheduler.step(scheduler.last_epoch)


class FakeScheduler:
last_epoch = 0

def step(self):
...


class FakeCheckpointCallback:
def custom_save(self, *args, **kwargs):
...


def add_custom_checkpoint_callback(ptl_trainer, ptl_model):
Expand All @@ -126,15 +138,17 @@ def add_custom_checkpoint_callback(ptl_trainer, ptl_model):
callback.custom_save = partial(callback.custom_save_ckpt_func, ptl_trainer, ptl_model)
return callback

class FakeCheckpointCallback:
def custom_save(self, *args, **kwargs):
...

return FakeCheckpointCallback()


def extract_optimizer_scheduler_from_ptl_model(ptl_model):
return ptl_model.optimizers().optimizer, ptl_model.lr_schedulers()
scheduler = ptl_model.lr_schedulers()
assert not isinstance(scheduler, Sequence), "multiple schedulers are not supported right now"

if scheduler is None:
scheduler = FakeScheduler()

return ptl_model.optimizers().optimizer, scheduler


def init_peft(ptl_model, updated_cfg):
Expand Down

0 comments on commit 735c17b

Please sign in to comment.