diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e45d2a59f4247..bc974b10e8fb5 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -23,10 +23,7 @@ #include #include -// #include "marlin_moe_ops.h" - #include -// #include template inline std::string str(T x) { diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e265bccc81c39..c708db2754675 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -260,9 +260,10 @@ def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] output = torch.empty((num_experts, size_k // 16, size_n * 2), device=b_q_weight.device, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index dc8a05c80a128..cf316c0f9afa4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -737,17 +737,6 @@ def fused_marlin_moe(hidden_states: torch.Tensor, E = w1.shape[0] N = w2.shape[1] * 16 - # print("hidden_states shape:", hidden_states) - # print("w1 shape:", w1) - # print("w2 shape:", w2) - # print("gating_output shape:", gating_output) - # print("g_idx1 shape:", g_idx1) - # print("g_idx2 shape:", g_idx2) - # print("w1_scale shape:", w1_scale) - # print("w2_scale shape:", w2_scale) - - # raise ValueError("stop") - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 32edcfb7d3155..b705c7ec4be4d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,19 +1,19 @@ from abc import abstractmethod -from typing import Optional, List +from typing import List, Optional import torch from vllm import _custom_ops as ops - from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe, fused_marlin_moe -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, GPTQMarlinState) +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_marlin_moe, + fused_moe) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinState) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -36,6 +36,7 @@ def apply(self, renormalize: bool = True) -> torch.Tensor: raise NotImplementedError + class MarlinFusedMoEMethod(FusedMoEMethodBase): """MoE Marlin method with quantization.""" @@ -50,38 +51,30 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # (input size per partition is the same as full input size) # Supports only sym for now (no zp) - #TODO scales g_idx etc. - #also do marlin transformations - - # print("*") - # print("group_size:", self.quant_config.group_size) - # print("hidden_size:", hidden_size) - # print("intermediate_size:", intermediate_size) - if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = intermediate_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size else: scales_size13 = 1 scales_size2 = 1 # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size // self.quant_config.pack_factor, - 2 * intermediate_size, - # 2 * intermediate_size, - # hidden_size // self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w13_qweight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty(num_experts, - intermediate_size // self.quant_config.pack_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) + w2_qweight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) @@ -111,7 +104,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - # print("w13g shape:", w13_g_idx.shape) layer.register_parameter("w13_g_idx", w13_g_idx) set_weight_attrs(w13_g_idx, extra_weight_attrs) @@ -134,7 +126,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( @@ -145,7 +138,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.marlin_state = GPTQMarlinState.REPACK @@ -157,9 +151,6 @@ def apply(self, top_k: int, renormalize: bool = True) -> torch.Tensor: - # print("1", layer.w13_scales) - - # TODO translate qweights into Marlin format if layer.marlin_state == GPTQMarlinState.REPACK: layer.marlin_state = GPTQMarlinState.READY @@ -182,30 +173,31 @@ def get_scale_perms(num_bits: int): [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single - def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): + def marlin_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): scale_perm, scale_perm_single = get_scale_perms(num_bits) if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): num_experts = s.shape[0] output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) + device=s.device, + dtype=s.dtype) for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size, num_bits) return output - # print("2", layer.w13_scales) - # Process act_order if self.quant_config.desc_act: # Get sorting based on g_idx @@ -215,50 +207,50 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int, w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] replace_tensor("w13_g_idx", w13_sorted_g_idx) replace_tensor("w2_g_idx", w2_sorted_g_idx) - replace_tensor("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_tensor("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) replace_tensor("w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty(0, dtype=torch.int), + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty(0, dtype=torch.int), + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty(0, dtype=torch.int), + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty(0, dtype=torch.int), + torch.empty((num_experts, 0), + dtype=torch.int32, + device=device), requires_grad=False, ) - # print("3", layer.w13_scales) - - # print("*") - # print("hidden:", x.shape) - # print("w13 before:", layer.w13_qweight.shape) - # print("w2 before:", layer.w2_qweight.shape) - # print("w13 args:", layer.w13_qweight.shape[1] - # * self.quant_config.pack_factor, - # layer.w13_qweight.shape[2]) - # print("w2 args:", layer.w2_qweight.shape[1] - # * self.quant_config.pack_factor, - # layer.w2_qweight.shape[2]) - - # print("weight type:", layer.w13_qweight.dtype) - # Repack weights marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_qweight, @@ -276,15 +268,6 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int, self.quant_config.weight_bits, ) replace_tensor("w2_qweight", marlin_w2_qweight) - - # print("w13 after:", marlin_w13_qweight.shape) - # print("w2 after:", marlin_w2_qweight.shape) - - # print("w13 scales before:", layer.w13_scales.shape) - # print("w2 scales before:", layer.w2_scales.shape) - # print("w13 args:", x.shape[1], layer.w13_scales.shape[2]) - # print("w2 args:", layer.w2_scales.shape[1] * self.quant_config.pack_factor, - # x.shape[1]) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( @@ -305,9 +288,6 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int, ) replace_tensor("w2_scales", marlin_w2_scales) - # print("w13 scales after:", marlin_w13_scales.shape) - # print("w2 scales after:", marlin_w2_scales.shape) - return fused_marlin_moe(x, layer.w13_qweight, layer.w2_qweight, @@ -321,6 +301,7 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales) + class UnquantizedFusedMoEMethod(FusedMoEMethodBase): """MoE method without quantization.""" @@ -362,7 +343,6 @@ def apply(self, inplace=True) -# TODO should work from this class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -409,15 +389,17 @@ def __init__( self.reduce_results = reduce_results self.renormalize = renormalize - # TODO we need to rewrite to QuantizedFusedMoEMethod + self.quant_method: Optional[QuantizeMethodBase] = None + if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + self.quant_method = UnquantizedFusedMoEMethod() else: - # TODO assert GPTQ quant config - self.quant_method: Optional[QuantizeMethodBase] = ( - MarlinFusedMoEMethod(quant_config)) - # self.quant_method = quant_config.get_quant_method(self) + if not isinstance(quant_config, GPTQMarlinConfig): + raise ValueError("Fused quantized MoE layer must use " + "GPTQMarlinConfig, but " + f"{quant_config.__class__.__name__} found.") + self.quant_method = MarlinFusedMoEMethod(quant_config) + assert self.quant_method is not None self.quant_method.create_weights( @@ -428,50 +410,44 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: int, expert_id: int, is_quantized: bool = False): + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + is_quantized: bool = False): param_data = param.data - # print("param_data shape:", param_data.shape) - # print("loaded weight shape:", loaded_weight.shape) - # TODO why is param_data[expert_id].shape == 7 * loaded_weight.shape? - - # print(param_data[expert_id]) if is_quantized: if "_qweight" in weight_name or "_scales" in weight_name: - # if "_scales" in weight_name: - # print("scales:", loaded_weight) if "w13" in weight_name: shard_size = self.intermediate_size_per_partition if shard_id == 0: - param_data[expert_id, :, :shard_size] = loaded_weight - # if "_scales" in weight_name: - # print("param:", param_data[expert_id, :, :shard_size]) + param_data[expert_id, :, :shard_size] = loaded_weight elif shard_id == 1: param_data[expert_id, :, shard_size:] = loaded_weight - # if "_scales" in weight_name: - # print("param:", param_data[expert_id, :, shard_size:]) else: - ValueError("wrong shard:", shard_id) + raise ValueError(f"Invalid shard_id: {shard_id}: " + "must be 0 or 1.") elif "w2" in weight_name: param_data[expert_id][:] = loaded_weight - # if "_scales" in weight_name: - # print("param:", param_data[expert_id][:]) else: - ValueError("what is this?", weight_name) + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") elif "_g_idx" in weight_name: if "w13" not in weight_name and "w2" not in weight_name: - ValueError("what is this?", weight_name) + raise ValueError(f"Invalid weight name: {weight_name}: " + "must contain 'w13' or 'w2'.") param_data[expert_id] = loaded_weight else: - ValueError("what is this?", weight_name) + raise ValueError(f"Invalid weight name: {weight_name}.") else: # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. # Follow up PR to enable fp8 for other MoE models. if "input_scale" in weight_name or "w2.weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: + loaded_weight).abs() > 1e-5: raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " @@ -493,11 +469,11 @@ def weight_loader(self, param: torch.nn.Parameter, # w1, gate_proj case: Load into first shard of w13. if shard_id == 0: param_data[expert_id, - 0:shard_size, :] = loaded_weight[shard, :] + 0:shard_size, :] = loaded_weight[shard, :] # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * - shard_size, :] = loaded_weight[shard, :] + shard_size, :] = loaded_weight[shard, :] # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 18c23564fcda5..1656460df26c8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -222,51 +222,6 @@ def extra_repr(self) -> str: s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" return s - -class FusedLinearMarlin(LinearBase): - - """ - Args: - input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__(self, - input_size13: int, - input_size2: int, - quant_config: Optional[QuantizationConfig] = None): - # calling with inputsize13 is a bit of an ugly workaround, - # it's not used for anything - super().__init__(input_size13, input_size13, False, None, - quant_config) - self.input_size13 = input_size13 - self.input_size2 = input_size2 - self.output_size13 = input_size2 - self.output_size2 = input_size13 - - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, self.input_size13, self.input_size2, - self.output_size13, self.output_size2, - self.input_size13, self.input_size2, - self.params_dtype) - - self.register_parameter("bias", None) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert self.quant_method is not None - output = self.quant_method.apply(self, x, None) - return output, None - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", output_features={self.output_size}" - s += f", bias={self.bias is not None}" - return s class ColumnParallelLinear(LinearBase): diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 4c9ebc16c7367..a6284d0ed7b1b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -8,7 +8,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - FusedLinearMarlin, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -152,8 +151,6 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: - if isinstance(layer, FusedLinearMarlin): - return GPTQMarlinFusedLinearMethod(self) if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return GPTQMarlinLinearMethod(self) @@ -389,10 +386,6 @@ def apply( out_shape = x.shape[:-1] + (part_size_n, ) - #TODO should make the new implementation also depend on repacking / not repacking here - # otherwise we lose 2x time doing superfluous computations - # maybe also repack q1/q3 separately before merging, depending on how fast it is compared to q13 - if layer.marlin_state == GPTQMarlinState.REPACK: layer.marlin_state = GPTQMarlinState.READY @@ -428,8 +421,6 @@ def replace_tensor(name, new_t): requires_grad=False, ) - # print("do repack", layer.qweight.shape, layer.g_idx_sort_indices.shape) - # Repack weights marlin_qweight = ops.gptq_marlin_repack( layer.qweight, @@ -455,9 +446,6 @@ def replace_tensor(name, new_t): ) replace_tensor("scales", marlin_scales) - # else: - # print("do not repack") - output = ops.gptq_marlin_gemm( reshaped_x, layer.qweight, @@ -476,472 +464,3 @@ def replace_tensor(name, new_t): output.add_(bias) # In-place add return output.reshape(out_shape) - -class GPTQMarlinFusedLinearMethod(LinearMethodBase): - """Linear method for fused GPTQ Marlin. - - Args: - quant_config: The GPTQ Marlin quantization config. - """ - - def __init__(self, quant_config: GPTQMarlinConfig) -> None: - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition13: int, - input_size_per_partition2: int, - output_size_per_partition13: int, - output_size_per_partition2: int, - input_size13: int, - input_size2: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - # Normalize group_size - if self.quant_config.group_size != -1: - group_size13 = self.quant_config.group_size - group_size2 = self.quant_config.group_size - else: - group_size13 = input_size13 - group_size2 = input_size2 - - # Validate dtype - if params_dtype not in [torch.float16, torch.bfloat16]: - raise ValueError(f"The params dtype must be float16 " - f"or bfloat16, but got {params_dtype}") - - # Validate output_size_per_partition - if output_size_per_partition13 % self.quant_config.min_thread_n != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition13} is not divisible by " - f" min_thread_n = {self.quant_config.min_thread_n}.") - if output_size_per_partition2 % self.quant_config.min_thread_n != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition2} is not divisible by " - f" min_thread_n = {self.quant_config.min_thread_n}.") - - # Validate input_size_per_partition - if input_size_per_partition13 % self.quant_config.min_thread_k != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition13} is not divisible " - f"by min_thread_k = {self.quant_config.min_thread_k}.") - if input_size_per_partition2 % self.quant_config.min_thread_k != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition2} is not divisible " - f"by min_thread_k = {self.quant_config.min_thread_k}.") - - if (group_size13 < input_size13 - and input_size_per_partition13 % group_size13 != 0): - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition13}" - f" is not divisible by group_size = {group_size13}.") - if (group_size2 < input_size2 - and input_size_per_partition2 % group_size2 != 0): - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition2}" - f" is not divisible by group_size = {group_size2}.") - - # Detect sharding of scales/zp - - # By default, no sharding over "input dim" - scales_and_zp_size13 = input_size13 // group_size13 - scales_and_zp_size2 = input_size2 // group_size2 - scales_and_zp_input_dim13 = None - scales_and_zp_input_dim2 = None - - if self.quant_config.desc_act: - # Act-order case - assert self.quant_config.group_size != -1 - - is_k_full = (input_size_per_partition13 == input_size13 and - input_size_per_partition2 == input_size2) - - else: - # No act-order case - - # K is always full due to full alignment with - # group-size and shard of scales/zp - is_k_full = True - - # If this is a row-parallel case, then shard scales/zp - if (input_size13 != input_size_per_partition13 - and self.quant_config.group_size != -1): - scales_and_zp_size13 = input_size_per_partition13 // group_size13 - scales_and_zp_input_dim13 = 0 - if (input_size2 != input_size_per_partition2 - and self.quant_config.group_size != -1): - scales_and_zp_size2 = input_size_per_partition2 // group_size2 - scales_and_zp_input_dim2 = 0 - - # Init buffers - - # Quantized weights - qweight1 = Parameter( - torch.empty( - input_size_per_partition13 // self.quant_config.pack_factor, - output_size_per_partition13, - dtype=torch.int32, - ), - requires_grad=False, - ) - qweight2 = Parameter( - torch.empty( - input_size_per_partition2 // self.quant_config.pack_factor, - output_size_per_partition2, - dtype=torch.int32, - ), - requires_grad=False, - ) - qweight3 = Parameter( - torch.empty( - input_size_per_partition13 // self.quant_config.pack_factor, - output_size_per_partition13, - dtype=torch.int32, - ), - requires_grad=False, - ) - qweight13 = Parameter( - torch.empty( - input_size_per_partition13 // self.quant_config.pack_factor * 2, - output_size_per_partition13, - dtype=torch.int32, - ), - requires_grad=False, - ) - qweight_attrs = { - **extra_weight_attrs, - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - } - - set_weight_attrs(qweight1, qweight_attrs) - set_weight_attrs(qweight2, qweight_attrs) - set_weight_attrs(qweight3, qweight_attrs) - set_weight_attrs(qweight13, qweight_attrs) - - # Activation order - g_idx13 = Parameter( - torch.empty( - input_size_per_partition13, - dtype=torch.int32, - ), - requires_grad=False, - ) - g_idx2 = Parameter( - torch.empty( - input_size_per_partition2, - dtype=torch.int32, - ), - requires_grad=False, - ) - g_idx_attrs = { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - } - # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx13, g_idx_attrs) - set_weight_attrs(g_idx2, g_idx_attrs) - - g_idx_sort_indices13 = Parameter( - torch.empty( - g_idx13.shape, - dtype=torch.int32, - ), - requires_grad=False, - ) - g_idx_sort_indices2 = Parameter( - torch.empty( - g_idx2.shape, - dtype=torch.int32, - ), - requires_grad=False, - ) - - # Scales - scales1 = Parameter( - torch.empty( - scales_and_zp_size13, - output_size_per_partition13, - dtype=params_dtype, - ), - requires_grad=False, - ) - scales2 = Parameter( - torch.empty( - scales_and_zp_size2, - output_size_per_partition2, - dtype=params_dtype, - ), - requires_grad=False, - ) - scales3 = Parameter( - torch.empty( - scales_and_zp_size13, - output_size_per_partition13, - dtype=params_dtype, - ), - requires_grad=False, - ) - scales13 = Parameter( - torch.empty( - scales_and_zp_size13 * 2, - output_size_per_partition13, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - scales1, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim13, - "output_dim": 1, - }, - ) - set_weight_attrs( - scales2, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim2, - "output_dim": 1, - }, - ) - set_weight_attrs( - scales3, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim13, - "output_dim": 1, - }, - ) - set_weight_attrs( - scales13, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim13, - "output_dim": 1, - }, - ) - - # No zero-point support - - # Allocate marlin workspace - # TODO we'll need multiple output sizes per partition (take max) - max_workspace_size = ( - max(output_size_per_partition13, output_size_per_partition2) // - self.quant_config.min_thread_n) * self.quant_config.max_parallel - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - requires_grad=False) - - layer.register_parameter("qweight1", qweight1) - layer.register_parameter("qweight2", qweight2) - layer.register_parameter("qweight3", qweight3) - layer.register_parameter("qweight13", qweight13) - layer.register_parameter("g_idx13", g_idx13) - layer.register_parameter("g_idx2", g_idx2) - layer.register_parameter("scales1", scales1) - layer.register_parameter("scales2", scales2) - layer.register_parameter("scales3", scales3) - layer.register_parameter("scales13", scales13) - layer.register_parameter("g_idx_sort_indices13", g_idx_sort_indices13) - layer.register_parameter("g_idx_sort_indices2", g_idx_sort_indices2) - layer.g_idx_sort_indices13 = g_idx_sort_indices13 - layer.g_idx_sort_indices2 = g_idx_sort_indices2 - layer.workspace = workspace - layer.input_size_per_partition13 = input_size_per_partition13 - layer.input_size_per_partition2 = input_size_per_partition2 - layer.output_size_per_partition13 = output_size_per_partition13 - layer.output_size_per_partition2 = output_size_per_partition2 - layer.input_size13 = input_size13 - layer.input_size2 = input_size2 - layer.is_k_full = is_k_full - layer.marlin_state = GPTQMarlinState.REPACK - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # reshaped_x = x.reshape(-1, x.shape[-1]) - - # size_m = reshaped_x.shape[0] - part_size_k = layer.input_size_per_partition13 - part_size_n = layer.input_size_per_partition2 - full_size_k = layer.input_size13 - full_size_n = layer.input_size2 - - # out_shape = x.shape[:-1] + (part_size_n, ) - - #TODO should make the new implementation also depend on repacking / not repacking here - # otherwise we lose 2x time doing superfluous computations - # maybe also repack q1/q3 separately before merging, depending on how fast it is compared to q13 - - if layer.marlin_state == GPTQMarlinState.REPACK: - layer.marlin_state = GPTQMarlinState.READY - - # Newly generated tensors need to replace existing tensors that are - # already registered as parameters by vLLM (and won't be freed) - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - cur_device = layer.qweight1.device - - # Process act_order - if self.quant_config.desc_act: - # Get sorting based on g_idx - g_idx_sort_indices13 = torch.argsort(layer.g_idx13).to(torch.int) - g_idx_sort_indices2 = torch.argsort(layer.g_idx2).to(torch.int) - - sorted_g_idx13 = layer.g_idx13[g_idx_sort_indices13] - sorted_g_idx2 = layer.g_idx2[g_idx_sort_indices2] - - replace_tensor("g_idx13", sorted_g_idx13) - replace_tensor("g_idx2", sorted_g_idx2) - replace_tensor("g_idx_sort_indices13", g_idx_sort_indices13) - replace_tensor("g_idx_sort_indices2", g_idx_sort_indices2) - - else: - # Reset g_idx related tensors - layer.g_idx13 = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx2 = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices13 = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - layer.g_idx_sort_indices2 = Parameter( - torch.empty(0, dtype=torch.int, device=cur_device), - requires_grad=False, - ) - - # print("do repack", layer.qweight1.shape, layer.qweight2.shape, layer.qweight3.shape) - - layer_qweight13 = torch.cat((layer.qweight1, layer.qweight3), 1) - - # print("*") - # print("hidden:", x.shape) - # print("w13 before:", layer_qweight13.shape) - # print("w2 before:", layer.qweight2.shape) - # print("w13 args:", part_size_k, layer_qweight13.shape[1]) - # print("w2 args:", part_size_n, part_size_k) - - # Repack weights - # marlin_qweight1 = ops.gptq_marlin_repack( - # layer.qweight1, - # layer.g_idx_sort_indices13, - # part_size_k, - # part_size_n, - # self.quant_config.weight_bits, - # ) - # replace_tensor("qweight1", marlin_qweight1) - marlin_qweight2 = ops.gptq_marlin_repack( - layer.qweight2, - layer.g_idx_sort_indices2, - part_size_n, - part_size_k, - self.quant_config.weight_bits, - ) - replace_tensor("qweight2", marlin_qweight2) - # marlin_qweight3 = ops.gptq_marlin_repack( - # layer.qweight3, - # layer.g_idx_sort_indices13, - # part_size_k, - # part_size_n, - # self.quant_config.weight_bits, - # ) - # replace_tensor("qweight3", marlin_qweight3) - - # print("13:", layer_qweight13.shape, part_size_k * 2, part_size_n) - marlin_qweight13 = ops.gptq_marlin_repack( - layer_qweight13, - layer.g_idx_sort_indices13, - part_size_k, - layer_qweight13.shape[1], - self.quant_config.weight_bits, - ) - replace_tensor("qweight13", marlin_qweight13) - - # print("w13 after:", marlin_qweight13.shape) - # print("w2 after:", marlin_qweight2.shape) - - # print("done repack", layer.get_parameter("qweight1").shape, - # layer.get_parameter("qweight2").shape, - # layer.get_parameter("qweight3").shape) - - # Permute scales - scales_size_k = part_size_k - scales_size_n = part_size_n - if self.quant_config.desc_act: - scales_size_k = full_size_k - scales_size_n = full_size_n - - layer_scales13 = torch.cat((layer.scales1, layer.scales3), 1) - - # print("w13 scales before:", layer_scales13.shape) - # print("w2 scales before:", layer.scales2.shape) - # print("w13 args:", part_size_k, layer_qweight13.shape[1]) - # print("w2 args:", layer.scales2.shape[0] * 8, layer.scales2.shape[1]) - - # marlin_scales1 = marlin_permute_scales( - # layer.scales1, - # scales_size_k, - # scales_size_n, - # self.quant_config.group_size, - # self.quant_config.weight_bits, - # ) - # replace_tensor("scales1", marlin_scales1) - marlin_scales2 = marlin_permute_scales( - layer.scales2, - layer.scales2.shape[0] * 8, - layer.scales2.shape[1], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales2", marlin_scales2) - # marlin_scales3 = marlin_permute_scales( - # layer.scales3, - # scales_size_k, - # scales_size_n, - # self.quant_config.group_size, - # self.quant_config.weight_bits, - # ) - # replace_tensor("scales3", marlin_scales3) - marlin_scales13 = marlin_permute_scales( - layer_scales13, - part_size_k, - layer_qweight13.shape[1], - self.quant_config.group_size, - self.quant_config.weight_bits, - ) - replace_tensor("scales13", marlin_scales13) - - # print("w13 scales after:", marlin_scales13.shape) - # print("w2 scales after:", marlin_scales2.shape) - - # raise ValueError("stop") - - # else: - # print("do not repack") - - # return output.reshape(out_shape) - return None #the computation is done elsewhere - diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 27e5ca5bc81b8..7cb6a79156d81 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" +import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -28,25 +29,20 @@ import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -import re -from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import fused_marlin_moe, FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, - FusedLinearMarlin, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_permute_scales_numbits) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -63,48 +59,36 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - experimental_fused_moe: bool, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.num_experts = num_experts self.ffn_dim = intermediate_size self.hidden_dim = hidden_size - self.experimental_fused_moe = experimental_fused_moe - # TODO - # print("hidden dim:", self.hidden_dim, "ffn_dim:", self.ffn_dim) - if self.experimental_fused_moe: - self.ws = FusedLinearMarlin(self.hidden_dim, self.ffn_dim, - quant_config=quant_config) - else: - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - quant_config=quant_config) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - quant_config=quant_config) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if self.experimental_fused_moe: - current_hidden_states = self.ws(hidden_states) - return current_hidden_states - else: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states class MixtralMoE(nn.Module): @@ -113,13 +97,11 @@ def __init__( self, config: MixtralConfig, experimental_fused_moe: bool, - old_code: bool, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.experimental_fused_moe = experimental_fused_moe - self.old_code = old_code self.quant_config = quant_config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -136,18 +118,7 @@ def __init__( raise ValueError( f"Rank {self.rank} has no experts assigned to it.") - if self.old_code: - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - self.experimental_fused_moe, - quant_config=quant_config) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - else: - # TODO type + if self.experimental_fused_moe: params_dtype = torch.float16 self.experts = FusedMoE(num_experts=self.num_total_experts, top_k=self.top_k, @@ -158,6 +129,16 @@ def __init__( renormalize=True, quant_config=quant_config, tp_size=self.tp_size) + else: + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, @@ -169,127 +150,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) if self.experimental_fused_moe: - - if not self.old_code: - return self.experts(hidden_states.half(), router_logits).bfloat16() - - qweight13_l = [] - scales13_l = [] - qweight2_l = [] - scales2_l = [] - g_idx13_l = [] - g_idx2_l = [] - g_idx_sort_idx13_l = [] - g_idx_sort_idx2_l = [] - - for i in range(len(self.experts)): - current_expert = self.experts[i].ws - current_expert(hidden_states) - # print("get weights") - # w1_qw = current_expert.get_parameter("qweight1").int() - # w3_qw = current_expert.get_parameter("qweight3").int() - # w1_s = current_expert.get_parameter("scales1").half() - # w3_s = current_expert.get_parameter("scales3").half() - w2_qw = current_expert.get_parameter("qweight2").int() - w2_s = current_expert.get_parameter("scales2").half() - w13_qw = current_expert.get_parameter("qweight13").int() - w13_s = current_expert.get_parameter("scales13").half() - # w1_qw = self.experts[i].w1.get_parameter("qweight").int() - # w3_qw = self.experts[i].w3.get_parameter("qweight").int() - # w1_s = self.experts[i].w1.get_parameter("scales").half() - # w3_s = self.experts[i].w3.get_parameter("scales").half() - # w2_qw = self.experts[i].w2.get_parameter("qweight").int() - # w2_s = self.experts[i].w2.get_parameter("scales").half() - if self.quant_config.desc_act: - # g_idx13 = self.experts[i].w1.get_parameter("g_idx") - # g_idx2 = self.experts[i].w2.get_parameter("g_idx") - g_idx13 = current_expert.get_parameter("g_idx13") - g_idx2 = current_expert.get_parameter("g_idx2") - g_idx_sort_idx13 = current_expert.get_parameter("g_idx_sort_indices13") - g_idx_sort_idx2 = current_expert.get_parameter("g_idx_sort_indices2") - else: - g_idx13 = torch.empty(0, device=w13_qw.device) - g_idx2 = torch.empty(0, device=w2_qw.device) - g_idx_sort_idx13 = torch.empty(0, device=w13_qw.device) - g_idx_sort_idx2 = torch.empty(0, device=w2_qw.device) - # g_idx_sort_idx13 = torch.argsort(g_idx13).int() - # g_idx_sort_idx2 = torch.argsort(g_idx2).int() - - # w13_qw = torch.cat((w1_qw, w3_qw), 0) - # w13_s = torch.cat((w1_s, w3_s), 0) - # w13_qw = torch.cat((w1_qw, w3_qw), 1) - # w13_s = torch.cat((w1_s, w3_s), 1) - # size_k = hidden_states.shape[1] - # size_n = w13_qw.shape[1] - # print("do repack 13", w13_qw.shape, g_idx_sort_idx13.shape) - # w13_qw = ops.gptq_marlin_repack(w13_qw, g_idx_sort_idx13, size_k, - # size_n, - # self.quant_config.weight_bits) - # w13_s = marlin_permute_scales_numbits( - # w13_s, size_k, size_n, self.quant_config.group_size, - # self.quant_config.weight_bits) - - # size_k = w2_qw.shape[0] * 8 - # size_n = w2_qw.shape[1] - # print("do repack 2", w2_qw.shape, g_idx_sort_idx2.shape) - # w2_qw = ops.gptq_marlin_repack(w2_qw, g_idx_sort_idx2, size_k, - # size_n, - # self.quant_config.weight_bits) - # w2_s = marlin_permute_scales_numbits(w2_s, size_k, size_n, - # self.quant_config.group_size, - # self.quant_config.weight_bits) - - qweight13_l.append(w13_qw) - scales13_l.append(w13_s) - qweight2_l.append(w2_qw) - scales2_l.append(w2_s) - g_idx13_l.append(g_idx13) - g_idx2_l.append(g_idx2) - g_idx_sort_idx13_l.append(g_idx_sort_idx13) - g_idx_sort_idx2_l.append(g_idx_sort_idx2) - - qweight13 = torch.stack(qweight13_l, dim=0).to(qweight13_l[0].device) - scales13 = torch.stack(scales13_l, dim=0).to(scales13_l[0].device) - qweight2 = torch.stack(qweight2_l, dim=0).to(qweight2_l[0].device) - scales2 = torch.stack(scales2_l, dim=0).to(scales2_l[0].device) - g_idx13 = torch.stack(g_idx13_l, dim=0).to(g_idx13_l[0].device) - g_idx2 = torch.stack(g_idx2_l, dim=0).to(g_idx2_l[0].device) - g_idx_sort_idx13 = torch.stack(g_idx_sort_idx13_l, - dim=0).to(g_idx_sort_idx13_l[0].device) - g_idx_sort_idx2 = torch.stack(g_idx_sort_idx2_l, - dim=0).to(g_idx_sort_idx2_l[0].device) - - final_hidden_states = fused_marlin_moe( - hidden_states.half(), - qweight13, - qweight2, - router_logits, - g_idx13, - g_idx2, - g_idx_sort_idx13, - g_idx_sort_idx2, - self.top_k, - renormalize=True, - w1_scale=scales13, - w2_scale=scales2, - ) - - return final_hidden_states.bfloat16() - + return self.experts(hidden_states.half(), router_logits).bfloat16() else: - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights = F.softmax(router_logits, + dim=1, + dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) + self.top_k, + dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) final_hidden_states = None for expert_idx in self.expert_indicies: expert_layer = self.experts[expert_idx] expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) current_hidden_states = expert_layer(hidden_states).mul_( expert_weights) @@ -299,7 +175,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states.add_(current_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states).view( - num_tokens, hidden_dim) + num_tokens, hidden_dim) + class MixtralAttention(nn.Module): @@ -384,7 +261,6 @@ def __init__( self, config: MixtralConfig, experimental_fused_moe: bool, - old_code: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -400,10 +276,10 @@ def __init__( rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config) - self.block_sparse_moe = MixtralMoE(config=config, - experimental_fused_moe=experimental_fused_moe, - old_code=old_code, - quant_config=quant_config) + self.block_sparse_moe = MixtralMoE( + config=config, + experimental_fused_moe=experimental_fused_moe, + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -444,7 +320,6 @@ def __init__( self, config: MixtralConfig, experimental_fused_moe: bool, - old_code: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -459,7 +334,6 @@ def __init__( self.layers = nn.ModuleList([ MixtralDecoderLayer(config, experimental_fused_moe, - old_code, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) @@ -495,15 +369,17 @@ def __init__( ) -> None: super().__init__() + # TODO have a better way to set this. + # Needs some testing/improving? self.experimental_fused_moe = True - self.old_code = False self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, self.experimental_fused_moe, self.old_code, cache_config, quant_config) + self.model = MixtralModel(config, self.experimental_fused_moe, + cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -522,7 +398,7 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + sampling_metadata) return logits def sample( @@ -541,25 +417,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # weight names: [w[0] for w in weights] - # 'model.layers.0.block_sparse_moe.experts.0.w1.bias', - # 'model.layers.0.block_sparse_moe.experts.0.w1.g_idx', - # 'model.layers.0.block_sparse_moe.experts.0.w1.qweight', - # 'model.layers.0.block_sparse_moe.experts.0.w1.qzeros', - # 'model.layers.0.block_sparse_moe.experts.0.w1.scales', - # 'model.layers.0.block_sparse_moe.experts.0.w2.bias', - # 'model.layers.0.block_sparse_moe.experts.0.w2.g_idx', - # 'model.layers.0.block_sparse_moe.experts.0.w2.qweight', - # 'model.layers.0.block_sparse_moe.experts.0.w2.qzeros', - # 'model.layers.0.block_sparse_moe.experts.0.w2.scales', - # 'model.layers.0.block_sparse_moe.experts.0.w3.bias', - # 'model.layers.0.block_sparse_moe.experts.0.w3.g_idx', - # 'model.layers.0.block_sparse_moe.experts.0.w3.qweight', - # 'model.layers.0.block_sparse_moe.experts.0.w3.qzeros', - # 'model.layers.0.block_sparse_moe.experts.0.w3.scales', - # 'model.layers.0.block_sparse_moe.experts.1.w1.bias' - # ... - params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: @@ -580,98 +437,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - if self.old_code: - if self.experimental_fused_moe: - if("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - has_weight_or_scale = (".qweight" in name or ".scales" in name) - has_g_idx = ".g_idx" in name - if (has_weight_or_scale and ".w1." in name): - name = name.replace(".w1.", ".ws.") - name += "1" - if ((has_weight_or_scale or has_g_idx) and ".w2." in name): - name = name.replace(".w2.", ".ws.") - name += "2" - if (has_weight_or_scale and ".w3." in name): - name = name.replace(".w3.", ".ws.") - name += "3" - if (has_g_idx and ".w1." in name): - name = name.replace(".w1.", ".ws.") - name += "13" - if (has_g_idx and ".w3." in name): - name = name.replace(".w3.", ".ws.") - name += "13" - - else: - if("block_sparse_moe.experts." in name - and name not in params_dict): - continue - - param = params_dict[name] - # print("load", name, "into", param.shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if self.experimental_fused_moe: + if ("block_sparse_moe.experts." in name + and ".w1." not in name and ".w2." not in name + and ".w3." not in name + and name not in params_dict): + continue + + if (".qzeros" in name): + continue + + shard_id = None + expert_id = 0 + + has_any_numbered = (".qweight" in name or ".scales" in name + or ".g_idx" in name) + if (has_any_numbered and (".w1." in name)): + name = name.replace(".w1.", ".w13_") + shard_id = 0 + if (has_any_numbered and (".w2." in name)): + name = name.replace(".w2.", ".w2_") + shard_id = 0 + if (has_any_numbered and (".w3." in name)): + name = name.replace(".w3.", ".w13_") + shard_id = 1 + + exp_string = re.search(r"\.experts\.\d+.", name) + if exp_string: + exp_string = exp_string.group(0) + expert_id = int(exp_string.split(".")[2]) + name = name.replace(exp_string, ".experts.") else: + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue - if self.experimental_fused_moe: - if("block_sparse_moe.experts." in name - and ".w1." not in name and ".w2." not in name - and ".w3." not in name - and name not in params_dict): - continue - - if (".qzeros" in name): - continue - - shard_id = None - expert_id = 0 - - # print("process:", name) - - has_any_numbered = (".qweight" in name or ".scales" in name or ".g_idx" in name) - if (has_any_numbered and (".w1." in name)): - name = name.replace(".w1.", ".w13_") - shard_id = 0 - if (has_any_numbered and (".w2." in name)): - name = name.replace(".w2.", ".w2_") - shard_id = 0 - if (has_any_numbered and (".w3." in name)): - name = name.replace(".w3.", ".w13_") - shard_id = 1 - - exp_string = re.search(r"\.experts\.\d+.", name) - if exp_string: - exp_string = exp_string.group(0) - # print("Exp string:", exp_string) - expert_id = int(exp_string.split(".")[2]) - # print("I found:", expert_shard, "in", name) - name = name.replace(exp_string, ".experts.") - - else: - if("block_sparse_moe.experts." in name - and name not in params_dict): - continue - - param = params_dict[name] - - # print("load", name, "into", param.shape) - if shard_id is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # print("load:", name, "with shard", shard_id) - weight_loader(param, loaded_weight, name, shard_id, expert_id, True) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - # print("load:", name, "without shard") - weight_loader(param, loaded_weight) - + param = params_dict[name] + + if self.experimental_fused_moe and shard_id is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight, name, shard_id, + expert_id, True) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)