-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Mistral support for Shardformer (#5103)
- Loading branch information
1 parent
9d5e04d
commit 2e04af1
Showing
4 changed files
with
255 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import warnings | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
from transformers.modeling_outputs import ( | ||
BaseModelOutputWithPast, | ||
CausalLMOutputWithPast, | ||
SequenceClassifierOutputWithPast, | ||
) | ||
from transformers.utils import logging | ||
|
||
|
||
def get_mistral_flash_attention_forward(): | ||
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv | ||
|
||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention | ||
|
||
def forward( | ||
self: MistralAttention, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, q_len, _ = hidden_states.size() | ||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." | ||
|
||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
|
||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
|
||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
if past_key_value is not None: | ||
# reuse k, v, self_attention | ||
key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
|
||
past_key_value = (key_states, value_states) if use_cache else None | ||
|
||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) | ||
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) | ||
|
||
flash_attention_mask = None | ||
attn_mask_type = AttnMaskType.causal | ||
if attention_mask != None: | ||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" | ||
) | ||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() | ||
attn_mask_type = AttnMaskType.paddedcausal | ||
|
||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) | ||
attn_output = attention( | ||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type | ||
) | ||
|
||
attn_output = self.o_proj(attn_output) | ||
|
||
return attn_output, None, past_key_value | ||
|
||
return forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Callable, Dict, List, Union | ||
|
||
import torch.nn as nn | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D | ||
|
||
from ..modeling.mistral import get_mistral_flash_attention_forward | ||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription | ||
|
||
__all__ = ["MistralPolicy", "MistralModelPolicy", "MistralForCausalLMPolicy", "MistralForSequenceClassificationPolicy"] | ||
|
||
|
||
class MistralPolicy(Policy): | ||
def config_sanity_check(self): | ||
pass | ||
|
||
def preprocess(self): | ||
if self.shard_config.enable_tensor_parallelism: | ||
# Resize embedding | ||
vocab_size = self.model.config.vocab_size | ||
world_size = self.shard_config.tensor_parallel_size | ||
|
||
if vocab_size % world_size != 0: | ||
new_vocab_size = vocab_size + world_size - vocab_size % world_size | ||
self.model.resize_token_embeddings(new_vocab_size) | ||
|
||
return self.model | ||
|
||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: | ||
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel | ||
|
||
policy = {} | ||
|
||
if self.shard_config.enable_sequence_parallelism: | ||
self.shard_config.enable_sequence_parallelism = False | ||
warnings.warn("Mistral dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") | ||
|
||
if self.shard_config.enable_tensor_parallelism: | ||
decoder_attribute_replacement = { | ||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, | ||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, | ||
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size | ||
} | ||
|
||
policy[MistralDecoderLayer] = ModulePolicyDescription( | ||
attribute_replacement=decoder_attribute_replacement, | ||
sub_module_replacement=[ | ||
SubModuleReplacementDescription( | ||
suffix="self_attn.q_proj", | ||
target_module=Linear1D_Col, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="self_attn.k_proj", | ||
target_module=Linear1D_Col, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="self_attn.v_proj", | ||
target_module=Linear1D_Col, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="self_attn.o_proj", | ||
target_module=Linear1D_Row, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="mlp.gate_proj", | ||
target_module=Linear1D_Col, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="mlp.up_proj", | ||
target_module=Linear1D_Col, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="mlp.down_proj", | ||
target_module=Linear1D_Row, | ||
), | ||
], | ||
) | ||
|
||
self.append_or_create_submodule_replacement( | ||
description=SubModuleReplacementDescription( | ||
suffix="embed_tokens", | ||
target_module=VocabParallelEmbedding1D, | ||
), | ||
policy=policy, | ||
target_key=MistralModel, | ||
) | ||
|
||
# optimization configuration | ||
if self.shard_config.enable_fused_normalization: | ||
self.append_or_create_submodule_replacement( | ||
description=[ | ||
SubModuleReplacementDescription( | ||
suffix="input_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
SubModuleReplacementDescription( | ||
suffix="post_attention_layernorm", | ||
target_module=FusedRMSNorm, | ||
), | ||
], | ||
policy=policy, | ||
target_key=MistralDecoderLayer, | ||
) | ||
|
||
self.append_or_create_submodule_replacement( | ||
description=SubModuleReplacementDescription( | ||
suffix="norm", | ||
target_module=FusedRMSNorm, | ||
), | ||
policy=policy, | ||
target_key=MistralModel, | ||
) | ||
|
||
if self.shard_config.enable_flash_attention: | ||
self.append_or_create_method_replacement( | ||
description={ | ||
"forward": get_mistral_flash_attention_forward(), | ||
}, | ||
policy=policy, | ||
target_key=MistralAttention, | ||
) | ||
|
||
return policy | ||
|
||
def postprocess(self): | ||
return self.model | ||
|
||
class MistralModelPolicy(MistralPolicy): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
class MistralForCausalLMPolicy(MistralPolicy): | ||
def module_policy(self): | ||
from transformers import MistralForCausalLM | ||
|
||
policy = super().module_policy() | ||
|
||
if self.shard_config.enable_tensor_parallelism: | ||
# add a new item for casual lm | ||
new_item = { | ||
MistralForCausalLM: ModulePolicyDescription( | ||
sub_module_replacement=[ | ||
SubModuleReplacementDescription( | ||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) | ||
) | ||
] | ||
) | ||
} | ||
policy.update(new_item) | ||
|
||
return policy | ||
|
||
class MistralForSequenceClassificationPolicy(MistralPolicy): | ||
def module_policy(self): | ||
from transformers import MistralForSequenceClassification | ||
|
||
policy = super().module_policy() | ||
|
||
if self.shard_config.enable_tensor_parallelism: | ||
# add a new item for sequence classification | ||
new_item = { | ||
MistralForSequenceClassification: ModulePolicyDescription( | ||
sub_module_replacement=[ | ||
SubModuleReplacementDescription( | ||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) | ||
) | ||
] | ||
) | ||
} | ||
policy.update(new_item) | ||
return policy |