From e0a35898112e7321e0744ef388483cee076403d2 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Wed, 4 Sep 2024 11:56:20 +0000 Subject: [PATCH] changed peft installation on parent qlinear Signed-off-by: 1000850000 user --- .../src/fms_acceleration_peft/framework_plugin_autogptq.py | 2 +- .../src/fms_acceleration_peft/gptqmodel/utils/peft.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py index b679ced0..fd14871c 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py @@ -320,7 +320,7 @@ def augmentation( model = get_gptq_peft_model( model, peft_config=peft_config, - auto_find_all_linears=peft_config.target_modules == PEFT_ALL_LINEAR, + auto_find_all_linears=(peft_config.target_modules == PEFT_ALL_LINEAR), train_mode=True, # install adapaters for training ) modifiable_args = (None,) # return a None for peft_config diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py index e9327bf2..a6fd4b15 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py @@ -35,7 +35,7 @@ # Local from ..models.base import BaseGPTQModel -from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as QuantLinearTriton +from ..nn_modules.qlinear import BaseQuantLinear class GPTQLoraConfig(LoraConfig): @@ -61,7 +61,7 @@ def _create_new_module( lora_config: LoraConfig, adapter_name: str, target: torch.nn.Module, - target_cls: torch.nn.Module = QuantLinearTriton, + target_cls: torch.nn.Module = BaseQuantLinear, **kwargs, ): # if the base layer module matches a supported class, dispatch the lora linear @@ -97,7 +97,7 @@ def find_all_linear_names( ignore.append(lm_head_name) results = set() for n, m in model.named_modules(): - if isinstance(m, QuantLinearTriton): + if isinstance(m, BaseQuantLinear): res = n.split(".")[-1] if res not in ignore: results.add(res)