Skip to content

Commit

Permalink
feat(model/overlap_handler.py): fix lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Oct 23, 2023
1 parent f6a5086 commit 0d693cf
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
1 change: 0 additions & 1 deletion internlm/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
device=None,
dtype=None,
):

super().__init__()

assert (
Expand Down
40 changes: 24 additions & 16 deletions internlm/model/overlap_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear
from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool
from internlm.model.utils import (
all_gather_raw_bias_memory_pool,
all_gather_raw_memory_pool,
)
from internlm.utils.common import get_current_device


Expand All @@ -25,7 +28,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.fstp_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle
self.module_to_index = dict() # key: fstp module; value: transformer block index
self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules
self.head = []
Expand Down Expand Up @@ -77,13 +80,13 @@ def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor:
self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous()

return self.zero_const_pool[size]

def _initialize_module_shape(self):
hidden_size = gpc.config.HIDDEN_SIZE
mlp_ratio = gpc.config.MLP_RATIO
mlp_hidden_size = int(hidden_size * mlp_ratio)
mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256)

self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size)
self.module_shape["out_proj"] = (hidden_size, hidden_size)
self.module_shape["w1"] = (mlp_hidden_size, hidden_size)
Expand All @@ -96,7 +99,7 @@ def _initialize_memory_pool(self) -> None:
self.all_gather_bias_memory_pool = []
self.reduce_scatter_memory_pool = {}
self.module_shape = {}

self._initialize_module_shape()
dtype = gpc.config.model.get("dtype", torch.half)
device = get_current_device()
Expand All @@ -107,10 +110,14 @@ def _initialize_memory_pool(self) -> None:
weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous()
self.all_gather_memory_pool.append(weight) # containing two groups of block weight

def clear_memory_pool(self) -> None:
self.zero_const_pool = {}
self.reduce_scatter_memory_pool = {}

def get_all_gather_memory(self, module):
block_index = self.module_to_index[module]
return self.all_gather_memory_pool[block_index % 2][module._fstp_name]

def get_bias_memory(self, module: nn.Module):
block_index = self.module_to_index[module]
# if the bias memory pool is empty or module has been not allocated memory
Expand All @@ -119,19 +126,20 @@ def get_bias_memory(self, module: nn.Module):
for _ in range(2):
weight = {}
weight[module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous()
self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()
self.all_gather_bias_memory_pool.append(weight)
elif module._fstp_name not in self.all_gather_bias_memory_pool[0]:
for i in range(2):
self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros(
self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device()).contiguous()

self.module_shape[module._fstp_name][0],
dtype=gpc.config.model.get("dtype", torch.half),
device=get_current_device(),
).contiguous()

return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name]


def get_reduce_scatter_memory(self, key):
return_idx = 0
Expand Down Expand Up @@ -170,7 +178,7 @@ def get_reduce_scatter_memory(self, key):

def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[key][index].idle = True

def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index]
for module in fstp_modules:
Expand All @@ -182,7 +190,7 @@ def _all_gather_block_weight_memory_pool(self, block_index: int):
module=module,
)
self.bias_global_handle[module] = bias_handle

weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
Expand Down
1 change: 1 addition & 0 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def all_gather_raw_memory_pool(
)
return handle


def all_gather_raw_bias_memory_pool(
input_: Tensor,
process_group: ProcessGroup,
Expand Down
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,7 @@ def main(args):
prof.step()

if gpc.fstp_handler is not None:
gpc.fstp_handler.zero_const_pool = {}
gpc.fstp_handler.reduce_scatter_memory_pool = {}
gpc.fstp_handler.clear_memory_pool()
# torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
torch.cuda.reset_peak_memory_stats()

Expand Down

0 comments on commit 0d693cf

Please sign in to comment.