Skip to content

Commit

Permalink
GPTQ Fused MoE class
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 3, 2024
1 parent bef6b53 commit db1f07e
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, GPTQFusedMoE)
from vllm.triton_utils import HAS_TRITON

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"GPTQFusedMoE",
]

if HAS_TRITON:
Expand Down
155 changes: 154 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,4 +498,157 @@ def _load_fp8_scale(self, param: torch.nn.Parameter,
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
param_data[expert_id] = loaded_weight


class GPTQFusedMoE(torch.nn.Module):
"""GPTQFusedMoE layer for GPTQ MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""

def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
):
super().__init__()

if params_dtype is None:
params_dtype = torch.get_default_dtype()

self.tp_size = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
self.top_k = top_k
self.num_experts = num_experts
self.intermediate_size = intermediate_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
assert (not use_grouped_topk and num_expert_group is None
and topk_group is None)

if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedFusedMoEMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:

if ("_qweight" in weight_name or "_scales" in weight_name
or "_qzeros" in weight_name):
if "w13" in weight_name:
shard_size = loaded_weight.size()[-1]
if shard_id == "w1":
param.data[expert_id, :, :shard_size] = loaded_weight
elif shard_id == "w2" or shard_id == "w3":
param.data[expert_id, :, shard_size:] = loaded_weight
else:
raise ValueError(f"Invalid shard_id: {shard_id}: "
"must be w1, w2, or w3.")
elif "w2" in weight_name:
param.data[expert_id][:] = loaded_weight
else:
raise ValueError(f"Invalid weight name: {weight_name}: "
"must contain 'w13' or 'w2'.")
elif "_g_idx" in weight_name:
if "w13" not in weight_name and "w2" not in weight_name:
raise ValueError(f"Invalid weight name: {weight_name}: "
"must contain 'w13' or 'w2'.")
param.data[expert_id] = loaded_weight
else:
raise ValueError(f"Invalid weight name: {weight_name}.")

@staticmethod
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None):
assert (not use_grouped_topk and topk_group is None
and num_expert_group is None)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk

topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)

return topk_weights, topk_ids

def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=False,
topk_group=False,
num_expert_group=False)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states

@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:

return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

0 comments on commit db1f07e

Please sign in to comment.