Skip to content

Commit

Permalink
fix reduce scatter async bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxun.p committed Oct 17, 2023
1 parent 229cc5c commit 6682f5d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,12 @@ def backward(ctx, grad_output, *args):
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
grad_weight = torch.zeros(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
grad_bias = torch.zeros(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
Expand Down
4 changes: 2 additions & 2 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def reset_reduce_bucket(self) -> None:
key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad = _grad
_param.grad += _grad

bucket.reset_by_rank(rank)

Expand All @@ -356,7 +356,7 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad = _grad
_param.grad += _grad

# reduce grad
if self.skip_grad_reduce is False:
Expand Down

0 comments on commit 6682f5d

Please sign in to comment.