Skip to content

Commit

Permalink
fix the bug so that the sequence parallel norm is all-reduced when ov…
Browse files Browse the repository at this point in the history
…erlap is False (InternLM#534)
  • Loading branch information
yingtongxiong authored Dec 12, 2023
1 parent d904730 commit 432bd5e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,7 @@ def __init__(
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False

# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
self._attach_reduction_hook()

@property
def zero_local_rank(self):
Expand Down Expand Up @@ -321,12 +318,15 @@ def reduce_grad_hook_sp(*args): # pylint: disable=W0613

# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)

accum_grad_obj.register_hook(reduce_grad_hook)
if (
gpc.config.parallel.sequence_parallel is True
and hasattr(param, IS_SEQUENCE_PARALLEL)
and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)

if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)

_define_and_attach(param, reduce_rank)

Expand Down

0 comments on commit 432bd5e

Please sign in to comment.