Skip to content

Commit

Permalink
fix all-reduce norm grad
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 25, 2023
1 parent 949a0a1 commit 1655a90
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
2 changes: 2 additions & 0 deletions internlm/core/context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .parallel_context import (
IS_TENSOR_PARALLEL,
IS_SEQUENCE_PARALLEL,
Config,
ParallelContext,
global_context,
Expand Down Expand Up @@ -29,6 +30,7 @@
__all__ = [
"Config",
"IS_TENSOR_PARALLEL",
"IS_SEQUENCE_PARALLEL",
"global_context",
"ParallelContext",
"ParallelMode",
Expand Down
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .random import add_seed, get_seeds, set_mode

IS_TENSOR_PARALLEL = "is_tensor_parallel"
IS_SEQUENCE_PARALLEL = "is_sequence_parallel"

logger = get_logger(__file__)

Expand Down
12 changes: 11 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn

from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import IS_TENSOR_PARALLEL, IS_SEQUENCE_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
Expand Down Expand Up @@ -134,6 +134,12 @@ def __init__(
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.norm1.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
for param in self.norm2.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)

self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
Expand Down Expand Up @@ -356,6 +362,10 @@ def __init__(
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)

self.parallel_output = parallel_output

def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
Expand Down
25 changes: 23 additions & 2 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.distributed as dist
from torch.optim import Optimizer

from internlm.core.context import Config, ParallelMode
from internlm.core.context import Config, ParallelMode, IS_SEQUENCE_PARALLEL
from internlm.core.context import global_context as gpc
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
Expand Down Expand Up @@ -296,15 +296,36 @@ def _define_and_attach(param, reduce_rank=None):
param=param,
reduce_rank=reduce_rank,
)

def reduction_sp_func():
handle = reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.TENSOR,
)
handle.wait()

# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_func()


# define hook for sequence_parallel
def reduce_grad_hook_sp(*args):
if self.skip_grad_reduce is False:
reduction_sp_func()


accum_grad_obj.register_hook(reduce_grad_hook)

# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)

_define_and_attach(param, reduce_rank)

Expand Down

0 comments on commit 1655a90

Please sign in to comment.