From 93924e11c82c32530b72b5462405685e23aae1a8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 30 Aug 2024 21:05:51 +0000 Subject: [PATCH] formatting + fixes --- .../layers/quantization/gptq_marlin.py | 9 +++++---- .../kernels/MacheteLinearKernel.py | 2 +- vllm/model_executor/parameter.py | 20 ++++++++----------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..d22a476ad7431 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -8,10 +8,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, + marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -285,7 +286,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.quant_config.desc_act: g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) + replace_parameter(layer, "g_idx", g_idx) else: layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) @@ -300,7 +301,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. marlin_scales = marlin_permute_scales( @@ -309,7 +310,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) def apply( self, diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py index 275c767a8c1be..f19e7b0e6b9c8 100644 --- a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -49,7 +49,7 @@ def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - self.config.weight_type) + self.config.weight_type) return x def transform_w_s(x): diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 04bb602458ae7..832e5221ecb16 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -327,12 +327,8 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) -def permute_param_layout_( - param: BasevLLMParameter, - input_dim: int, - output_dim: int, - **kwargs -) -> BasevLLMParameter: +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: """ Permute a parameter's layout to the specified input and output dimensions, useful for forcing the parameter into a known layout, for example, if I need @@ -343,7 +339,7 @@ def permute_param_layout_( to ensure x is in the correct layout (permuting it to the correct layout if required, asserting if it cannot get it to the correct layout) """ - + curr_input_dim = getattr(param, "input_dim", None) curr_output_dim = getattr(param, "output_dim", None) @@ -351,7 +347,7 @@ def permute_param_layout_( assert param.data.dim() == 2,\ "permute_param_layout_ only supports 2D parameters where either "\ "input_dim or output_dim is not set" - + # if one of the dimensions is not set, set it to the opposite of the other # we can only do this since we asserted the parameter is 2D above if curr_input_dim is None: @@ -362,9 +358,9 @@ def permute_param_layout_( assert curr_input_dim is not None,\ "either input or output dim must be set" curr_output_dim = (curr_input_dim + 1) % 2 - + # create permutation from the current layout to the layout with - # self.input_dim at input_dim and self.output_dim at output_dim preserving + # self.input_dim at input_dim and self.output_dim at output_dim preserving # other dimensions perm = [ i for i in range(param.data.dim()) @@ -372,7 +368,7 @@ def permute_param_layout_( ] perm.insert(input_dim, curr_input_dim) perm.insert(output_dim, curr_output_dim) - + if "packed_dim" in kwargs: assert hasattr(param, "packed_dim") and\ param.packed_dim == perm[kwargs["packed_dim"]],\ @@ -385,7 +381,7 @@ def permute_param_layout_( param._output_dim = output_dim if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): param._packed_dim = kwargs["packed_dim"] - + return param