From 000796acdf1e6184eeb36272c5ddd6ffbc41fac3 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 18:03:11 +0000 Subject: [PATCH 1/5] add awq moe --- .../model_executor/layers/quantization/awq.py | 191 +++++++++++++++++- vllm/model_executor/model_loader/utils.py | 2 +- 2 files changed, 188 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 410b3cb5321cb..e564b18e7d323 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,14 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Callable, Any, Dict, List, Optional import torch - +from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -64,9 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQLinearMethod"]: + prefix: str) -> Optional["QuantizedMethodBase"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -170,3 +180,176 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + self.num_bits = self.quant_config.weight_bits + self.packed_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": "group", + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1], + layer.w13_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1], + layer.w2_qweight.shape[2] * self.packed_factor, + self.num_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.group_size + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] , + size_n=layer.w2_scales.shape[2] * self.packed_factor, + group_size=self.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits) \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd62..995bb253db8a1 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From e8289ae95dbe7f100898935997906791b01adcc2 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 26 Sep 2024 19:39:56 +0000 Subject: [PATCH 2/5] update --- .../model_executor/layers/quantization/awq.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index e564b18e7d323..b82714bd5ba55 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -18,6 +18,7 @@ marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -181,13 +182,11 @@ def apply(self, out.add_(bias) return out.reshape(out_shape) + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - self.num_bits = self.quant_config.weight_bits - self.packed_factor = self.quant_config.pack_factor - self.group_size = self.quant_config.group_size def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -255,61 +254,60 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) + marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1], - layer.w13_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1], - layer.w2_qweight.shape[2] * self.packed_factor, - self.num_bits, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], - group_size=self.group_size + group_size=self.quant_config.group_size, ) - + # for @eliza: why do we need to apply a pack factor to the scales? + # they're not in packed format? replace_tensor(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] , - size_n=layer.w2_scales.shape[2] * self.packed_factor, - group_size=self.group_size, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) @@ -352,4 +350,5 @@ def apply( g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits) \ No newline at end of file + num_bits=self.quant_config.weight_bits, + ) From 0385aa85eabc5005e706f112e96f370ae3e28326 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 27 Sep 2024 17:07:25 +0000 Subject: [PATCH 3/5] update awq --- vllm/_custom_ops.py | 14 ++++++++ .../layers/fused_moe/fused_moe.py | 2 +- .../model_executor/layers/quantization/awq.py | 36 +++++++++++++------ .../layers/quantization/gptq_marlin.py | 1 + .../layers/quantization/utils/marlin_utils.py | 15 ++++++++ 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 77c46584ef530..8ce01b2d82532 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -317,6 +317,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..1a98666204f93 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -443,7 +443,7 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index b82714bd5ba55..ba912aa6552d3 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -12,11 +12,10 @@ PackedvLLMParameter) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) class AWQConfig(QuantizationConfig): @@ -276,7 +275,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: requires_grad=False, ) - marlin_w13_qweight = ops.gptq_marlin_moe_repack( + marlin_w13_qweight = ops.awq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, size_k=layer.w13_qweight.shape[1], @@ -285,7 +284,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( + marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, size_k=layer.w2_qweight.shape[1], @@ -294,23 +293,38 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - # for @eliza: why do we need to apply a pack factor to the scales? - # they're not in packed format? + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_k=layer.intermediate_size_per_partition, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) replace_tensor(layer, "w2_scales", marlin_w2_scales) + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + def apply( self, layer: torch.nn.Module, @@ -346,6 +360,8 @@ def apply( router_logits, topk_weights, topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, g_idx1=layer.w13_g_idx, g_idx2=layer.w2_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index dd46f0ce5a39c..04bea28ec4630 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -554,6 +554,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) replace_tensor(layer, "w2_qweight", marlin_w2_qweight) # Repack scales + # Why does this take the intermediate size for size_k? marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..db8ec78f937ee 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -188,6 +188,7 @@ def marlin_moe_permute_scales( device=s.device, dtype=s.dtype, ) + for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output @@ -238,6 +239,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) def replace_tensor(layer: torch.nn.Module, name: str, From 3d125547c775e3048e4c327f2a5dbb272f490a8b Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 15:16:38 +0000 Subject: [PATCH 4/5] move to marlin; clean-up --- .../model_executor/layers/quantization/awq.py | 208 +----------------- .../layers/quantization/awq_marlin.py | 206 ++++++++++++++++- vllm/model_executor/model_loader/utils.py | 4 +- 3 files changed, 204 insertions(+), 214 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ba912aa6552d3..30380ec0407c5 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,22 +1,14 @@ -from typing import Callable, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter + from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, moe_awq_to_marlin_zero_points, - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) - class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -72,11 +64,9 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": return cls(weight_bits, group_size, zero_point) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizedMethodBase"]: + prefix: str) -> Optional["AWQLinearMethod"]: if isinstance(layer, LinearBase): return AWQLinearMethod(self) - elif isinstance(layer, FusedMoE): - return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,192 +169,4 @@ def apply(self, pack_factor) if bias is not None: out.add_(bias) - return out.reshape(out_shape) - - -class AWQMoEMethod(FusedMoEMethodBase): - - def __init__(self, quant_config: AWQConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": "group", - }) - - w13_qweight = Parameter(torch.empty(num_experts, - hidden_size, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - - w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - - num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = intermediate_size // self.quant_config.group_size - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter(torch.empty(num_experts, - num_groups_w13, - intermediate_size * 2, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - - w2_scales = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - - # WEIGHT_ZERO_POINT - # Allocate 2 zero points for w1 and w3 respectively. - w13_qzeros = Parameter(torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qzeros", w13_qzeros) - set_weight_attrs(w13_qzeros, extra_weight_attrs) - - w2_qzeros = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qzeros", w2_qzeros) - set_weight_attrs(w2_qzeros, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - num_experts = layer.w13_qweight.shape[0] - device = layer.w13_qweight.device - - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = ops.awq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - size_k=layer.w13_qweight.shape[1], - size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) - - marlin_w2_qweight = ops.awq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - size_k=layer.w2_qweight.shape[1], - size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits, - ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) - - # Why does this take the intermediate size for size_k? - marlin_w13_scales = marlin_moe_permute_scales( - s=layer.w13_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w13_scales.shape[2], - group_size=self.quant_config.group_size, - ) - - replace_tensor(layer, "w13_scales", marlin_w13_scales) - - marlin_w2_scales = marlin_moe_permute_scales( - s=layer.w2_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w2_scales.shape[2], - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) - - marlin_w13_zp = moe_awq_to_marlin_zero_points( - layer.w13_qzeros, - size_k=layer.w13_qzeros.shape[1], - size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w13_qzeros", marlin_w13_zp) - - marlin_w2_zp = moe_awq_to_marlin_zero_points( - layer.w2_qzeros, - size_k=layer.w2_qzeros.shape[1], - size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) - replace_tensor(layer, "w2_qzeros", marlin_w2_zp) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - fused_marlin_moe) - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) - - return fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - layer.w13_scales, - layer.w2_scales, - router_logits, - topk_weights, - topk_ids, - w1_zeros=layer.w13_qzeros, - w2_zeros=layer.w2_qzeros, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.quant_config.weight_bits, - ) + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..9704b1adbce55 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,16 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, replace_tensor, + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -34,12 +39,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits - if weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {weight_bits}. " + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}") - self.quant_type = self.TYPE_MAP[weight_bits] + self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported(self.quant_type, group_size=self.group_size, @@ -97,10 +103,12 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQMarlinLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -270,4 +278,182 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) \ No newline at end of file + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_tensor(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_tensor(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 995bb253db8a1..792c359a559a9 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,9 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin", "awq"] + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq", "awq_marlin" + ] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported From b54b633cbf21ae4a2b600b96be3f04603d9d5c9a Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 30 Sep 2024 16:35:23 +0000 Subject: [PATCH 5/5] fix typo; add test --- tests/weight_loading/models-large.txt | 1 + vllm/model_executor/layers/quantization/awq_marlin.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 2f5c6c5a117f3..8ab7f05d7d1b2 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -2,3 +2,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 9704b1adbce55..5c689f03925a1 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -293,7 +293,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "is_transposed": True, "quant_method": - FusedMoeWeightScaleSupported.GROUP, + FusedMoeWeightScaleSupported.GROUP.value, }) w13_qweight = Parameter(torch.empty(num_experts,