Skip to content

Commit

Permalink
impl reduce scatter async
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxun.p committed Oct 17, 2023
1 parent d1af0d6 commit 229cc5c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 19 deletions.
1 change: 1 addition & 0 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def forward_backward_step(
_output, _loss, _moe_loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
)
engine.optimizer.reset_reduce_bucket()

if return_loss:
loss += _loss
Expand Down
23 changes: 19 additions & 4 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name

self.reduce_scatter_handlers = {}

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
Expand All @@ -337,16 +339,22 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _, children in _chunk.named_children():
for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
self.block_module[idx] = {}
for _, sub in block.named_children():
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
if isinstance(child, FSTPLinear):

_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")

self.FSTP_modules.append(child)
self.module_block[child] = idx
self.block_module[idx][index] = child
Expand Down Expand Up @@ -450,6 +458,8 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}

self.reduce_scatter_handlers = {}

# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
Expand All @@ -458,7 +468,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model

for _, children in _chunk.named_children():
for _chunk_name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
index = 0
Expand All @@ -467,7 +477,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.block_to_index[block] = idx
self.index_to_block[idx] = block
self.index_to_fsdp_modules[idx] = []
for _, sub in block.named_children():
for _sub_name, sub in block.named_children():
sub_modules = list(sub.children())
if len(sub_modules) > 0:
for name, child in sub.named_children():
Expand All @@ -485,6 +495,11 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.index_to_fsdp_modules[idx].append(child)
self.module_name_index[child] = index
index = index + 1

_full_name = f"{_chunk_name}.{idx}.{_sub_name}.{name}"
setattr(child.weight, "_fstp_reduce_scatter_str", f"{_full_name}.weight")
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
else:
continue

Expand Down
28 changes: 17 additions & 11 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,9 @@ def forward(ctx, x, weight, bias, return_residual=False, process_group=None, mod
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output = F.linear(total_x, total_weight, total_bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
ctx.save_for_backward(x, weight, bias)
else:
ctx.save_for_backward(weight)
ctx.save_for_backward(weight, bias)
return output if not return_residual else (output, x)

@staticmethod
Expand All @@ -340,10 +340,10 @@ def backward(ctx, grad_output, *args):
all_gather_handler = ctx.all_gather_handler
module = ctx.module
if ctx.compute_weight_gradient:
x, weight = ctx.saved_tensors
x, weight, bias = ctx.saved_tensors
total_x = x
else:
(weight,) = ctx.saved_tensors
weight, bias = ctx.saved_tensors
total_x = None
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
Expand All @@ -368,9 +368,15 @@ def backward(ctx, grad_output, *args):
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
)
if world_size > 1:
grad_weight, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = torch.empty(grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:], dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None:
grad_bias, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = torch.empty(grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:], dtype=grad_bias.dtype, device=grad_bias.device)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
Expand All @@ -384,11 +390,11 @@ def backward(ctx, grad_output, *args):
else:
grad_input = None

if ctx.needs_input_grad[1]:
if world_size > 1:
handle_grad_weight.wait()
if grad_bias is not None:
handle_grad_bias.wait()
# if ctx.needs_input_grad[1]:
# if world_size > 1:
# handle_grad_weight.wait()
# if grad_bias is not None:
# handle_grad_bias.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None


Expand Down
38 changes: 34 additions & 4 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
hysteresis = grad_scal_cfg.hysteresis
max_scale = grad_scal_cfg.max_scale

self._fstp_handler = gpc.config.fstp_handler

# Zero related args
reduce_bucket_size = zero_cfg.reduce_bucket_size
clip_grad_norm = zero_cfg.clip_grad_norm
Expand Down Expand Up @@ -301,8 +303,7 @@ def _define_and_attach(param, reduce_rank=None):
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def reduce_grad_hook(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_func()
reduction_func()

accum_grad_obj.register_hook(reduce_grad_hook)

Expand All @@ -322,6 +323,20 @@ def belongs_to_current_rank(self, param) -> bool:
group_id = getattr(param, "group_id")
return tensor_rank == gpc.get_local_rank(self._broadcast_parallel_mode[group_id])

def reset_reduce_bucket(self) -> None:
for bucket in self._bucket_store:
for rank, params in bucket._params.items():
for _param in params:
if not hasattr(_param, "_fstp_reduce_scatter_str"):
continue

key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad = _grad

bucket.reset_by_rank(rank)

def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
param_size = param.numel()

Expand All @@ -332,11 +347,26 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None):
current_bucket = self._bucket_store[group_id]

if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
# wait reduce scatter communication
params = current_bucket.get_param(reduce_rank)
for _param in params:
if not hasattr(_param, "_fstp_reduce_scatter_str"):
continue

key = getattr(_param, "_fstp_reduce_scatter_str")
comm_handle, _grad = self._fstp_handler.reduce_scatter_handlers[key]
comm_handle.wait()
_param.grad = _grad

# reduce grad
if self.skip_grad_reduce is False:
self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank, last_bucket=False)
else:
current_bucket.reset_by_rank(reduce_rank)

# the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param)
if is_param_reduced:
if is_param_reduced and self.skip_grad_reduce is False:
msg = (
f"Parameter of size ({param.size()}) has already been reduced, "
+ "duplicate reduction will lead to arithmetic incorrectness"
Expand Down

0 comments on commit 229cc5c

Please sign in to comment.