Skip to content

Commit

Permalink
remove full weight for block 0
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Oct 17, 2023
1 parent 5c38cb6 commit 5abe519
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 117 deletions.
152 changes: 69 additions & 83 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.embedding import Embedding1D
from internlm.model.utils import (
Silu,
all_gather_raw,
Expand Down Expand Up @@ -255,56 +256,33 @@ def __init__(

hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)

if block_idx == 0 and gpc.config.parallel.block_0_full_weight:
self.w1 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w2 = nn.Linear(
in_features,
hidden_features,
bias,
device=device,
dtype=dtype,
)
self.w3 = nn.Linear(
hidden_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
)
else:
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w1 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w2 = FSTPLinear(
in_features,
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
self.w3 = FSTPLinear(
hidden_features,
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

def forward(self, x):
w1_o = self.w1(x)
Expand Down Expand Up @@ -458,6 +436,7 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = []
self.embedding = []

self.reduce_scatter_handlers = {}

Expand Down Expand Up @@ -505,6 +484,8 @@ def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> Non
continue
elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
elif isinstance(children, Embedding1D):
self.embedding.append(children)

def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index]
Expand Down Expand Up @@ -532,7 +513,6 @@ def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight(block_index + 1)
# print(f"_all_gather_block_weight for block {block_index+1}", flush=True)

def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
block_index = self.block_to_index[block]
Expand All @@ -548,6 +528,10 @@ def _pre_forward_hook_for_block(block: nn.Module, inputs: Any):
handles = self.block_handles[block]
for handle in handles:
handle.wait()

def _pre_forward_hook_for_embedding(module: nn.Module, inputs: Any, output):
self._all_gather_block_weight(0)


def _post_forward_hook_for_block(block: nn.Module, input, output):
block_index = self.block_to_index[block]
Expand All @@ -557,11 +541,10 @@ def _post_forward_hook_for_block(block: nn.Module, input, output):
for module in fsdp_modules:
del self.FSTP_global_weights[module]

def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any,):
block_index = self.module_to_index[module]
if block_index != 0:
handler = self.FSTP_global_handle[module]
handler.wait()
handler = self.FSTP_global_handle[module]
handler.wait()

def _post_forward_hook_for_module(module: nn.Module, input, output):
if module in self.FSTP_global_weights:
Expand Down Expand Up @@ -593,7 +576,7 @@ def _pre_backward_hook_for_block(block: nn.Module, grad_output):
# if block_index == gpc.config.NUM_LAYER - 1:
# self._all_gather_block_weight(block_index)
# start the all-gather for next block
if block_index - 1 > 0:
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)

def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
Expand All @@ -613,38 +596,38 @@ def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module

if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight

# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()

if block_index - 1 >= 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()

if block_index - 1 > 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
# if module in self.FSTP_global_handle:
# handler = self.FSTP_global_handle[module]
# handler.wait()
Expand All @@ -655,6 +638,9 @@ def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]

for embedding in self.embedding:
embedding.register_forward_hook(_pre_forward_hook_for_embedding)

for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)

Expand Down
50 changes: 16 additions & 34 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,14 @@ def __init__(

# notice here should change bias=True
Wqkv_cls = ColumnParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
Wqkv_cls = nn.Linear
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
self.Wqkv = Wqkv_cls(
embed_dim,
3 * embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577

inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
Expand All @@ -235,23 +226,14 @@ def __init__(

# output projection always have the bias (for now)
out_proj_cls = RowParallelLinearTorch if tp_mode == "origin_tp" else FSTPLinear
if block_idx == 0 and tp_mode != "origin_tp" and gpc.config.parallel.block_0_full_weight:
out_proj_cls = nn.Linear
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
bias=False,
**factory_kwargs,
)
else:
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
self.out_proj = out_proj_cls(
embed_dim,
embed_dim,
process_group,
bias=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for name in ["out_proj", "Wqkv"]:
Expand Down

0 comments on commit 5abe519

Please sign in to comment.