Skip to content

Commit

Permalink
formatting + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 30, 2024
1 parent cff188c commit 93924e1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
Expand Down Expand Up @@ -285,7 +286,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
replace_parameter(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
Expand All @@ -300,7 +301,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
replace_parameter(layer, "qweight", marlin_qweight)

# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
Expand All @@ -309,7 +310,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
replace_parameter(layer, "scales", marlin_scales)

def apply(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
self.config.weight_type)
return x

def transform_w_s(x):
Expand Down
20 changes: 8 additions & 12 deletions vllm/model_executor/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,8 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
marlin_tile_size=self.marlin_tile_size)


def permute_param_layout_(
param: BasevLLMParameter,
input_dim: int,
output_dim: int,
**kwargs
) -> BasevLLMParameter:
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
Expand All @@ -343,15 +339,15 @@ def permute_param_layout_(
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""

curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)

if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2,\
"permute_param_layout_ only supports 2D parameters where either "\
"input_dim or output_dim is not set"

# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
Expand All @@ -362,17 +358,17 @@ def permute_param_layout_(
assert curr_input_dim is not None,\
"either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2

# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim())
if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)

if "packed_dim" in kwargs:
assert hasattr(param, "packed_dim") and\
param.packed_dim == perm[kwargs["packed_dim"]],\
Expand All @@ -385,7 +381,7 @@ def permute_param_layout_(
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]

return param


Expand Down

0 comments on commit 93924e1

Please sign in to comment.