diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 73e20025e77a3..51c54d800a761 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,8 +10,9 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import (fused_marlin_moe, fused_moe, - single_marlin_moe) +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_marlin_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -63,11 +64,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -166,11 +167,11 @@ def test_fused_marlin_moe( quant_type = scalar_types.uint4b8 dtype = torch.float16 - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device='cuda', dtype=dtype) + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -218,27 +219,32 @@ def test_fused_marlin_moe( g_idx2 = stack_and_dev(g_idx2_l) sort_indices2 = stack_and_dev(sort_indices2_l) - score = torch.randn((m, e), device='cuda', dtype=dtype) - triton_output = fused_moe(a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False) - marlin_output = fused_marlin_moe(a, - qweight1, - qweight2, - score, - g_idx1, - g_idx2, - sort_indices1, - sort_indices2, - topk, - renormalize=False, - w1_scale=scales1, - w2_scale=scales2) - - assert (compute_max_diff(marlin_output, triton_output) < 4e-2) + score = torch.randn((m, e), device="cuda", dtype=dtype) + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_moe_marlin( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk, + renormalize=False, + w1_scale=scales1, + w2_scale=scales2, + num_bits=4, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 # TODO: make sure this test works @@ -275,8 +281,8 @@ def test_single_marlin_moe( quant_type = scalar_types.uint4b8 dtype = torch.float16 - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w = torch.randn((e, n, k), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 w_ref_l = [] qweights_l = [] @@ -300,7 +306,7 @@ def test_single_marlin_moe( g_idx = stack_and_dev(g_idx_l) sort_indices = stack_and_dev(sort_indices_l) - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) marlin_output = single_marlin_moe(a, qweight, scales, @@ -311,4 +317,4 @@ def test_single_marlin_moe( renormalize=False) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) - assert (compute_max_diff(marlin_output, torch_output) < 1e-2) + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9b18d7e645cac..10979bc244686 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -304,7 +304,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 080ecb5cfe0ba..73315d8e71fcd 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,5 +1,5 @@ -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe, - single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin, single_marlin_moe) from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) from vllm.triton_utils import HAS_TRITON @@ -7,12 +7,11 @@ __all__ = [ "FusedMoE", "FusedMoEMethodBase", - "fused_marlin_moe", + "fused_moe_marlin", "single_marlin_moe", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ff4d7dd4cae3c..613d67e64bff5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -666,181 +666,3 @@ def fused_moe( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale) - - -def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - rand_perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, -) -> torch.Tensor: - """ - This function computes a Marlin MoE MMM using weights w - and top-k gating mechanism. It is meant for testing and debugging. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w (torch.Tensor): The first set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w and w2. Defaults to False. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" - assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w.is_contiguous(), "Expert weights must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, K = hidden_states.shape - E = w.shape[0] - N = w.shape[2] // 2 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - # This might not be an optimal config for a single MMM - get_config_func = functools.partial(try_get_optimal_moe_config, - w.shape, - w.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = (N // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, - True, False) - - return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) - - -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py new file mode 100644 index 0000000000000..469b20ccf24de --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_marlin.py @@ -0,0 +1,227 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops + +from .fused_moe import (fused_topk, moe_align_block_size, + try_get_optimal_moe_config) + + +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + rand_perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, +) -> torch.Tensor: + """ + This function computes a Marlin MoE MMM using weights w + and top-k gating mechanism. It is meant for testing and debugging. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w (torch.Tensor): The first set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w and w2. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // 2 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, + True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_moe_marlin( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ce9cba1a6c01f..160f6948648af 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,21 +1,15 @@ -import enum from abc import abstractmethod -from enum import Enum from typing import List, Optional, Tuple import torch -from vllm import _custom_ops as ops from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.fused_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -24,294 +18,57 @@ class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError @abstractmethod - def apply(self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: raise NotImplementedError -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -class MarlinFusedMoEMethod(FusedMoEMethodBase): - """MoE Marlin method with quantization.""" - - def __init__(self, quant_config: GPTQMarlinConfig) -> None: - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) - if self.quant_config.group_size != -1: - scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = intermediate_size // self.quant_config.group_size - else: - scales_size13 = 1 - scales_size2 = 1 - # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.quant_config.pack_factor, - 2 * intermediate_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size // self.quant_config.pack_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - # up_proj scales - w13_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size13, - 2 * intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - # down_proj scales - w2_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - layer.marlin_state = GPTQMarlinState.REPACK - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - num_experts = layer.w13_g_idx.shape[0] - w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_g_idx[e]).to(torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] - replace_tensor("w13_g_idx", w13_sorted_g_idx) - replace_tensor("w2_g_idx", w2_sorted_g_idx) - replace_tensor("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - else: - # Reset g_idx related tensors - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.quant_type.size_bits, - ) - replace_tensor("w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2], - self.quant_config.quant_type.size_bits, - ) - replace_tensor("w2_qweight", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_scales, - x.shape[1], - layer.w13_scales.shape[2], - self.quant_config.group_size, - self.quant_config.quant_type.size_bits, - ) - replace_tensor("w13_scales", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_scales, - layer.w2_scales.shape[1] * self.quant_config.pack_factor, - x.shape[1], - self.quant_config.group_size, - self.quant_config.quant_type.size_bits, - ) - replace_tensor("w2_scales", marlin_w2_scales) - return fused_marlin_moe(x, - layer.w13_qweight, - layer.w2_qweight, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales) - - class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -320,10 +77,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -435,6 +192,7 @@ def __init__( get_tensor_model_parallel_world_size()) self.top_k = top_k self.num_experts = num_experts + self.intermediate_size = intermediate_size self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -444,12 +202,9 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group - self.quant_method: Optional[QuantizeMethodBase] = None - if quant_config is None: - self.quant_method = UnquantizedFusedMoEMethod() - elif isinstance(quant_config, GPTQMarlinConfig): - self.quant_method = MarlinFusedMoEMethod(quant_config) + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None @@ -460,28 +215,32 @@ def __init__( hidden_size=hidden_size, intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, - weight_loader=self.weight_loader) - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - is_quantized: bool = False): + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + is_quantized: bool = False, + ): param_data = param.data if is_quantized: - if "_qweight" in weight_name or "_scales" in weight_name: + if ("_qweight" in weight_name or "_scales" in weight_name + or "_qzeros" in weight_name): if "w13" in weight_name: - shard_size = self.intermediate_size_per_partition - if shard_id == 0: + shard_size = loaded_weight.size()[-1] + if shard_id == "w1": param_data[expert_id, :, :shard_size] = loaded_weight - elif shard_id == 1: + elif shard_id == "w3" or shard_id == "w2": param_data[expert_id, :, shard_size:] = loaded_weight else: raise ValueError(f"Invalid shard_id: {shard_id}: " - "must be 0 or 1.") + "must be 0, 1, or 2.") elif "w2" in weight_name: param_data[expert_id][:] = loaded_weight else: @@ -585,8 +344,8 @@ def forward(self, hidden_states: torch.Tensor, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group) + num_expert_group=self.num_expert_group, + topk_group=self.topk_group) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -599,18 +358,100 @@ def make_expert_params_mapping( cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, ckpt_up_proj_name: str, num_experts: int) -> List[Tuple[str, str, int, str]]: - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ + ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name ] + return ([ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_scales" + if weight_name in gate_up else "experts.w2_scales", + f"experts.{expert_id}.{weight_name}.scales", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + f"a{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the qweights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qweight" + if weight_name in gate_up else "experts.w2_qweight", + f"experts.{expert_id}.{weight_name}.qweight", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the g_idx and g_idx_sort_indices scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_g_idx" + if weight_name in gate_up else "experts.w2_g_idx", + f"experts.{expert_id}.{weight_name}.g_idx", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the g_idx and g_idx_sort_indices scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qzeros" + if weight_name in gate_up else "experts.w2_qzeros", + f"experts.{expert_id}.{weight_name}.qzeros", + expert_id, + f"w{shard_id + 1}", + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ]) + + # return [ + # # (param_name, weight_name, expert_id, shard_id) + # ("experts.w13_" if weight_name + # in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + # f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + # for expert_id in range(num_experts) for shard_id, weight_name in [ + # ("w1", ckpt_gate_proj_name), + # ("w2", ckpt_down_proj_name), + # ("w3", ckpt_up_proj_name), + # ] + # ] def _load_fp8_scale(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fef..153bccc303ef1 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -157,7 +157,7 @@ def quantize_and_call_weight_loader(param: torch.nn.Parameter, layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) else: raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") + f"Shard id must be in ['w1','w2','w3'] but got {shard_id}") weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fd7682a1c0f51..2939c57ac95de 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -466,10 +466,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..5e5fb64af8c32 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,26 @@ -from typing import Any, Dict, List, Optional +import enum +from enum import Enum +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_marlin import ( + fused_moe_marlin) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -24,6 +32,11 @@ logger = init_logger(__name__) +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" @@ -33,8 +46,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -109,11 +128,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -179,7 +201,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +322,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +332,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +354,249 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=(layer.intermediate_size if self.quant_config.desc_act else + layer.intermediate_size_per_partition), + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + layer.marlin_state = GPTQMarlinState.READY + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return fused_moe_marlin( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, + ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..699d5f1844146 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 32cafa845a6e3..11ba0070ef735 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -37,7 +37,7 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 587d2f26a2d5e..935d3d58eb5e7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import logging from typing import Iterable, List, Optional, Tuple import torch @@ -50,6 +51,8 @@ from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers +logger = logging.getLogger(__name__) + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert @@ -60,45 +63,52 @@ class MixtralMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = torch.float16, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size - + self.params_dtype = params_dtype # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + orig_shape, orig_type = hidden_states.shape, hidden_states.dtype + hidden_states = hidden_states.view(-1, self.hidden_size).to( + self.params_dtype) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape) + return final_hidden_states.view(orig_shape).to(orig_type) class MixtralAttention(nn.Module): @@ -159,12 +169,14 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -202,14 +214,18 @@ def __init__( rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + tp_size=get_tensor_model_parallel_world_size(), + params_dtype=torch.float16, + prefix=f"{prefix}.block_sparse_moe", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -256,8 +272,8 @@ def __init__( ) -> None: super().__init__() self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -272,7 +288,8 @@ def __init__( lambda prefix: MixtralDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -293,9 +310,13 @@ def forward( residual = intermediate_tensors["residual"] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -340,7 +361,6 @@ def __init__( self.config = config self.lora_config = lora_config - self.model = MixtralModel(config, cache_config, quant_config, @@ -420,14 +440,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -447,17 +468,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) # Skip layers on other devices. + name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + is_quantized=True, + ) break else: # Skip loading extra bias for GPTQ models. diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index c13d3c378eee9..4c1c74aee61b4 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import logging import re from typing import Iterable, List, Optional, Tuple @@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +logger = logging.getLogger(__name__) + class MixtralMLP(nn.Module): @@ -371,7 +374,6 @@ def __init__( # TODO check runs with dtype=float16 self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) - self.config = config self.quant_config = quant_config self.model = MixtralModel(config, self.use_fused_moe, cache_config,