Skip to content

Commit

Permalink
fix function name change from nemo (#71)
Browse files Browse the repository at this point in the history
Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
gshennvm committed Jan 29, 2024
1 parent 8399e5a commit 6f38bf8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def on_load_checkpoint(self, checkpoint) -> None:
"""
# mcore uses distributed checkpointing
if "state_dict" in checkpoint and checkpoint["state_dict"]:
for index, module in enumerate(self.get_gpt_module_list()):
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"]
else:
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def set_sync_funcs(ptl_model, forward_only):
param_sync_func = ptl_model.sync_overlap_parameters

# pipeline schedules will get these from ptl_model.model.config
for module in ptl_model.get_gpt_module_list():
for module in ptl_model.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func
Expand Down

0 comments on commit 6f38bf8

Please sign in to comment.