diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 046e6e5a53652..34589c5bdb3c8 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -152,7 +152,8 @@ struct MacheteKernelTemplate { int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); - int const group_size = maybe_group_size.value_or(K); + int group_size = maybe_group_size.value_or(K); + group_size = (group_size == -1) ? K : group_size; int const scale_k = (K + group_size - 1) / group_size; TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 686dd68bd52bb..df78312997fb0 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // clang-format on // Allocate output - torch::Tensor D = torch::empty_like(B); + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); prepack_B(stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..ba1e519d64370 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -7,10 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -231,7 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -239,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -247,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..17fd688a09f10 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -2,13 +2,10 @@ import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -46,23 +43,32 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + act_reordering=False + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + # If group_size is -1, we are in channelwise case. channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size @@ -71,12 +77,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, # scales across all gpus. partition_scales = (row_parallel and not channelwise) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -123,62 +123,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None) # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..fed393dd97771 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,16 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + check_marlin_supported, marlin_repeat_scales_on_all_ranks, + verify_marlin_supported) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -163,24 +161,29 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + act_reordering=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -261,55 +264,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -317,16 +280,4 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py new file mode 100644 index 0000000000000..d8b2de6141d63 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py @@ -0,0 +1,85 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.parameter import (ModelWeightParameter, + PackedvLLMParameter) + +from .MPLinearKernel import * + + +class GPTQLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if c.act_type != torch.half: + return False, f"Act type {c.act_type} currently not supported by GPTQLinearKernel" + + if c.zero_points: + return False, "Zero points currently not supported by GPTQLinearKernel" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + def transform_w_q(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, PackedvLLMParameter): + x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0) + return ops.machete_prepack_B(x.t().contiguous().t(), + self.config.weight_type) + + def transform_w_s(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, ModelWeightParameter): + x = x.permute_layout(input_dim=0, output_dim=1) + return x.contiguous() + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000000000..185e40c251310 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + act_reordering: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + # note assumes that (if the they are not ModelWeightParameters) + # `getattr(layer, w_q_name)` is: + # {input_dim = 0, output_dim = 1, packed_dim = 0} + # `getattr(layer, w_s_name)` is: + # {input_dim = 0, output_dim = 1} + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + replace_parameter(layer, name, fn(getattr(layer, name))) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], Optional[torch.Tensor]]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py new file mode 100644 index 0000000000000..275c767a8c1be --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -0,0 +1,88 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import * + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.act_reordering: + return False, "Act reordering currently not supported by Machete" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + print(w_s) + print(c.group_size) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py new file mode 100644 index 0000000000000..fb59f6d4352aa --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py @@ -0,0 +1,128 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import * + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.act_reordering, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.act_reordering: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading`` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..22172771e5b64 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,44 @@ +from typing import List, Optional, Type + +from vllm.platforms import current_platform + +from .MacheteLinearKernel import MacheteLinearKernel +from .MarlinLinearKernel import MarlinLinearKernel +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1d..e60f0c79ac1f7 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000..c38bd8955f457 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,33 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new) + mod.register_parameter(name, torch.nn.Parameter(new)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000000000..18e1332050cdd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..e83b4eacf8f38 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -118,6 +118,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -146,6 +159,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -221,17 +239,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index c6cfab7892efa..6d84a5ecb3492 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -320,6 +320,68 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_( + param: BasevLLMParameter, + input_dim: int, + output_dim: int, + **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters where either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index eb491dd1554a8..373151a5311e5 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -27,6 +27,8 @@ class scalar_types: float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) uint4b8 = ScalarType.uint(4, 8) uint8b128 = ScalarType.uint(8, 128)