Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
pcmoritz committed May 21, 2024
1 parent 28199d8 commit 434d757
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 90 deletions.
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ def get_hidden_size(self) -> int:

def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type=='deepseek_v2':
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
# FlashAttention suports only head_size 32, 64, 128, 256, we need to pad head_size 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"):
Expand Down
37 changes: 24 additions & 13 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def fused_topk(
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
Expand All @@ -332,13 +332,13 @@ def fused_topk(
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
Expand All @@ -351,15 +351,25 @@ def fused_topk(
)
del token_expert_indicies # Not used. Will be used in the future.
else:
scores = torch.softmax(gating_output, dim = -1)
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores,
k=topk_group,
dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token,
-1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
Expand Down Expand Up @@ -523,7 +533,8 @@ def fused_moe(
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize, num_expert_group, topk_group)
renormalize, num_expert_group,
topk_group)
return fused_experts(hidden_states,
w1,
w2,
Expand Down
20 changes: 10 additions & 10 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,9 @@ def __init__(
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale))
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor)
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

Expand All @@ -505,7 +506,8 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
Expand All @@ -522,7 +524,6 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
print("Cache shape", cache.shape)
return cache


def forward(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -562,7 +563,8 @@ def forward(
query = query_rot
key = key_rot
return query, key



_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -627,11 +629,9 @@ def get_rope(
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),

"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),

"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
Expand Down
156 changes: 92 additions & 64 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def __init__(

self.experts = nn.ModuleList([
DeepseekV2MLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
Expand Down Expand Up @@ -153,40 +153,48 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True,
num_expert_group = self.config.n_group,
topk_group=self.config.topk_group) * self.routed_scaling_factor
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True,
num_expert_group=self.config.n_group,
topk_group=self.config.topk_group) * self.routed_scaling_factor
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

def __init__(
self,
config: PretrainedConfig,
hidden_size: int, num_heads: int,
qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int,
q_lora_rank: int, kv_lora_rank: int,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_idx = None,
layer_idx=None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand All @@ -206,44 +214,48 @@ def __init__(
self.max_position_embeddings = max_position_embeddings

if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size, self.q_lora_rank,
bias=False, quant_config=quant_config
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank, self.num_heads * self.qk_head_dim,
bias=False, quant_config=quant_config
)
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config)
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size, self.num_heads * self.qk_head_dim,
bias=False, quant_config=quant_config
)

self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, quant_config=quant_config
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config)

self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
self.kv_lora_rank +
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, quant_config=quant_config
)
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim, self.hidden_size,
bias=False, quant_config=quant_config
)
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config)
rope_scaling['type'] = 'deepseek_yarn'
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False
)
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)

if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
Expand All @@ -255,7 +267,7 @@ def __init__(
# self.qk_head_dim,
# self.scaling,
# num_kv_heads=self.num_heads)

# TODO, support head_size 192
self.attn = Attention(self.num_local_heads,
256,
Expand All @@ -272,28 +284,41 @@ def forward(
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim:] = k_pe
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value = 0).view(-1, self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value = 0).view(-1, self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value = 0).view(-1, self.num_local_heads * 256)
q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = attn_output.view(-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(-1, self.num_local_heads * self.v_head_dim)
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
output, _ = self.o_proj(attn_output)
return output

Expand All @@ -319,13 +344,14 @@ def __init__(
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_idx = layer_idx,
layer_idx=layer_idx,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
Expand Down Expand Up @@ -390,7 +416,9 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekV2DecoderLayer(config, layer_idx, quant_config=quant_config)
DeepseekV2DecoderLayer(config,
layer_idx,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down

0 comments on commit 434d757

Please sign in to comment.