Skip to content

Commit

Permalink
feat(model/linear.py): block-grained backward
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 17, 2023
1 parent 0d1fa03 commit d1af0d6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 40 deletions.
9 changes: 5 additions & 4 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
HIDDEN_SIZE = 8192
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 8
NUM_LAYER = 4
VOCAB_SIZE = 103168

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
Expand Down Expand Up @@ -57,7 +57,7 @@
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
total_steps=20,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
Expand Down Expand Up @@ -161,10 +161,11 @@
sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=1, fsdp=False),
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, mode="fstp"),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
block_0_full_weight=True,
)

cudnn_deterministic = False
Expand Down
77 changes: 41 additions & 36 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,21 +559,21 @@ def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output):

def _pre_backward_hook_for_block(block: nn.Module, grad_output):
block_index = self.block_to_index[block]
if block_index == gpc.config.NUM_LAYER - 1:
# all gather weight for the last block
fsdp_modules = self.index_to_fsdp_modules[block_index]
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handle.wait()
self.FSTP_global_weights[module] = total_weight
else:
# wait handle for current block
handles = self.block_handles[block]
for handle in handles:
handle.wait()
# if block_index == gpc.config.NUM_LAYER - 1:
# # all gather weight for the last block
# fsdp_modules = self.index_to_fsdp_modules[block_index]
# for module in fsdp_modules:
# total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handle.wait()
# self.FSTP_global_weights[module] = total_weight
# else:
# # wait handle for current block
# handles = self.block_handles[block]
# for handle in handles:
# handle.wait()

# start the all-gather for next block
if block_index - 1 >= 0:
if block_index - 1 > 0:
self._all_gather_block_weight(block_index - 1)

def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
Expand All @@ -588,36 +588,41 @@ def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
# if name_index == 4:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight

# # start the all-gather for next module
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
# else:
# handler = self.FSTP_global_handle[module]
# handler.wait()
# if name_index != 0:
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
if module in self.FSTP_global_handle:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler

def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]

# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)

for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
Expand Down

0 comments on commit d1af0d6

Please sign in to comment.