From e5c1a8131c970fbb42540b518c8e37d3d0b150e8 Mon Sep 17 00:00:00 2001 From: DhruvaBansal00 Date: Tue, 6 Aug 2024 17:20:26 -0700 Subject: [PATCH] Refactoring for maintainability --- .../layers/fused_moe/__init__.py | 18 +- .../layers/fused_moe/fused_moe.py | 102 +--- .../layers/fused_moe/fused_moe_gptq.py | 138 +++++ vllm/model_executor/layers/fused_moe/layer.py | 482 ++++++------------ .../layers/quantization/gptq_marlin.py | 356 ++++++++++++- vllm/model_executor/models/mixtral_quant.py | 144 ++---- 6 files changed, 665 insertions(+), 575 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/fused_moe_gptq.py diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 080ecb5cfe0ba..2b982b7ab9f86 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,21 +1,23 @@ -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe, - single_marlin_moe) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.fused_moe.fused_moe import single_marlin_moe +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.triton_utils import HAS_TRITON __all__ = [ "FusedMoE", "FusedMoEMethodBase", - "fused_marlin_moe", + "fused_moe_gptq", "single_marlin_moe", ] if HAS_TRITON: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + fused_experts, + fused_moe, + fused_topk, + get_config_file_name, + grouped_topk, + ) __all__ += [ "fused_moe", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 64e47ad803232..9ae5859c4da0c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -704,104 +704,4 @@ def single_marlin_moe( g_idx, rand_perm, workspace, M, N, K, True, E, topk, block_size_m, True, False) - return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) - - -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py new file mode 100644 index 0000000000000..15c11fc0b668e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe_gptq.py @@ -0,0 +1,138 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +import torch + +from typing import Any, Dict, Optional +from vllm import _custom_ops as ops + +from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config + + +def fused_moe_gptq( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False + ) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + rand_perm1, + workspace, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + rand_perm2, + workspace, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 564a316b4894a..913d6a93b0cd5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -6,16 +6,17 @@ import torch from vllm import _custom_ops as ops -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.fused_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -24,300 +25,63 @@ class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: raise NotImplementedError -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -class MarlinFusedMoEMethod(FusedMoEMethodBase): - """MoE Marlin method with quantization.""" - - def __init__(self, quant_config: GPTQMarlinConfig) -> None: - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - # Currently assuming is_k_full is always True - # (input size per partition is the same as full input size) - # Supports only sym for now (no zp) - if self.quant_config.group_size != -1: - scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = intermediate_size // self.quant_config.group_size - else: - scales_size13 = 1 - scales_size2 = 1 - # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.quant_config.pack_factor, - 2 * intermediate_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_qweight", w13_qweight) - set_weight_attrs(w13_qweight, extra_weight_attrs) - # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size // self.quant_config.pack_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_qweight", w2_qweight) - set_weight_attrs(w2_qweight, extra_weight_attrs) - # up_proj scales - w13_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size13, - 2 * intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_scales", w13_scales) - set_weight_attrs(w13_scales, extra_weight_attrs) - # down_proj scales - w2_scales = torch.nn.Parameter(torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_scales", w2_scales) - set_weight_attrs(w2_scales, extra_weight_attrs) - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - layer.marlin_state = GPTQMarlinState.REPACK - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - num_experts = layer.w13_g_idx.shape[0] - w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_g_idx[e]).to(torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] - replace_tensor("w13_g_idx", w13_sorted_g_idx) - replace_tensor("w2_g_idx", w2_sorted_g_idx) - replace_tensor("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - else: - # Reset g_idx related tensors - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), - dtype=torch.int32, - device=device), - requires_grad=False, - ) - # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2], - self.quant_config.weight_bits, - ) - replace_tensor("w2_qweight", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_scales, - x.shape[1], - layer.w13_scales.shape[2], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w13_scales", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_scales, - layer.w2_scales.shape[1] * self.quant_config.pack_factor, - x.shape[1], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("w2_scales", marlin_w2_scales) - return fused_marlin_moe(x, - layer.w13_qweight, - layer.w2_qweight, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales) - - class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -332,9 +96,17 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, ) -> torch.Tensor: - return self.forward(x, layer.w13_weight, layer.w2_weight, - router_logits, top_k, renormalize, - use_grouped_topk, num_expert_group, topk_group) + return self.forward( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + ) def forward_cuda( self, @@ -349,16 +121,19 @@ def forward_cuda( topk_group: Optional[int], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe - return fused_moe(x, - w1, - w2, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + + return fused_moe( + x, + w1, + w2, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) def forward_cpu(self, *args, **kwargs): raise NotImplementedError( @@ -377,6 +152,7 @@ def forward_tpu( topk_group: Optional[int], ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -386,7 +162,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -438,12 +214,9 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group - self.quant_method: Optional[QuantizeMethodBase] = None - if quant_config is None: - self.quant_method = UnquantizedFusedMoEMethod() - elif isinstance(quant_config, GPTQMarlinConfig): - self.quant_method = MarlinFusedMoEMethod(quant_config) + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedFusedMoEMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None @@ -454,15 +227,18 @@ def __init__( hidden_size=hidden_size, intermediate_size=self.intermediate_size_per_partition, params_dtype=params_dtype, - weight_loader=self.weight_loader) - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: int, - expert_id: int, - is_quantized: bool = False): + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + is_quantized: bool = False, + ): param_data = param.data if is_quantized: @@ -491,8 +267,8 @@ def weight_loader(self, else: # Input scales can be loaded directly and should be equal. if "input_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: + if (param_data[expert_id] != 1 and + (param_data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -546,7 +322,8 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, num_expert_group=self.num_expert_group, - topk_group=self.topk_group) + topk_group=self.topk_group, + ) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -556,37 +333,70 @@ def forward(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, int]]: - + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, int]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] gate_down_up = [ ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name ] - return [ + return ([ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in gate_up else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(num_experts) + ( + "experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" - if weight_name in gate_up else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(num_experts) + ( + "experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" - if weight_name in gate_up else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(num_experts) + ( + "experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + shard_id, + ) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) - ] + ] + [ + # These are the qweights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_qweight" + if weight_name in gate_up else "experts.w2_qweight", + f"experts.{expert_id}.{weight_name}.qweight", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the g_idx and g_idx_sort_indices scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_g_idx" + if weight_name in gate_up else "experts.w2_g_idx", + f"experts.{expert_id}.{weight_name}.g_idx", + expert_id, + shard_id, + ) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ]) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bdcc9c3b4f0c5..f58a89c8e4bb9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,29 +1,53 @@ -from typing import Any, Dict, List, Optional - +from typing import Any, Dict, List, Optional, Union +import enum +from enum import Enum import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe.fused_moe_gptq import fused_moe_gptq +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_gptq_marlin_supported, verify_marlin_supports_shape) + apply_gptq_marlin_linear, + check_gptq_marlin_supported, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, + replace_tensor, + verify_gptq_marlin_supported, + verify_marlin_supports_shape, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -95,11 +119,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -118,15 +145,15 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False return check_gptq_marlin_supported( num_bits=num_bits, group_size=group_size, is_sym=sym, - min_capability=cls.get_min_capability()) + min_capability=cls.get_min_capability(), + ) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -163,7 +190,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -293,7 +321,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -302,7 +331,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -323,4 +353,284 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + else: + scales_size13 = 1 + scales_size2 = 1 + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + layer.marlin_state = GPTQMarlinState.REPACK + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int, + ): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, + num_bits: int, + ): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor("w13_g_idx", w13_sorted_g_idx) + replace_tensor("w2_g_idx", w2_sorted_g_idx) + replace_tensor("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor("w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.weight_bits, + ) + replace_tensor("w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_scales, + x.shape[1], + layer.w13_scales.shape[2], + self.quant_config.group_size, + self.quant_config.weight_bits, + ) + replace_tensor("w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_scales, + layer.w2_scales.shape[1] * self.quant_config.pack_factor, + x.shape[1], + self.quant_config.group_size, + self.quant_config.weight_bits, + ) + replace_tensor("w2_scales", marlin_w2_scales) + return fused_moe_gptq( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + ) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 85dafd55bbcf8..cdfd24874b974 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -35,7 +34,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -96,13 +94,10 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.use_fused_moe = use_fused_moe - self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.num_total_experts = config.num_local_experts @@ -118,26 +113,14 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - if self.use_fused_moe: - params_dtype = torch.float16 - self.experts = FusedMoE(num_experts=self.num_total_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=self.tp_size) - else: - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, @@ -149,34 +132,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits, _ = self.gate(hidden_states) - if self.use_fused_moe: - ret = self.experts(hidden_states.half(), router_logits) - return ret.bfloat16() - else: - routing_weights = F.softmax(router_logits, - dim=1, - dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum( - dim=-1, keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) class MixtralAttention(nn.Module): @@ -261,7 +238,6 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -278,7 +254,6 @@ def __init__( cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - use_fused_moe=use_fused_moe, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -319,7 +294,6 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - use_fused_moe: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -333,7 +307,6 @@ def __init__( ) self.layers = nn.ModuleList([ MixtralDecoderLayer(config, - use_fused_moe, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -370,12 +343,10 @@ def __init__( super().__init__() # TODO check runs with dtype=float16 - self.use_fused_moe = (config.torch_dtype != torch.float8_e4m3fn) self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, self.use_fused_moe, cache_config, - quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -436,50 +407,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if self.use_fused_moe: - if ("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - has_any_numbered = (".qweight" in name or ".scales" in name - or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - expert_id = int(exp_string.split(".")[2]) - name = name.replace(exp_string, ".experts.") - - else: - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - - param = params_dict[name] - - if self.use_fused_moe and shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight, name, shard_id, - expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)