Skip to content

Commit

Permalink
Add Pytorch Tensor Parallel support for Mistral (huggingface#34927)
Browse files Browse the repository at this point in the history
add base tp support
  • Loading branch information
VladOS95-cyber authored and BernardZach committed Dec 6, 2024
1 parent 5788004 commit bf74b85
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/transformers/models/mistral/configuration_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ class MistralConfig(PretrainedConfig):

model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `MistralModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

def __init__(
self,
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -983,6 +983,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}

def __init__(self, config):
super().__init__(config)
Expand Down

0 comments on commit bf74b85

Please sign in to comment.