Skip to content

Commit

Permalink
Extracted Subset of AutoGPTQ library into Accelerated-Peft Plugin (#48)
Browse files Browse the repository at this point in the history
* added gptqmodel to plugin

* edited peft header

* add package build workflow

* add unit tests on extracted autogptq

* modify autogptq plugin to support both external and extracted autogptq

* addressed additional PR changes

* reintroduce support for low_cpu_mem_usage in extracted lib

* Use transformers package checking instead of importlib

* formatting

* linting

* add additional entry to requirements.txt

* fixed union type backward compatibility with py39

* Fix FOAK dequant for compatibility with local gptq package

* add benchmark comparison script

* modified comparison script

* formatted scripts/

* edited comparison script to detect difference in command args

* addresed PR edits

* updated benchmarks

* Add comment for foak kernel
  • Loading branch information
achew010 authored Jul 15, 2024
1 parent 33a32f5 commit c3a069c
Show file tree
Hide file tree
Showing 47 changed files with 6,012 additions and 188 deletions.
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ignore=CVS,protobufs
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
ignore-paths=.*gptqmodel/,tests/test_q4_triton.py,tests/test_triton.py

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
1 change: 1 addition & 0 deletions plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers=[

[project.optional-dependencies]
flash-attn = ["flash-attn"]
auto_gptq = ["auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7"] # known working commitid

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]
Expand Down
12 changes: 6 additions & 6 deletions plugins/accelerated-peft/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# decide not to have this as an requirement for now
# fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework

# put this in here because there is a breaking FSDP api change that
# is fixed after peft > 0.10
accelerate < 0.29
# Needs a lower bound due to`accelerate.load_checkpoint_in_model` function used in gptqmodel
accelerate >= 0.29

# bitsandbytes for the BNB plugin
bitsandbytes

# Installing from repository because "auto_gptq > 0.7.1" it not yet available
# Specifying the commit id here as recent commits to the main branch have introduced additional dependencies
auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7
# Used to manage the thread limit in functions for converting old
# GPTQ models to new GPTQ model format that support symmetrical=False
# https://github.com/AutoGPTQ/AutoGPTQ/pull/640
threadpoolctl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.import_utils import _is_package_available
import torch
import torch.distributed


class AutoGPTQAccelerationPlugin(AccelerationPlugin):

require_packages = ["auto_gptq"]
require_packages = []

def __init__(self, configurations: Dict[str, Dict]):
def __init__(self, configurations: Dict[str, Dict], use_external_lib: bool = False):
super().__init__(configurations)

# just do checking, nothing must to configure at this point
Expand All @@ -47,18 +48,31 @@ def __init__(self, configurations: Dict[str, Dict]):
self._check_config_equal(
key="peft.quantization.auto_gptq.from_quantized", value=True
)
self.use_external_lib = use_external_lib

if self.use_external_lib:
assert (
_is_package_available("auto_gptq") is True
), "Unable to use external library, autogptq module not found."

def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM,
BaseQuantizeConfig,
)
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

if self.use_external_lib:
# Third Party
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM as GPTQModel,
)
from auto_gptq import BaseQuantizeConfig as QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
else:
from .gptqmodel import GPTQModel, QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.utils import Backend # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
Expand All @@ -85,11 +99,11 @@ def model_loader(self, model_name: str, **kwargs):
# switching to cuda/cuda_old/triton backend."
# assume model_name points to a quantized checkpoint. Thus we load the quantization
# config directly from the checkpoint.
quantize_config = BaseQuantizeConfig.from_pretrained(model_name)
quantize_config = QuantizeConfig.from_pretrained(model_name)

# get additional parameters
torch_dtype = kwargs.get("torch_dtype", torch.float32)
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage", False)
attn_implementation = kwargs.get("attn_implementation")

# there are some kwargs that we wont be passed to AutoModel, so we need
Expand All @@ -101,54 +115,68 @@ def model_loader(self, model_name: str, **kwargs):
)
AutoModelForCausalLM.from_config = _from_config # patch

if self.use_external_lib:
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage,
"use_marlin": False, # disable, cannot be used for training (no forward+backward)
"disable_exllama": True, # disable, cannot be used for training (no backward)
"use_tritonv2": True,
"trainable": True, # only support trainable mode
}
else:
kwargs = {
"low_cpu_mem_usage": low_cpu_mem_usage, # this is only used for device map
"backend": Backend.TRITON,
}

# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
low_cpu_mem_usage = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
# https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
# to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
device_map = {
"": (
(torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
if torch.cuda.is_available()
else None
)
}
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints
# to "cpu" to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and
# will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262

kwargs["device_map"] = {
"": (
(
torch.cuda.current_device()
if not kwargs["low_cpu_mem_usage"]
else "cpu"
)
if torch.cuda.is_available()
else None
)
}

# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
model = AutoGPTQForCausalLM.from_quantized(
model = GPTQModel.from_quantized(
model_name,
quantize_config=quantize_config,
torch_dtype=torch_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
use_marlin=False, # disable, cannot be used for training (no forward+backward)
disable_exllama=True, # disable, cannot be used for training (no backward)
warmup_triton=False, # disable for now as it will try to run the warmup while on CPU
use_tritonv2=True,
trainable=True, # only support trainable mode
device_map=device_map,
**kwargs,
)

# https://github.com/foundation-model-stack/fms-acceleration/pull/15
Expand Down Expand Up @@ -219,19 +247,24 @@ def augmentation(
):
# guarded imports
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)
if self.use_external_lib:
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)

# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)
else:
# Local
from .gptqmodel.utils.peft import get_gptq_peft_model # pylint: disable=import-outside-toplevel,import-error

(peft_config,) = modifiable_args # unpack modifiable args

Expand All @@ -249,31 +282,35 @@ def augmentation(
gradient_checkpointing_kwargs=train_args.gradient_checkpointing_kwargs,
)

# These functions need to replaced due to some incompatibliites
# with newer PEFT packages.
# - on augmentation we call auto_gptq.utils.peft_utils.get_gptq_peft_model
# - this internally calls peft.utils.other.get_peft_model
# - however the problem is that peft API moves very fast, and there are incompatiblities
#
# During peft wrapping there are two key operations
# 1. LoraModel._create_new_module is called to create a LoraLinear layer that is
# compatible with the base layer. For quantized base layers, the LoraLinear
# may be different.
# 2. GPTQLoraModel._replace_module to replace the existing Linear with the LoraLinear.
# Also move to device (which may depend on how base layer is implemented)

# NOTE: GPTQLoraModel inherits from LoraModel, and the _create_new_module method is called
# on the parent. Hence _create_new_module is patched on the parent

# FIXME:
# 1. investigate using BaseGPTQForCausalLM.make_sure_compatible_with_peft
# to see if we can get around the patching

_old_create_new_module = LoraModel._create_new_module
_old_replace_module = GPTQLoraModel._replace_module
_create_new_module = partial(create_new_module_peft, target_cls=QuantLinear)
LoraModel._create_new_module = staticmethod(_create_new_module)
GPTQLoraModel._replace_module = MethodType(replace_module_peft, GPTQLoraModel)
if self.use_external_lib:
# These functions need to replaced due to some incompatibliites
# with newer PEFT packages.
# - on augmentation we call auto_gptq.utils.peft_utils.get_gptq_peft_model
# - this internally calls peft.utils.other.get_peft_model
# - however the problem is that peft API moves very fast, and there are incompatiblities
#
# During peft wrapping there are two key operations
# 1. LoraModel._create_new_module is called to create a LoraLinear layer that is
# compatible with the base layer. For quantized base layers, the LoraLinear
# may be different.
# 2. GPTQLoraModel._replace_module to replace the existing Linear with the LoraLinear.
# Also move to device (which may depend on how base layer is implemented)

# NOTE: GPTQLoraModel inherits from LoraModel,
# and the _create_new_module method is called
# on the parent. Hence _create_new_module is patched on the parent

# FIXME:
# 1. investigate using BaseGPTQForCausalLM.make_sure_compatible_with_peft
# to see if we can get around the patching

_old_create_new_module = LoraModel._create_new_module
_old_replace_module = GPTQLoraModel._replace_module
_create_new_module = partial(create_new_module_peft, target_cls=QuantLinear)
LoraModel._create_new_module = staticmethod(_create_new_module)
GPTQLoraModel._replace_module = MethodType(
replace_module_peft, GPTQLoraModel
)

# Install GPTQ adapters using the AutoGPTQ package (with the above patches)
model = get_gptq_peft_model(
Expand All @@ -284,9 +321,12 @@ def augmentation(
)
modifiable_args = (None,) # return a None for peft_config

# undo the patching for hygine
LoraModel._create_new_module = staticmethod(_old_create_new_module)
GPTQLoraModel._replace_module = MethodType(_old_replace_module, GPTQLoraModel)
if self.use_external_lib:
# undo the patching for hygine
LoraModel._create_new_module = staticmethod(_old_create_new_module)
GPTQLoraModel._replace_module = MethodType(
_old_replace_module, GPTQLoraModel
)

return model, modifiable_args

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Local
from .auto import MODEL_MAP, GPTQModel
from .base import BaseGPTQModel
from .dbrx import DbrxGPTQ
from .dbrx_converted import DbrxConvertedGPTQ
from .gemma import GemmaGPTQ
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Loading

0 comments on commit c3a069c

Please sign in to comment.