Skip to content

Commit

Permalink
fix(overlap_handler.py): fix hook error and param group split
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Dec 21, 2023
1 parent e9cd521 commit e0cafb0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
9 changes: 3 additions & 6 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,12 @@ def is_first_rank(self, parallel_mode: ParallelMode):

def is_rank_for_log(self):
"""Returns a boolean value indicating whether the current device should print log."""
# is_log_rank = (
# self.is_first_rank(ParallelMode.DATA)
# and self.is_first_rank(ParallelMode.TENSOR)
# and self.is_last_rank(ParallelMode.PIPELINE)
# )
is_log_rank = (
self.is_first_rank(ParallelMode.WEIGHT)
self.is_first_rank(ParallelMode.TENSOR)
and self.is_first_rank(ParallelMode.WEIGHT)
and self.is_first_rank(ParallelMode.DATA)
and self.is_first_rank(ParallelMode.WEIGHT_DATA)
and self.is_last_rank(ParallelMode.PIPELINE)
)
return is_log_rank

Expand Down
2 changes: 1 addition & 1 deletion internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: dis

def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613
fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1]
if module in fstp_modules:
for module in fstp_modules:
self._all_gather_module_weight(module)
_wait_handle(module)

Expand Down
3 changes: 2 additions & 1 deletion internlm/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def split_params_into_different_groups_for_optimizer_with_new_partition_strategy
pgroup["optimizer_mode"] = ParallelMode.ZERO1

# param groups may contain empty groups, such as fp32
param_groups.extend(new_groups.values())
if len(new_groups["embed_head"]["params"]) > 0:
param_groups.extend(new_groups.values())

# print(f"ht debug params_groups after split default len:{len(param_groups[0]['params'])}", flush=True)
# print(f"ht debug params_groups after split embed_head len:{len(param_groups[1]['params'])}", flush=True)
Expand Down

0 comments on commit e0cafb0

Please sign in to comment.