From 0d693cf3a182b34cc9af7b6ef640f250ff7abbda Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:22:03 +0800 Subject: [PATCH] feat(model/overlap_handler.py): fix lint error --- internlm/model/moe.py | 1 - internlm/model/overlap_handler.py | 40 ++++++++++++++++++------------- internlm/model/utils.py | 1 + train.py | 3 +-- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 28e5ae6e..0865097f 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -53,7 +53,6 @@ def __init__( device=None, dtype=None, ): - super().__init__() assert ( diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index f7132c3b..3f7ee055 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -10,7 +10,10 @@ from internlm.core.naive_amp import NaiveAMPModel from internlm.model.embedding import Embedding1D from internlm.model.linear import FSTPLinear, ScaleColumnParallelLinear -from internlm.model.utils import all_gather_raw_memory_pool, all_gather_raw_bias_memory_pool +from internlm.model.utils import ( + all_gather_raw_bias_memory_pool, + all_gather_raw_memory_pool, +) from internlm.utils.common import get_current_device @@ -25,7 +28,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non self.fstp_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.fstp_global_handle = dict() # key: fstp module; value: module global all-gather op handle - self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle + self.bias_global_handle = dict() # key: fstp module; value: module bias global all-gather op handle self.module_to_index = dict() # key: fstp module; value: transformer block index self.index_to_fstp_modules = dict() # key: transformer block index; value: fsdp modules self.head = [] @@ -77,13 +80,13 @@ def get_zero_by_shape(self, size: tuple, dtype, device) -> torch.Tensor: self.zero_const_pool[size] = torch.zeros(*size, dtype=dtype, device=device).contiguous() return self.zero_const_pool[size] - + def _initialize_module_shape(self): hidden_size = gpc.config.HIDDEN_SIZE mlp_ratio = gpc.config.MLP_RATIO mlp_hidden_size = int(hidden_size * mlp_ratio) mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) - + self.module_shape["Wqkv"] = (3 * hidden_size, hidden_size) self.module_shape["out_proj"] = (hidden_size, hidden_size) self.module_shape["w1"] = (mlp_hidden_size, hidden_size) @@ -96,7 +99,7 @@ def _initialize_memory_pool(self) -> None: self.all_gather_bias_memory_pool = [] self.reduce_scatter_memory_pool = {} self.module_shape = {} - + self._initialize_module_shape() dtype = gpc.config.model.get("dtype", torch.half) device = get_current_device() @@ -107,10 +110,14 @@ def _initialize_memory_pool(self) -> None: weight[name] = torch.zeros(self.module_shape[name], dtype=dtype, device=device).contiguous() self.all_gather_memory_pool.append(weight) # containing two groups of block weight + def clear_memory_pool(self) -> None: + self.zero_const_pool = {} + self.reduce_scatter_memory_pool = {} + def get_all_gather_memory(self, module): block_index = self.module_to_index[module] return self.all_gather_memory_pool[block_index % 2][module._fstp_name] - + def get_bias_memory(self, module: nn.Module): block_index = self.module_to_index[module] # if the bias memory pool is empty or module has been not allocated memory @@ -119,19 +126,20 @@ def get_bias_memory(self, module: nn.Module): for _ in range(2): weight = {} weight[module._fstp_name] = torch.zeros( - self.module_shape[module._fstp_name][0], - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() self.all_gather_bias_memory_pool.append(weight) elif module._fstp_name not in self.all_gather_bias_memory_pool[0]: for i in range(2): self.all_gather_bias_memory_pool[i][module._fstp_name] = torch.zeros( - self.module_shape[module._fstp_name][0], - dtype=gpc.config.model.get("dtype", torch.half), - device=get_current_device()).contiguous() - + self.module_shape[module._fstp_name][0], + dtype=gpc.config.model.get("dtype", torch.half), + device=get_current_device(), + ).contiguous() + return self.all_gather_bias_memory_pool[block_index % 2][module._fstp_name] - def get_reduce_scatter_memory(self, key): return_idx = 0 @@ -170,7 +178,7 @@ def get_reduce_scatter_memory(self, key): def release_reduce_scatter_memory(self, key, index): self.reduce_scatter_memory_pool[key][index].idle = True - + def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] for module in fstp_modules: @@ -182,7 +190,7 @@ def _all_gather_block_weight_memory_pool(self, block_index: int): module=module, ) self.bias_global_handle[module] = bias_handle - + weight_handle = all_gather_raw_memory_pool( module.weight, self.process_group, diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 8a1281e8..42a84003 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -140,6 +140,7 @@ def all_gather_raw_memory_pool( ) return handle + def all_gather_raw_bias_memory_pool( input_: Tensor, process_group: ProcessGroup, diff --git a/train.py b/train.py index 96dc24d1..b4f2a6d2 100644 --- a/train.py +++ b/train.py @@ -298,8 +298,7 @@ def main(args): prof.step() if gpc.fstp_handler is not None: - gpc.fstp_handler.zero_const_pool = {} - gpc.fstp_handler.reduce_scatter_memory_pool = {} + gpc.fstp_handler.clear_memory_pool() # torch.cuda.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") torch.cuda.reset_peak_memory_stats()