Skip to content

Commit

Permalink
support reduce scatter memory pool
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 20, 2023
1 parent 4742271 commit ed72327
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 17 deletions.
2 changes: 1 addition & 1 deletion configs/20B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
Expand Down
8 changes: 4 additions & 4 deletions configs/30B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
HIDDEN_SIZE = 6144
NUM_ATTENTION_HEAD = 48
MLP_RATIO = 8 / 3
NUM_LAYER = 40
NUM_LAYER = 60
VOCAB_SIZE = 103168

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
Expand Down Expand Up @@ -51,7 +51,7 @@
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=4,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
Expand Down Expand Up @@ -161,8 +161,8 @@
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="origin_tp", overlap=False),
zero1=dict(size=4, fsdp=False),
tensor=dict(size=8, mode="fstp", overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)
Expand Down
2 changes: 1 addition & 1 deletion configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp"),
tensor=dict(size=8, mode="fstp", overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)
Expand Down
52 changes: 50 additions & 2 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.common import get_current_device

logger = get_logger(__file__)

Expand Down Expand Up @@ -148,6 +149,18 @@ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bo
async_op=async_op)
return output, handle

def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:])
index = check_reduce_scatter_memory_pool(size)
output = gpc.config.reduce_scatter_memory[size]['data'][index]
setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
group=process_group,
async_op=async_op)
return output, handle


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFunc(torch.autograd.Function):
Expand Down Expand Up @@ -404,12 +417,13 @@ def backward(ctx, grad_output, *args):
# assert hasattr(bias, "_fstp_all_reduce_str")
# all_gather_handler.all_reduce_handlers[bias._fstp_all_reduce_str] = (handle_grad_bias, grad_bias_async)
# grad_bias = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
grad_weight_async, handle_grad_weight = reduce_scatter_raw(grad_weight, process_group, async_op=True)

grad_weight_async, handle_grad_weight = reduce_scatter_raw_memory_pool(grad_weight, process_group, async_op=True)
assert hasattr(weight, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[weight._fstp_reduce_scatter_str] = (handle_grad_weight, grad_weight_async)
grad_weight = all_gather_handler.get_zero_by_shape((grad_weight.shape[0]//torch.distributed.get_world_size(process_group), *grad_weight.shape[1:]), dtype=grad_weight.dtype, device=grad_weight.device)
if grad_bias is not None:
grad_bias_async, handle_grad_bias = reduce_scatter_raw(grad_bias, process_group, async_op=True)
grad_bias_async, handle_grad_bias = reduce_scatter_raw_memory_pool(grad_bias, process_group, async_op=True)
assert hasattr(bias, "_fstp_reduce_scatter_str")
all_gather_handler.reduce_scatter_handlers[bias._fstp_reduce_scatter_str] = (handle_grad_bias, grad_bias_async)
grad_bias = all_gather_handler.get_zero_by_shape((grad_bias.shape[0]//torch.distributed.get_world_size(process_group), *grad_bias.shape[1:]), dtype=grad_bias.dtype, device=grad_bias.device)
Expand Down Expand Up @@ -521,3 +535,37 @@ def Silu(w1_o, w2_o):


Silu = torch.jit.script(Silu)

def check_reduce_scatter_memory_pool(key):

return_idx = 0

# if key not in dict
if key not in gpc.config.reduce_scatter_memory:
gpc.config.reduce_scatter_memory[key] = {'data': [], 'used': []}

# if the data is empty
if len(gpc.config.reduce_scatter_memory[key]['data']) == 0:
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous())
gpc.config.reduce_scatter_memory[key]['used'].append(True)
return_idx = 0
return return_idx
else: # if not empty
for index, used in enumerate(gpc.config.reduce_scatter_memory[key]['used']):
if used == False:
gpc.config.reduce_scatter_memory[key]['used'][index] = True
return_idx = index
return return_idx
# if the memory pool is all used
length = len(gpc.config.reduce_scatter_memory[key]['data'])
gpc.config.reduce_scatter_memory[key]['data'].append(torch.zeros(key,
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous())
gpc.config.reduce_scatter_memory[key]['used'].append(True)
return_idx = length
return return_idx

def release_reduce_scatter_memory_pool(size, index):
gpc.config.reduce_scatter_memory[size]['used'][index] = False
8 changes: 5 additions & 3 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import split_forward_gather_backward
from internlm.model.utils import split_forward_gather_backward, release_reduce_scatter_memory_pool
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
Expand Down Expand Up @@ -353,7 +353,8 @@ def reset_reduce_bucket(self) -> None:
comm_handle.wait()
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
del _grad
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
Expand Down Expand Up @@ -395,7 +396,8 @@ def _wait_reduce_scatter_and_accumulate_grad(self, param, reduce_rank=None):
comm_handle.wait()
_param.grad.add_(_grad)
# self._fstp_handler.reduce_scatter_handlers[key] = None
del _grad
# del _grad
release_reduce_scatter_memory_pool(size=tuple(_grad.size()),index=_grad.index)
del self._fstp_handler.reduce_scatter_handlers[key]
self._fstp_handler.reduce_scatter_handlers[key] = None
assert key in self._fstp_handler.reduce_scatter_handlers
Expand Down
18 changes: 12 additions & 6 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile
from internlm.utils.common import DummyProfile, get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
Expand Down Expand Up @@ -123,29 +123,35 @@ def initialize_model():
mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)
size_key = [(3 * hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (mlp_hidden_size, hidden_size), (hidden_size, hidden_size)]
world_size = gpc.get_world_size(ParallelMode.TENSOR)
size_key = [(3 * hidden_size // world_size, hidden_size), (mlp_hidden_size // world_size, hidden_size), (hidden_size // world_size, mlp_hidden_size), (hidden_size // world_size, hidden_size)]
module_name = ['Wqkv', 'out_proj', 'w1', 'w2', 'w3']
for i in range(2):
weight = {}
for name in module_name:
if name == 'Wqkv':
weight[name] = torch.zeros((3 * hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous()
device=get_current_device()).contiguous()
elif name == 'out_proj':
weight[name] = torch.zeros((hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous()
device=get_current_device()).contiguous()
elif name == 'w1' or name == 'w2':
weight[name] = torch.zeros((mlp_hidden_size, hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous()
device=get_current_device()).contiguous()
else:
weight[name] = torch.zeros((hidden_size, mlp_hidden_size),
dtype=gpc.config.model.get("dtype", torch.half),
device='cuda').contiguous()
device=get_current_device()).contiguous()
block_memory[i] = weight
reduce_scatter_memory = {}
for key in size_key:
reduce_scatter_memory[key] = {'data': [], 'used': []}

gpc.config.block_memory = block_memory
gpc.config.reduce_scatter_memory = reduce_scatter_memory

return model

Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def main(args):

if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {}
gpc.config.fstp_handler.reduce_scatter_memory = {}
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats()

Expand Down

0 comments on commit ed72327

Please sign in to comment.