Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gptqmodel #35012

Merged
merged 63 commits into from
Jan 15, 2025
Merged
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
4c567b3
gptqmodel
jiqing-feng Nov 29, 2024
1d8f83e
fix format
jiqing-feng Nov 29, 2024
9f44604
update readme
jiqing-feng Dec 2, 2024
62cd0dd
Merge branch 'main' into gptq
jiqing-feng Dec 2, 2024
8c88315
gptqmodel need use checkpoint_format (#1)
LRL-ModelCloud Dec 3, 2024
ef0fb56
Revert quantizer_gptq.py (#2)
LRL-ModelCloud Dec 4, 2024
0191322
Merge branch 'main' into gptq
jiqing-feng Dec 4, 2024
0655960
limit gptqmodel and optimum version
jiqing-feng Dec 4, 2024
be914ea
fix format
jiqing-feng Dec 4, 2024
aa9a5c6
fix warning
jiqing-feng Dec 4, 2024
a4bc251
fix version check
jiqing-feng Dec 4, 2024
9ae979b
revert unrelated changes
jiqing-feng Dec 4, 2024
a73a8c2
enable gptqmodel tests
jiqing-feng Dec 4, 2024
c18a5f1
fix requires gptq
jiqing-feng Dec 4, 2024
27ac615
Fix Transformer compat (#3)
ZX-ModelCloud Dec 5, 2024
d3ad24b
Merge branch 'main' into gptq
jiqing-feng Dec 7, 2024
3972d2e
fix format
jiqing-feng Dec 10, 2024
2612dd7
Merge branch 'main' into gptq
jiqing-feng Dec 10, 2024
99b2ed7
fix format again
jiqing-feng Dec 10, 2024
ac14b9f
update gptqmodel version (#6)
ZX-ModelCloud Dec 16, 2024
0276854
fix unit test (#5)
ZX-ModelCloud Dec 19, 2024
8bde513
Merge branch 'main' into gptq
jiqing-feng Dec 19, 2024
4ffc7d1
backend is loading_attibutes (#7)
LRL-ModelCloud Dec 20, 2024
5474f89
fix format and tests
jiqing-feng Dec 20, 2024
f9e7e45
Merge branch 'main' into gptq
jiqing-feng Dec 20, 2024
99b5f14
fix memory check
jiqing-feng Dec 20, 2024
331b56a
Merge branch 'main' into gptq
jiqing-feng Dec 23, 2024
409f6a2
fix device mismatch
jiqing-feng Dec 23, 2024
c996a41
fix result check
jiqing-feng Dec 23, 2024
84e972c
Merge branch 'main' into gptq
jiqing-feng Dec 23, 2024
dbf68e8
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
f4c2ad3
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
9185f8b
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
8d69ba4
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
226953a
Merge branch 'main' into gptq
MekkCyber Dec 24, 2024
65ee44b
update tests
jiqing-feng Dec 24, 2024
34d0ec0
review: update docs (#10)
Qubitium Dec 24, 2024
9d71301
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
153121a
review: update docs (#12)
Qubitium Dec 24, 2024
b270b2d
update tests for gptqmodel
jiqing-feng Dec 24, 2024
7120899
update document (#9)
ZX-ModelCloud Dec 24, 2024
a7fcfd7
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
8e36a0e
typo
Qubitium Dec 24, 2024
0aef2df
doc note for asymmetric quant
Qubitium Dec 24, 2024
31a6baa
typo with apple silicon(e)
Qubitium Dec 24, 2024
d7c8890
typo for marlin
Qubitium Dec 24, 2024
db33fd5
Merge branch 'main' into gptq
jiqing-feng Dec 25, 2024
945f663
column name revert: review
Qubitium Dec 26, 2024
fc7b971
Merge branch 'main' into gptq
jiqing-feng Dec 27, 2024
6cb77d5
Merge branch 'main' into gptq
jiqing-feng Dec 30, 2024
2234122
Merge branch 'main' into gptq
Qubitium Jan 3, 2025
d07ed96
Merge branch 'main' into gptq
jiqing-feng Jan 9, 2025
a20dfd3
Merge branch 'main' into gptq
jiqing-feng Jan 9, 2025
91d12cc
doc rocm support
Qubitium Jan 9, 2025
1ec6fe7
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
7d2b708
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
8c2a8b3
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
053e0ad
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
d3bfbb0
Update docs/source/en/quantization/overview.md
Qubitium Jan 10, 2025
1d883ec
Update docs/source/en/quantization/overview.md
Qubitium Jan 10, 2025
2806f71
Merge branch 'main' into gptq
Qubitium Jan 10, 2025
25169bd
Merge branch 'main' into gptq
jiqing-feng Jan 10, 2025
5ea104a
Merge branch 'main' into gptq
jiqing-feng Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix Transformer compat (#3)
* revert quantizer_gptq.py change

* pass **kwargs

* add meta info

* cleanup

* cleanup

* Update quantization_config.py

* hf_select_quant_linear pass checkpoint_format and meta

* fix GPTQTestCUDA

* Update test_gptq.py

* gptqmodel.hf_select_quant_linear() now does not select ExllamaV2

* cleanup

* add backend

* cleanup

* cleanup

* no need check exllama version

* Update quantization_config.py

* lower checkpoint_format and backend

* check none

* cleanup

* Update quantization_config.py

* fix self.use_exllama == False

* spell

* fix unittest

* fix unittest

---------

Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
3 people authored Dec 5, 2024
commit 27ac615f3f085880c95aff4026c30f5c2d332574
45 changes: 35 additions & 10 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,9 @@

from packaging import version

from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
from .import_utils import is_auto_gptq_available
from ..utils import (is_auto_awq_available, is_hqq_available, is_torch_available, is_gptqmodel_available,
is_torchao_available, logging)


if is_torch_available():
@@ -577,8 +579,14 @@ class GPTQConfig(QuantizationConfigMixin):
quantization using inputs that have passed through the previously quantized layers.
checkpoint_format (`str`, *optional*, defaults to `"gptq"`):
GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only.
meta (`Dict[str, any]`, *optional*):
Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
backend (`str`, *optional*):
Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only
valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
use_cuda_fp16 (`bool`, *optional*, defaults to `False`):
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. Auto-gptq only.
model_seqlen (`int`, *optional*):
The maximum sequence length that the model can take.
block_name_to_quantize (`str`, *optional*):
@@ -618,7 +626,9 @@ def __init__(
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
checkpoint_format: Optional[str] = "gptq",
checkpoint_format: str = "gptq",
meta: Optional[Dict[str, any]] = None,
backend: Optional[str] = None,
use_cuda_fp16: bool = False,
model_seqlen: Optional[int] = None,
block_name_to_quantize: Optional[str] = None,
@@ -641,6 +651,9 @@ def __init__(
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
self.checkpoint_format = checkpoint_format.lower()
self.meta = meta
self.backend = backend.lower() if isinstance(backend, str) else backend
self.use_cuda_fp16 = use_cuda_fp16
self.model_seqlen = model_seqlen
self.block_name_to_quantize = block_name_to_quantize
@@ -653,7 +666,6 @@ def __init__(
self.disable_exllama = kwargs.pop("disable_exllama", None)
self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.checkpoint_format = checkpoint_format
self.post_init()

def get_loading_attributes(self):
@@ -690,6 +702,17 @@ def post_init(self):
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
)

# make sure backend is back/forward compatible with both gptqmodel (full) and auto-gptq (partial)
if is_gptqmodel_available():
# convert auto-gptq control into gptqmodel backend
if self.backend is None:
self.backend = "auto_trainable" if self.use_exllama == False else "auto"
else:
# convert gptqmodel backend `auto_trainable` into auto-gptq control
if self.backend == "auto_trainable":
self.use_exllama = False

# auto-gptq specific kernel control logic
if self.disable_exllama is None and self.use_exllama is None:
# New default behaviour
self.use_exllama = True
@@ -723,19 +746,21 @@ def post_init(self):
"speed using exllamav2 kernel by setting `exllama_config`."
)
elif self.exllama_config["version"] == ExllamaVersion.TWO:
optimum_version = version.parse(importlib.metadata.version("optimum"))
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
raise ValueError(
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
)
if is_auto_gptq_available():
optimum_version = version.parse(importlib.metadata.version("optimum"))
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
raise ValueError(
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
)
if self.modules_in_block_to_quantize is not None:
optimum_version = version.parse(importlib.metadata.version("optimum"))
if optimum_version < version.parse("1.15.0"):
raise ValueError(
"You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
)


def to_dict(self):
config_dict = super().to_dict()
config_dict.pop("disable_exllama", None)
43 changes: 31 additions & 12 deletions tests/quantization/gptq/test_gptq.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

import pytest

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig, AutoConfig
from transformers.testing_utils import (
is_torch_available,
require_accelerate,
@@ -84,12 +84,14 @@ class GPTQTest(unittest.TestCase):
input_text = "Hello my name is"

EXPECTED_OUTPUTS = set()
# flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions
EXPECTED_OUTPUTS.add("Hello my name is Katie, I am a 22 year")

# this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings
EXPECTED_RELATIVE_DIFFERENCE = 2.06183008

bits = 4
sym = True
group_size = 128
desc_act = False
use_exllama = False
@@ -112,21 +114,23 @@ def setUpClass(cls):
cls.mem_fp16 = cls.model_fp16.get_memory_footprint()

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
cls.config = AutoConfig.from_pretrained(cls.model_name)

quantization_config = GPTQConfig(
cls.quantization_config = GPTQConfig(
bits=cls.bits,
dataset=cls.dataset,
tokenizer=cls.tokenizer,
group_size=cls.group_size,
desc_act=cls.desc_act,
sym=cls.sym,
use_exllama=cls.use_exllama,
)

cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
torch_dtype=torch.float16,
device_map=cls.device_map,
quantization_config=quantization_config,
quantization_config=cls.quantization_config,
)

def test_memory_footprint(self):
@@ -167,14 +171,21 @@ def test_quantized_layers_class(self):
"""
if is_gptqmodel_available():
from gptqmodel.utils.importer import hf_select_quant_linear

if hasattr(self.config, "quantization_config"):
checkpoint_format = self.config.quantization_config.get("checkpoint_format")
meta = self.config.quantization_config.get("meta")
else:
checkpoint_format = "gptq"
meta = None
QuantLinear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=True,
sym=self.sym,
device_map=self.device_map,
pack=False,
checkpoint_format=checkpoint_format,
meta=meta,
backend=self.quantization_config.backend,
)
elif is_auto_gptq_available():
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
@@ -187,7 +198,7 @@ def test_quantized_layers_class(self):
disable_exllama=not self.use_exllama,
disable_exllamav2=True,
)
self.assertTrue(self.quantized_model.model.layers[0].mlp.gate_proj.__class__ == QuantLinear)
self.assertEqual(self.quantized_model.model.layers[0].mlp.gate_proj.__class__, QuantLinear)

def check_inference_correctness(self, model):
r"""
@@ -205,13 +216,13 @@ def check_inference_correctness(self, model):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def check_quantized_layers_type(self, model, value):
self.assertTrue(model.model.layers[0].mlp.gate_proj.QUANT_TYPE == value)
self.assertEqual(model.model.layers[0].mlp.gate_proj.QUANT_TYPE, value)

def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
"""
if self.device_map != "cpu":
if self.device_map is None:
self.check_inference_correctness(self.quantized_model.to(0))
else:
self.check_inference_correctness(self.quantized_model)
@@ -235,7 +246,7 @@ def test_serialization(self):
tmpdirname, device_map=self.device_map
)
else:
quant_type = "ipex" if self.device_map == "cpu" else "cuda"
quant_type = "ipex" if self.device_map == "cpu" else "exllama"
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
tmpdirname, device_map=self.device_map
)
@@ -259,6 +270,12 @@ class GPTQTestCUDA(GPTQTest):
EXPECTED_RELATIVE_DIFFERENCE = 2.06183008
device_map = {"": 0}

@classmethod
def setUpClass(cls):
super().setUpClass()
# flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions
cls.EXPECTED_OUTPUTS.add("Hello my name is Katie. I am a 20 year")

def test_change_loading_attributes(self):
"""
Test the serialization of the model and the loading of the quantized weights works with another config file
@@ -302,6 +319,7 @@ class GPTQTestActOrderExllama(unittest.TestCase):
"""

EXPECTED_OUTPUTS = set()
# flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions
EXPECTED_OUTPUTS.add("Hello, how are you ? I'm doing good, thanks for asking.")
# 4bit + act_order + 128g
model_name = "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ"
@@ -338,7 +356,7 @@ def check_inference_correctness(self, model):
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_quantized_layers_type(self):
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllama")
self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE, "exllama")

def test_generate_quality(self):
"""
@@ -377,6 +395,7 @@ class GPTQTestExllamaV2(unittest.TestCase):
"""

EXPECTED_OUTPUTS = set()
# flaky test: gptqmodel and auto-gptq are not output equivalent nor is string compare deterministic even between transformer/torch versions
EXPECTED_OUTPUTS.add("Hello, how are you ? I'm doing good, thanks for asking.")
# 4bit + act_order + 128g
model_name = "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ"
@@ -397,7 +416,7 @@ def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)

def test_quantized_layers_type(self):
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllamav2")
self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE, "exllama" if is_gptqmodel_available() else "exllamav2")

def check_inference_correctness(self, model):
"""