Skip to content

Commit

Permalink
remove gptq support
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 30, 2024
1 parent 735259b commit cff188c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 112 deletions.
99 changes: 74 additions & 25 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
Expand Down Expand Up @@ -161,29 +163,24 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:

del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")

mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
act_reordering=self.quant_config.desc_act
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)

# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
Expand Down Expand Up @@ -264,20 +261,72 @@ def create_weights(
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)

# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.zp = marlin_make_empty_g_idx(device)

# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)

# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)

This file was deleted.

2 changes: 0 additions & 2 deletions vllm/scalar_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class scalar_types:
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)

# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)

Expand Down

0 comments on commit cff188c

Please sign in to comment.