Skip to content

Commit

Permalink
feat(model/overlap_handler.py): optimize reduce scatter mem pool
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 23, 2023
1 parent b20f47a commit e7f9f1d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
35 changes: 18 additions & 17 deletions internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,37 +125,38 @@ def get_reduce_scatter_memory(self, key):

# if key not in dict
if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = {"data": [], "used": []}
self.reduce_scatter_memory_pool[key] = []

# if the data is empty
if len(self.reduce_scatter_memory_pool[key]["data"]) == 0:
self.reduce_scatter_memory_pool[key]["data"].append(
if len(self.reduce_scatter_memory_pool[key]) == 0:
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = 0
return return_idx
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
else: # if not empty
for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]):
if used is False:
self.reduce_scatter_memory_pool[key]["used"][index] = True
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
if mem_item.idle is True:
self.reduce_scatter_memory_pool[key][index].idle = False
return_idx = index
return return_idx
return self.reduce_scatter_memory_pool[key][return_idx]
# if the memory pool is all used
length = len(self.reduce_scatter_memory_pool[key]["data"])
self.reduce_scatter_memory_pool[key]["data"].append(
cur_len = len(self.reduce_scatter_memory_pool[key])
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = length
return return_idx
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
return_idx = cur_len
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]

def release_reduce_scatter_memory(self, size, index):
self.reduce_scatter_memory_pool[size]["used"][index] = False
def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[key][index].idle = True

def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index]
Expand Down
4 changes: 1 addition & 3 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:])
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index)
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona
_param.grad.add_(_grad)

# release cuda memory.
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None

bucket.reset_by_rank(reduce_rank)
Expand Down

0 comments on commit e7f9f1d

Please sign in to comment.