-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gptqmodel support #2247
base: main
Are you sure you want to change the base?
add gptqmodel support #2247
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,8 +25,10 @@ | |
_freeze_adapter, | ||
_get_submodules, | ||
get_auto_gptq_quant_linear, | ||
get_gptqmodel_quant_linear, | ||
get_quantization_config, | ||
) | ||
from peft.import_utils import is_gptqmodel_available | ||
from peft.utils.integrations import gather_params_ctx | ||
|
||
from .gptq import SVDQuantLinear | ||
|
@@ -135,7 +137,7 @@ def _create_and_replace( | |
|
||
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters | ||
if not isinstance(target, AdaLoraLayer): | ||
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) | ||
new_module = self._create_new_module(lora_config, adapter_name, target, self.model.hf_device_map, **kwargs) | ||
if adapter_name not in self.active_adapters: | ||
# adding an additional adapter: it is not automatically trainable | ||
new_module.requires_grad_(False) | ||
|
@@ -150,7 +152,7 @@ def _create_and_replace( | |
) | ||
|
||
@staticmethod | ||
def _create_new_module(lora_config, adapter_name, target, **kwargs): | ||
def _create_new_module(lora_config, adapter_name, target, device_map, **kwargs): | ||
# avoid eager bnb import | ||
if is_bnb_available(): | ||
import bitsandbytes as bnb | ||
|
@@ -160,7 +162,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
from .bnb import SVDLinear4bit | ||
|
||
gptq_quantization_config = kwargs.get("gptq_quantization_config", None) | ||
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) | ||
|
||
if is_gptqmodel_available(): | ||
QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following from the previous comment about |
||
else: | ||
QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) | ||
|
||
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) | ||
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) | ||
|
@@ -190,7 +196,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): | |
} | ||
) | ||
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) | ||
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear): | ||
elif QuantLinear is not None and isinstance(target, QuantLinear): | ||
new_module = SVDQuantLinear(target, adapter_name, **kwargs) | ||
else: | ||
if isinstance(target_base_layer, torch.nn.Linear): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,8 @@ | |
|
||
from peft.tuners.lora.layer import LoraLayer | ||
from peft.tuners.tuners_utils import BaseTunerLayer | ||
from peft.utils import get_auto_gptq_quant_linear | ||
from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear | ||
from peft.import_utils import is_gptqmodel_available | ||
|
||
|
||
class QuantLinear(torch.nn.Module, LoraLayer): | ||
|
@@ -106,10 +107,15 @@ def dispatch_gptq( | |
else: | ||
target_base_layer = target | ||
|
||
gptq_quantization_config = kwargs.get("gptq_quantization_config", None) | ||
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) | ||
cfg = kwargs.get("gptq_quantization_config", None) | ||
|
||
if AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear): | ||
if is_gptqmodel_available(): | ||
device_map = kwargs.get("device_map", None) | ||
quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map) | ||
else: | ||
quant_linear = get_auto_gptq_quant_linear(cfg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to suggest for users to move to gptqmodel here or is it too early for that? |
||
|
||
if quant_linear is not None and isinstance(target_base_layer, quant_linear): | ||
new_module = QuantLinear(target, adapter_name, **kwargs) | ||
target.qweight = target_base_layer.qweight | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -232,7 +232,7 @@ def _create_and_replace( | |
lora_bias=lora_config.lora_bias, | ||
) | ||
else: | ||
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) | ||
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=self.model.hf_device_map, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about |
||
if adapter_name not in self.active_adapters: | ||
# adding an additional adapter: it is not automatically trainable | ||
new_module.requires_grad_(False) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,7 @@ | |
|
||
from .testing_utils import ( | ||
require_bitsandbytes, | ||
require_gptq, | ||
require_multi_accelerator, | ||
require_non_cpu, | ||
require_torch_gpu, | ||
|
@@ -79,8 +80,7 @@ | |
from peft.tuners.vera import Linear4bit as VeraLinear4bit | ||
|
||
|
||
@require_non_cpu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can just remove this decorator, since this class contains a bunch of tests that are unrelated to gptqmodel. From my understanding, gptqmodel can be run on CPU. In this case, I would suggest to move the gptqmodel tests to a different test file and leave this class untouched otherwise. They can be combined there with the The new test file (e.g. |
||
class PeftGPUCommonTests(unittest.TestCase): | ||
class PeftCommonTests(unittest.TestCase): | ||
r""" | ||
A common tester to run common operations that are performed on GPU such as generation, loading in 8bit, etc. | ||
""" | ||
|
@@ -383,7 +383,7 @@ def test_ia3_bnb_quantization_from_pretrained_safetensors(self, quantization): | |
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l | ||
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.ia3_l | ||
|
||
@pytest.mark.single_gpu_tests | ||
@require_gptq | ||
def test_lora_gptq_quantization_from_pretrained_safetensors(self): | ||
r""" | ||
Tests that the autogptq quantization using LoRA works as expected with safetensors weights. | ||
|
@@ -403,19 +403,19 @@ def test_lora_gptq_quantization_from_pretrained_safetensors(self): | |
|
||
config = LoraConfig(task_type="CAUSAL_LM") | ||
peft_model = get_peft_model(model, config) | ||
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) | ||
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
peft_model.save_pretrained(tmp_dir) | ||
model = AutoModelForCausalLM.from_pretrained(**kwargs) | ||
model = PeftModel.from_pretrained(model, tmp_dir) | ||
model = prepare_model_for_kbit_training(model) | ||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) | ||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) | ||
|
||
# loading a 2nd adapter works, #1239 | ||
model.load_adapter(tmp_dir, "adapter2") | ||
model.set_adapter("adapter2") | ||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(0)) | ||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device)) | ||
|
||
# check that both adapters are in the same layer | ||
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't assume that the model has a
.hf_device_map
attribute. Theoretically, the model could by any PyTorch model, it doesn't have to be a transformers model -- and even transformers models don't necessarily have the attribute. Therefore, we have to check for this attribute here and if it doesn't exist, don't pass it (or passNone
).