diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 53416761..7e357234 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -275,15 +275,12 @@ def is_first_rank(self, parallel_mode: ParallelMode): def is_rank_for_log(self): """Returns a boolean value indicating whether the current device should print log.""" - # is_log_rank = ( - # self.is_first_rank(ParallelMode.DATA) - # and self.is_first_rank(ParallelMode.TENSOR) - # and self.is_last_rank(ParallelMode.PIPELINE) - # ) is_log_rank = ( - self.is_first_rank(ParallelMode.WEIGHT) + self.is_first_rank(ParallelMode.TENSOR) + and self.is_first_rank(ParallelMode.WEIGHT) and self.is_first_rank(ParallelMode.DATA) and self.is_first_rank(ParallelMode.WEIGHT_DATA) + and self.is_last_rank(ParallelMode.PIPELINE) ) return is_log_rank diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index c81b09d0..65473a6b 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -287,7 +287,7 @@ def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: dis def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disable=W0613 fstp_modules = self.index_to_fstp_modules[self.num_blocks - 1] - if module in fstp_modules: + for module in fstp_modules: self._all_gather_module_weight(module) _wait_handle(module) diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 54e75ccb..7ef0cb81 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -139,7 +139,8 @@ def split_params_into_different_groups_for_optimizer_with_new_partition_strategy pgroup["optimizer_mode"] = ParallelMode.ZERO1 # param groups may contain empty groups, such as fp32 - param_groups.extend(new_groups.values()) + if len(new_groups["embed_head"]["params"]) > 0: + param_groups.extend(new_groups.values()) # print(f"ht debug params_groups after split default len:{len(param_groups[0]['params'])}", flush=True) # print(f"ht debug params_groups after split embed_head len:{len(param_groups[1]['params'])}", flush=True)