Skip to content

Commit

Permalink
fix(overlap_handler.py): fix clear weight error when activation ckpt …
Browse files Browse the repository at this point in the history
…is True
  • Loading branch information
huangting4201 committed Dec 22, 2023
1 parent bf3c01a commit e6dcaa2
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _convert_to_fp16(self, input_: Any):

def _convert_to_fp32(self, input_: Any):
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
if isinstance(input_, Tensor) and input_.dtype in (torch.float16, torch.bfloat16):
input_ = input_.float()
return input_

Expand Down
7 changes: 6 additions & 1 deletion internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
Expand Down Expand Up @@ -318,7 +319,11 @@ def args_sanity_check():
"intern",
], "invalid sp mode, only ['none', 'megatron', 'flash-attn', 'intern'] is supported"
# adapt to old version's sequence parallel config
if gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]:
if (
gpc.config.parallel["tensor"].get("sp", None) in ["megatron", "flash-attn", "intern"]
and gpc.is_initialized(ParallelMode.TENSOR)
and gpc.get_world_size(ParallelMode.TENSOR) > 1
):
gpc.config.parallel.sequence_parallel = True

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _pre_forward_hook_for_block(block: nn.Module, inputs: Any): # pylint: disab

def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
_clear_handle(module)
if not self.model_checkpoint:
if not (self.model_checkpoint and self.is_forward is False):
_clear_weight(module)

def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
Expand Down
2 changes: 2 additions & 0 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ def record_current_batch_training_metrics(
tflops_list_2.append(tflops_2)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
if len(tgs_list) <= 0:
return
avg_tgs = sum(tgs_list) / len(tgs_list)
for tgs in tgs_list.copy():
if abs(tgs - avg_tgs) > 400:
Expand Down

0 comments on commit e6dcaa2

Please sign in to comment.