Skip to content

Commit

Permalink
Merge pull request #4 from jiazhan-msft/dev/jiazhan/add_phi3_dedicated
Browse files Browse the repository at this point in the history
Dev/jiazhan/add phi3 dedicated
  • Loading branch information
jiazhan-msft authored Aug 27, 2024
2 parents 4be6309 + afe0ed3 commit 9f426b4
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 2 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k, num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
Expand Down
173 changes: 173 additions & 0 deletions vllm/model_executor/models/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# coding=utf-8
# Adapted from llama.py

"""Inference-only Phi3 model code inherit from Llama.py"""

from typing import Optional

import torch

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)

from .utils import make_layers

from vllm.model_executor.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers import Phi3Config

class Phi3Attention(LlamaAttention):
def __init__(
self,
config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=config.rope_theta,
rope_scaling=config.rope_scaling,
max_position_embeddings=config.max_position_embeddings,
quant_config=quant_config,
bias=bias,
cache_config=cache_config,
prefix=prefix)

self.rope_scaling = config.rope_scaling


def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) \
if self.rope_scaling is None \
else self.rotary_emb(positions, q, k, num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor)

attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output


class Phi3DecoderLayer(LlamaDecoderLayer):

def __init__(
self,
config: Phi3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix
)
self.self_attn = Phi3Attention(
config=config,
quant_config=quant_config,
bias=getattr(config, "attention_bias", False) or getattr(
config, "bias", False),
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)



class Phi3Model(LlamaModel):

def __init__(
self,
config: Phi3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
config=config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config,
prefix=prefix
)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Phi3DecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")


class Phi3ForCausalLM(LlamaForCausalLM):

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(
self,
config: Phi3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__(
config=config,
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config
)

self.model = Phi3Model(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

0 comments on commit 9f426b4

Please sign in to comment.