Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gptqmodel support #2247

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ def is_auto_gptq_available():
)


@lru_cache
def is_gptqmodel_available():
if importlib.util.find_spec("gptqmodel") is not None:
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.4.2")
version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel"))
if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel:
return True
else:
raise ImportError(
f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, "
f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported"
)


@lru_cache
def is_optimum_available() -> bool:
return importlib.util.find_spec("optimum") is not None
Expand Down
14 changes: 10 additions & 4 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
_freeze_adapter,
_get_submodules,
get_auto_gptq_quant_linear,
get_gptqmodel_quant_linear,
get_quantization_config,
)
from peft.import_utils import is_gptqmodel_available
from peft.utils.integrations import gather_params_ctx

from .gptq import SVDQuantLinear
Expand Down Expand Up @@ -135,7 +137,7 @@ def _create_and_replace(

# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
new_module = self._create_new_module(lora_config, adapter_name, target, self.model.hf_device_map, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't assume that the model has a .hf_device_map attribute. Theoretically, the model could by any PyTorch model, it doesn't have to be a transformers model -- and even transformers models don't necessarily have the attribute. Therefore, we have to check for this attribute here and if it doesn't exist, don't pass it (or pass None).

if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
Expand All @@ -150,7 +152,7 @@ def _create_and_replace(
)

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
def _create_new_module(lora_config, adapter_name, target, device_map, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb
Expand All @@ -160,7 +162,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
from .bnb import SVDLinear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

if is_gptqmodel_available():
QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following from the previous comment about device_map: If gptqmodel absolutely needs this attribute, if it doesn't exist on the model, let's raise an error with a helpful error message for the user.

else:
QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
Expand Down Expand Up @@ -190,7 +196,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
}
)
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
elif QuantLinear is not None and isinstance(target, QuantLinear):
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
else:
if isinstance(target_base_layer, torch.nn.Linear):
Expand Down
14 changes: 10 additions & 4 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_auto_gptq_quant_linear
from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear
from peft.import_utils import is_gptqmodel_available


class QuantLinear(torch.nn.Module, LoraLayer):
Expand Down Expand Up @@ -106,10 +107,15 @@ def dispatch_gptq(
else:
target_base_layer = target

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
cfg = kwargs.get("gptq_quantization_config", None)

if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
if is_gptqmodel_available():
device_map = kwargs.get("device_map", None)
quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map)
else:
quant_linear = get_auto_gptq_quant_linear(cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to suggest for users to move to gptqmodel here or is it too early for that?


if quant_linear is not None and isinstance(target_base_layer, quant_linear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _create_and_replace(
lora_bias=lora_config.lora_bias,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=self.model.hf_device_map, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about hf_device_map.

if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_is_valid_match,
infer_device,
get_auto_gptq_quant_linear,
get_gptqmodel_quant_linear,
get_quantization_config,
id_tensor_storage,
cast_mixed_precision_params,
Expand Down
78 changes: 55 additions & 23 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from packaging import version
from safetensors.torch import storage_ptr, storage_size

from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available
from .constants import (
CONFIG_NAME,
EMBEDDING_LAYER_NAMES,
Expand Down Expand Up @@ -607,30 +607,62 @@ def get_auto_gptq_quant_linear(gptq_quantization_config):
"""
Get the right AutoGPTQQuantLinear class based on the quantization config file
"""
if gptq_quantization_config is not None and is_auto_gptq_available():
if gptq_quantization_config is None:
return None

if is_auto_gptq_available():
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
else:
return None

desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
if hasattr(gptq_quantization_config, "use_exllama"):
use_exllama = gptq_quantization_config.use_exllama
else:
use_exllama = not gptq_quantization_config.disable_exllama
if hasattr(gptq_quantization_config, "exllama_config"):
exllama_version = gptq_quantization_config.exllama_config["version"]
else:
exllama_version = 1
AutoGPTQQuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=not (use_exllama and exllama_version == 1),
disable_exllamav2=not (use_exllama and exllama_version == 2),
)
return AutoGPTQQuantLinear
return None
desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
if hasattr(gptq_quantization_config, "use_exllama"):
use_exllama = gptq_quantization_config.use_exllama
else:
use_exllama = not gptq_quantization_config.disable_exllama
if hasattr(gptq_quantization_config, "exllama_config"):
exllama_version = gptq_quantization_config.exllama_config["version"]
else:
exllama_version = 1

QuantLinear = dynamically_import_QuantLinear(
use_triton=False,
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_exllama=not (use_exllama and exllama_version == 1),
disable_exllamav2=not (use_exllama and exllama_version == 2),
)

return QuantLinear


def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None):
"""
Get the right GPTQQuantLinear class based on the quantization config file
"""
if gptq_quantization_config is None:
return None

if not is_gptqmodel_available():
return None

from gptqmodel.utils.importer import hf_select_quant_linear

desc_act = gptq_quantization_config.desc_act
group_size = gptq_quantization_config.group_size
bits = gptq_quantization_config.bits
checkpoint_format = gptq_quantization_config.checkpoint_format if hasattr(gptq_quantization_config, "checkpoint_format") else "gptq"
sym = gptq_quantization_config.sym
meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None

QuantLinear = hf_select_quant_linear(bits=bits, group_size=group_size,
desc_act=desc_act, sym=sym, device_map=device_map,
checkpoint_format=checkpoint_format, meta=meta, backend="auto_trainable")

return QuantLinear


def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@

from .testing_utils import (
require_bitsandbytes,
require_gptq,
require_multi_accelerator,
require_non_cpu,
require_torch_gpu,
Expand All @@ -79,8 +80,7 @@
from peft.tuners.vera import Linear4bit as VeraLinear4bit


@require_non_cpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can just remove this decorator, since this class contains a bunch of tests that are unrelated to gptqmodel. From my understanding, gptqmodel can be run on CPU. In this case, I would suggest to move the gptqmodel tests to a different test file and leave this class untouched otherwise. They can be combined there with the PeftGPTQTests tests from test_gpu_examples.py.

The new test file (e.g. tests/test_gptqmodel.py) can then be run during the normal GitHub CI, which will just use CPU. However, if the tests are very slow on CPU, we need to either accelerate them or not run them on normal CI after all.

class PeftGPUCommonTests(unittest.TestCase):
class PeftCommonTests(unittest.TestCase):
r"""
A common tester to run common operations that are performed on GPU such as generation, loading in 8bit, etc.
"""
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_ia3_bnb_quantization_from_pretrained_safetensors(self, quantization):
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l

@pytest.mark.single_gpu_tests
@require_gptq
def test_lora_gptq_quantization_from_pretrained_safetensors(self):
r"""
Tests that the autogptq quantization using LoRA works as expected with safetensors weights.
Expand All @@ -403,19 +403,19 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self):

config = LoraConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0))
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))

# check that both adapters are in the same layer
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
Expand Down
14 changes: 4 additions & 10 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@
from .testing_utils import (
require_aqlm,
require_auto_awq,
require_auto_gptq,
require_bitsandbytes,
require_eetq,
require_gptq,
require_hqq,
require_non_cpu,
require_non_xpu,
Expand Down Expand Up @@ -1370,10 +1370,9 @@ def test_causal_lm_training_multi_gpu_4bit_vera(self):
assert trainer.state.log_history[-1]["train_loss"] is not None


@require_torch_gpu
@require_auto_gptq
@require_gptq
@require_optimum
class PeftGPTQGPUTests(unittest.TestCase):
class PeftGPTQTests(unittest.TestCase):
r"""
GPTQ + peft tests
"""
Expand All @@ -1382,8 +1381,7 @@ def setUp(self):
from transformers import GPTQConfig

self.causal_lm_model_id = "marcsun13/opt-350m-gptq-4bit"
# TODO : check if it works for Exllamav2 kernels
self.quantization_config = GPTQConfig(bits=4, use_exllama=False)
self.quantization_config = GPTQConfig(bits=4, backend="auto_trainable")
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)

def tearDown(self):
Expand All @@ -1402,7 +1400,6 @@ def _check_inference_finite(self, model, batch):
assert torch.isfinite(output.logits).all()
model.train(training)

@pytest.mark.single_gpu_tests
def test_causal_lm_training(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
Expand Down Expand Up @@ -1456,7 +1453,6 @@ def test_causal_lm_training(self):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.single_gpu_tests
def test_adalora_causalLM(self):
r"""
Tests the gptq training with adalora
Expand Down Expand Up @@ -1584,7 +1580,6 @@ def test_causal_lm_training_multi_gpu(self):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.single_gpu_tests
def test_non_default_adapter_name(self):
# See issue 1346
config = LoraConfig(
Expand Down Expand Up @@ -1665,7 +1660,6 @@ def test_offload_load(self):
offloaded_output = offloaded_lora_model(input_tokens)[0]
assert torch.allclose(output, offloaded_output, atol=1e-5)

@pytest.mark.single_gpu_tests
def test_offload_merge(self):
r"""
Test merging, unmerging, and unloading of a model with CPU- and disk- offloaded modules.
Expand Down
9 changes: 6 additions & 3 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_auto_awq_available,
is_auto_gptq_available,
is_eetq_available,
is_gptqmodel_available,
is_hqq_available,
is_optimum_available,
is_torchao_available,
Expand Down Expand Up @@ -91,11 +92,13 @@ def require_bitsandbytes(test_case):
return test_case


def require_auto_gptq(test_case):
def require_gptq(test_case):
"""
Decorator marking a test that requires auto-gptq. These tests are skipped when auto-gptq isn't installed.
Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
"""
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
return unittest.skipUnless(
is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq"
)(test_case)


def require_aqlm(test_case):
Expand Down