Skip to content

Commit

Permalink
support torch1.13
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 19, 2023
1 parent 4742271 commit 1056fe2
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 7 deletions.
44 changes: 43 additions & 1 deletion internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,49 @@ def _all_gather_block_weight_memory_pool(self, block_index: int):
# self.FSTP_global_weights[module] = total_weight
self.FSTP_global_handle[module] = weight_handle
# self.block_handles[block].append(weight_handle)

def _pre_backward_hook_for_module_memory_pool(self, module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]

if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
next_name = self.module_name[name_index - 1]
weights_handler = all_gather_raw_memory_pool(
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=next_name
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()

if block_index - 1 >= 0:
next_module = self.block_module[block_index - 1][4]
name = self.module_name[4]
weights_handler = all_gather_raw_memory_pool(
next_module.weight, self.process_group, async_op=True, block_index=block_index - 1, module_name=name,
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
name = self.module_name[name_index - 1]
weights_handler = all_gather_raw_memory_pool(
next_module.weight, self.process_group, async_op=True, block_index=block_index, module_name=name
)
self.FSTP_global_handle[next_module] = weights_handler
# if module in self.FSTP_global_handle:
# handler = self.FSTP_global_handle[module]
# handler.wait()

def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTP block.
Expand Down Expand Up @@ -724,5 +766,5 @@ def _post_backward_hook_for_module(module, grad_input, grad_output):
module.register_forward_pre_hook(_pre_forward_hook_for_module)
module.register_forward_hook(_post_forward_hook_for_module)
# module.register_full_backward_pre_hook(_pre_backward_hook_for_module)
module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
# module.register_full_backward_pre_hook(_pre_backward_hook_for_module_memory_pool)
module.register_full_backward_hook(_post_backward_hook_for_module)
7 changes: 4 additions & 3 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,16 @@ def forward(ctx, x, weight, bias, return_residual=False, process_group=None, mod
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
module = ctx.module
gpc.config.fstp_handler._pre_backward_hook_for_module_memory_pool(module, None)
block_index = ctx.block_index
module_name = ctx.module_name
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
process_group = ctx.process_group
all_gather_handler = ctx.all_gather_handler
module = ctx.module
block_index = ctx.block_index
module_name = ctx.module_name

if ctx.compute_weight_gradient:
x, weight, bias = ctx.saved_tensors
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def step(self, closure=None):
except torch.cuda.OutOfMemoryError as e:
print(e, flush=True)
print(torch.cuda.memory_summary(), flush=True)
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")

return res

Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def main(args):
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
torch.cuda.memory._record_memory_history()
# torch.cuda.memory._record_memory_history(enabled=True)
start_time = time.time()
timer("one-batch").start()

Expand Down Expand Up @@ -299,7 +299,7 @@ def main(args):

if gpc.config.fstp_handler is not None:
gpc.config.fstp_handler.zero_const_pool = {}
torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats()

ckpt_manager.wait_async_upload_finish()
Expand Down

0 comments on commit 1056fe2

Please sign in to comment.