Skip to content

Commit

Permalink
add support for q_proj without lora
Browse files Browse the repository at this point in the history
  • Loading branch information
zwd003 committed May 16, 2024
1 parent 2609d43 commit 2bcfba8
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,22 @@ def __init__(
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.q_a_proj = ReplicatedLinear(
self.hidden_size, self.q_lora_rank,
bias=False, quant_config=quant_config
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank, self.num_heads * self.qk_head_dim,
bias=False, quant_config=quant_config
)
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size, self.q_lora_rank,
bias=False, quant_config=quant_config
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank, self.num_heads * self.qk_head_dim,
bias=False, quant_config=quant_config
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size, self.num_heads * self.qk_head_dim,
bias=False, quant_config=quant_config
)

self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, quant_config=quant_config
Expand Down Expand Up @@ -262,10 +269,12 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
Expand Down Expand Up @@ -310,7 +319,7 @@ def __init__(
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
Expand Down

0 comments on commit 2bcfba8

Please sign in to comment.