Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gptqmodel #35012

Merged
merged 63 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
4c567b3
gptqmodel
jiqing-feng Nov 29, 2024
1d8f83e
fix format
jiqing-feng Nov 29, 2024
9f44604
update readme
jiqing-feng Dec 2, 2024
62cd0dd
Merge branch 'main' into gptq
jiqing-feng Dec 2, 2024
8c88315
gptqmodel need use checkpoint_format (#1)
LRL-ModelCloud Dec 3, 2024
ef0fb56
Revert quantizer_gptq.py (#2)
LRL-ModelCloud Dec 4, 2024
0191322
Merge branch 'main' into gptq
jiqing-feng Dec 4, 2024
0655960
limit gptqmodel and optimum version
jiqing-feng Dec 4, 2024
be914ea
fix format
jiqing-feng Dec 4, 2024
aa9a5c6
fix warning
jiqing-feng Dec 4, 2024
a4bc251
fix version check
jiqing-feng Dec 4, 2024
9ae979b
revert unrelated changes
jiqing-feng Dec 4, 2024
a73a8c2
enable gptqmodel tests
jiqing-feng Dec 4, 2024
c18a5f1
fix requires gptq
jiqing-feng Dec 4, 2024
27ac615
Fix Transformer compat (#3)
ZX-ModelCloud Dec 5, 2024
d3ad24b
Merge branch 'main' into gptq
jiqing-feng Dec 7, 2024
3972d2e
fix format
jiqing-feng Dec 10, 2024
2612dd7
Merge branch 'main' into gptq
jiqing-feng Dec 10, 2024
99b2ed7
fix format again
jiqing-feng Dec 10, 2024
ac14b9f
update gptqmodel version (#6)
ZX-ModelCloud Dec 16, 2024
0276854
fix unit test (#5)
ZX-ModelCloud Dec 19, 2024
8bde513
Merge branch 'main' into gptq
jiqing-feng Dec 19, 2024
4ffc7d1
backend is loading_attibutes (#7)
LRL-ModelCloud Dec 20, 2024
5474f89
fix format and tests
jiqing-feng Dec 20, 2024
f9e7e45
Merge branch 'main' into gptq
jiqing-feng Dec 20, 2024
99b5f14
fix memory check
jiqing-feng Dec 20, 2024
331b56a
Merge branch 'main' into gptq
jiqing-feng Dec 23, 2024
409f6a2
fix device mismatch
jiqing-feng Dec 23, 2024
c996a41
fix result check
jiqing-feng Dec 23, 2024
84e972c
Merge branch 'main' into gptq
jiqing-feng Dec 23, 2024
dbf68e8
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
f4c2ad3
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
9185f8b
Update src/transformers/quantizers/quantizer_gptq.py
jiqing-feng Dec 24, 2024
8d69ba4
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
226953a
Merge branch 'main' into gptq
MekkCyber Dec 24, 2024
65ee44b
update tests
jiqing-feng Dec 24, 2024
34d0ec0
review: update docs (#10)
Qubitium Dec 24, 2024
9d71301
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
153121a
review: update docs (#12)
Qubitium Dec 24, 2024
b270b2d
update tests for gptqmodel
jiqing-feng Dec 24, 2024
7120899
update document (#9)
ZX-ModelCloud Dec 24, 2024
a7fcfd7
Merge branch 'main' into gptq
jiqing-feng Dec 24, 2024
8e36a0e
typo
Qubitium Dec 24, 2024
0aef2df
doc note for asymmetric quant
Qubitium Dec 24, 2024
31a6baa
typo with apple silicon(e)
Qubitium Dec 24, 2024
d7c8890
typo for marlin
Qubitium Dec 24, 2024
db33fd5
Merge branch 'main' into gptq
jiqing-feng Dec 25, 2024
945f663
column name revert: review
Qubitium Dec 26, 2024
fc7b971
Merge branch 'main' into gptq
jiqing-feng Dec 27, 2024
6cb77d5
Merge branch 'main' into gptq
jiqing-feng Dec 30, 2024
2234122
Merge branch 'main' into gptq
Qubitium Jan 3, 2025
d07ed96
Merge branch 'main' into gptq
jiqing-feng Jan 9, 2025
a20dfd3
Merge branch 'main' into gptq
jiqing-feng Jan 9, 2025
91d12cc
doc rocm support
Qubitium Jan 9, 2025
1ec6fe7
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
7d2b708
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
8c2a8b3
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
053e0ad
Update docs/source/en/quantization/gptq.md
Qubitium Jan 10, 2025
d3bfbb0
Update docs/source/en/quantization/overview.md
Qubitium Jan 10, 2025
1d883ec
Update docs/source/en/quantization/overview.md
Qubitium Jan 10, 2025
2806f71
Merge branch 'main' into gptq
Qubitium Jan 10, 2025
25169bd
Merge branch 'main' into gptq
jiqing-feng Jan 10, 2025
5ea104a
Merge branch 'main' into gptq
jiqing-feng Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/source/en/quantization/gptq.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@ Try GPTQ quantization with PEFT in this [notebook](https://colab.research.google

The [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) library implements the GPTQ algorithm, a post-training quantization technique where each row of the weight matrix is quantized independently to find a version of the weights that minimizes the error. These weights are quantized to int4, but they're restored to fp16 on the fly during inference. This can save your memory-usage by 4x because the int4 weights are dequantized in a fused kernel rather than a GPU's global memory, and you can also expect a speedup in inference because using a lower bitwidth takes less time to communicate.

Now, we are going to replace [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) with [GPTQModel](https://github.com/ModelCloud/GPTQModel), the auto_gptq will be deprecated in the future.

SunMarc marked this conversation as resolved.
Show resolved Hide resolved
Before you begin, make sure the following libraries are installed:

```bash
pip install auto-gptq
```
or
```bash
pip install gptqmodel
```

```bash
pip install --upgrade accelerate optimum transformers
```

Expand Down Expand Up @@ -110,7 +119,7 @@ Only 4-bit models are supported, and we recommend deactivating the ExLlama kerne

</Tip>

The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2), then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file.
The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2) or GPTQModel, then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file.

```py
import torch
Expand Down
8 changes: 7 additions & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Use the table below to help you decide which quantization method to use.
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| GGUF / GGML (llama.cpp) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1 - 8 | 🔴 | [See GGUF section](../gguf) | [See GGUF section](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [GPTQ](./gptq) | 🔴 | 🟡 *** | 🟢 | 🟢 | 🔴 | 🟡 *** | 🔴 | 2 - 3 - 4 - 8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1 - 8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2 / 4 / 8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
Expand All @@ -72,3 +72,9 @@ We value your feedback to help identify bugs before the full release! Check out
\** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships.

</Tip>

<Tip>

\*** GPTQ only supports 4-bit on Intel CPU / GPU.

</Tip>
21 changes: 13 additions & 8 deletions src/transformers/quantizers/quantizer_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging
from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging
from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin


Expand All @@ -35,11 +35,11 @@
class GptqHfQuantizer(HfQuantizer):
"""
Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through
`auto_gptq` package. Quantization is done under the hood for users if they load a non-prequantized model.
`auto_gptq` or `gptqmodel` package. Quantization is done under the hood for users if they load a non-prequantized model.
"""

requires_calibration = False
required_packages = ["optimum", "auto_gptq"]
required_packages = ["optimum", "auto_gptq", "gptqmodel"]
optimum_quantizer = None

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
Expand All @@ -49,16 +49,21 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum())

def validate_environment(self, *args, **kwargs):
gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
gptq_supports_cpu = (
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or is_gptqmodel_available()
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not (is_optimum_available() and is_auto_gptq_available()):
elif not (is_optimum_available() and (is_auto_gptq_available() or is_gptqmodel_available())):
raise ImportError(
"Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
"Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq or gptqmodel library (`pip install auto-gptq` or `pip install gptqmodel`)"
)
elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
elif is_auto_gptq_available() and version.parse(importlib.metadata.version("auto_gptq")) < version.parse(
"0.4.2"
):
raise ImportError(
"You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
"You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel`"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a message mentioning that autogptq will be deprecated ? I think we can do two version of transformers from now. For optimum, maybe we can deprecate this a bit later than transformers to make sure that we can still revert if there is a big issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget that the users need to use the latest version from optimum with gptqmodel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have limited the optimum and gptqmodel version. The version limitation can be changed after gptqmodel and optimum released.


def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_gptqmodel_available,
is_grokadamw_available,
is_hqq_available,
is_in_notebook,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
_gptqmodel_available = _is_package_available("gptqmodel")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
Expand Down Expand Up @@ -1005,6 +1006,10 @@ def is_auto_gptq_available():
return _auto_gptq_available


def is_gptqmodel_available():
return _gptqmodel_available


def is_eetq_available():
return _eetq_available

Expand Down