Skip to content

Commit

Permalink
feat(model/overlap_handler.py): move handler to gpc
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 23, 2023
1 parent 85ad917 commit b20f47a
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 41 deletions.
5 changes: 1 addition & 4 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,16 +352,13 @@ def __init__(

class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
block_index = gpc.config.fstp_handler.module_to_index[self]
return fstp_fused_dense_func(
x,
self.weight,
self.bias,
process_group=self.process_group,
module=self,
handler=gpc.config.fstp_handler,
block_index=block_index,
module_name=self._fstp_name,
handler=gpc.fstp_handler,
)


Expand Down
16 changes: 6 additions & 10 deletions internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def _initialize_memory_pool(self) -> None:

self.all_gather_memory_pool.append(weight) # containing two groups of block weight

def get_all_gather_memory(self, index, module_name):
return self.all_gather_memory_pool[index % 2][module_name]
def get_all_gather_memory(self, module):
block_index = self.module_to_index[module]
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]

def get_reduce_scatter_memory(self, key):
return_idx = 0
Expand Down Expand Up @@ -163,8 +164,7 @@ def _all_gather_block_weight_memory_pool(self, block_index: int):
module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(module, "_fstp_name"),
module=module,
)
self.fstp_global_handle[module] = weight_handle

Expand Down Expand Up @@ -192,13 +192,11 @@ def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):

def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
first_backward_module = self.fstp_modules[-1]
block_index = self.module_to_index[first_backward_module]
weight_handle = all_gather_raw_memory_pool(
first_backward_module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(first_backward_module, "_fstp_name"),
module=first_backward_module,
)
self.fstp_global_handle[first_backward_module] = weight_handle

Expand All @@ -211,13 +209,11 @@ def _pre_backward_hook_for_module(module: nn.Module, grad_output):
module_index = self.fstp_modules.index(module)
if module_index - 1 >= 0:
next_module = self.fstp_modules[module_index - 1]
block_index = self.module_to_index[next_module]
weight_handle = all_gather_raw_memory_pool(
next_module.weight,
self.process_group,
async_op=True,
block_index=block_index,
module_name=getattr(next_module, "_fstp_name"),
module=next_module,
)
self.fstp_global_handle[next_module] = weight_handle

Expand Down
29 changes: 9 additions & 20 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import torch
import torch.nn.functional as F
from flash_attn.utils.distributed import all_reduce_raw
from torch import Tensor
from torch import Tensor, nn
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger

logger = get_logger(__file__)
Expand Down Expand Up @@ -131,11 +130,10 @@ def all_gather_raw_memory_pool(
process_group: ProcessGroup,
async_op: bool = False,
gather_dim: int = 0,
block_index: int = None,
module_name: str = None,
module: nn.Module = None,
):
handle = torch.distributed.all_gather_into_tensor(
gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name),
gpc.fstp_handler.get_all_gather_memory(module=module),
input_.contiguous(),
group=process_group,
async_op=async_op,
Expand Down Expand Up @@ -166,8 +164,8 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:])
index = gpc.config.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.config.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
Expand Down Expand Up @@ -469,16 +467,12 @@ def forward(
process_group=None,
module=None,
overlap_handler=None,
block_index=None,
module_name=None,
):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.overlap_handler = overlap_handler
ctx.module = module
ctx.block_index = block_index
ctx.module_name = module_name

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
Expand All @@ -488,7 +482,7 @@ def forward(
if world_size > 1:
# do all_gather for weight and bias before actual computation
if overlap_handler is not None:
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
Expand Down Expand Up @@ -531,8 +525,7 @@ def backward(ctx, grad_output, *args):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
overlap_handler = ctx.overlap_handler
block_index = ctx.block_index
module_name = ctx.module_name
module = ctx.module

if ctx.compute_weight_gradient:
x, weight, bias = ctx.saved_tensors
Expand All @@ -547,7 +540,7 @@ def backward(ctx, grad_output, *args):
world_size = gpc.get_world_size(ParallelMode.TENSOR)
if world_size > 1:
if overlap_handler is not None:
total_weight = gpc.config.fstp_handler.get_all_gather_memory(block_index, module_name)
total_weight = gpc.fstp_handler.get_all_gather_memory(module=module)
else:
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait()
Expand Down Expand Up @@ -669,16 +662,12 @@ def fstp_fused_dense_func(
process_group=None,
module=None,
handler=None,
block_index=None,
module_name=None,
):
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
x.dtype == torch.float32 and torch.is_autocast_enabled()
)
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
return FSTPFusedDenseFunc.apply(
x, weight, bias, return_residual, process_group, module, handler, block_index, module_name
)
return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler)
else:
assert process_group is None
out = F.linear(x, weight, bias)
Expand Down
4 changes: 2 additions & 2 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(

self._fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
self._fstp_handler = gpc.config.fstp_handler
self._fstp_handler = gpc.fstp_handler

# Zero related args
reduce_bucket_size = zero_cfg.reduce_bucket_size
Expand Down Expand Up @@ -350,7 +350,7 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona
_param.grad.add_(_grad)

# release cuda memory.
gpc.config.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None

bucket.reset_by_rank(reduce_rank)
Expand Down
4 changes: 2 additions & 2 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def initialize_model():
# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)

gpc.config.fstp_handler = None
gpc.fstp_handler = None
if gpc.config.parallel["tensor"]["sp"] == "intern" and gpc.config.parallel["tensor"]["intern_overlap"] is True:
gpc.config.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))
gpc.fstp_handler = FSTPOverlapHandler(model, gpc.get_group(ParallelMode.TENSOR))

return model

Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def main(args):

prof.step()

if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory_pool = {}
if gpc.fstp_handler is not None:
gpc.fstp_handler.zero_const_pool = {}
gpc.fstp_handler.reduce_scatter_memory_pool = {}
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats()

Expand Down

0 comments on commit b20f47a

Please sign in to comment.