Skip to content

Commit

Permalink
addressed additional PR changes
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 4, 2024
1 parent 477a26b commit b42d401
Show file tree
Hide file tree
Showing 40 changed files with 151 additions and 54 deletions.
4 changes: 2 additions & 2 deletions plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ classifiers=[

[project.optional-dependencies]
flash-attn = ["flash-attn"]
auto_gptq = ["auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git"]
auto_gptq = ["auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git@ea829c7bbe83561c2b1de26795b6592992373ef7"] # known working commitid

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]

[tool.hatch.build.targets.wheel]
only-include = ["src/"]
only-include = ["src/fms_acceleration_peft"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
5 changes: 2 additions & 3 deletions plugins/accelerated-peft/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# decide not to have this as an requirement for now
# fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework

# put this in here because there is a breaking FSDP api change that
# is fixed after peft > 0.10
accelerate <= 0.29
# Needs a lower bound due to`accelerate.load_checkpoint_in_model` function used in gptqmodel
accelerate >= 0.29

# bitsandbytes for the BNB plugin
bitsandbytes
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __init__(self, configurations: Dict[str, Dict], use_external_lib: bool = Fal
self._check_config_equal(
key="peft.quantization.auto_gptq.from_quantized", value=True
)
self.use_external_lib = use_external_lib and importlib.util.find_spec("autogptq") is not None
self.use_external_lib = use_external_lib

if self.use_external_lib:
assert importlib.util.find_spec("auto_gptq") is not None, "Unable to use external library, autogptq module not found."

def model_loader(self, model_name: str, **kwargs):
# guarded imports
Expand All @@ -61,9 +64,9 @@ def model_loader(self, model_name: str, **kwargs):
QuantLinear,
)
else:
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.utils import Backend
from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import (
from .gptqmodel import GPTQModel, QuantizeConfig
from .gptqmodel.utils import Backend
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import (
QuantLinear,
)

Expand Down Expand Up @@ -126,22 +129,21 @@ def model_loader(self, model_name: str, **kwargs):

# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
kwargs["low_cpu_mem_usage"] = True
if is_fsdp_enabled() and self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
kwargs["low_cpu_mem_usage"] = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
Expand Down Expand Up @@ -253,7 +255,7 @@ def augmentation(
replace_module_peft,
)
else:
from gptqmodel.utils.peft import get_gptq_peft_model
from .gptqmodel.utils.peft import get_gptq_peft_model


(peft_config,) = modifiable_args # unpack modifiable args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .auto import MODEL_MAP, GPTQModel
from .base import BaseGPTQModel
from .dbrx import DbrxGPTQ
from .dbrx_converted import DbrxConvertedGPTQ
from .gemma import GemmaGPTQ
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .config import (FORMAT, FORMAT_FIELD_CODE, FORMAT_FIELD_JSON,
QUANT_CONFIG_FILENAME, QUANT_METHOD, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig)
from .gptq import GPTQ
from .quantizer import Quantizer, quantize
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .backend import Backend, get_backend
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
###############################################################################
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union

import torch
from peft import PeftConfig, PeftModel, PeftType, get_peft_model
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ

from ..models.base import BaseGPTQModel
Expand Down Expand Up @@ -100,7 +99,6 @@ def find_all_linear_names(
results.add(res)
return list(results)


@contextmanager
def hijack_peft_mappings():
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
Expand Down Expand Up @@ -139,16 +137,19 @@ def get_gptq_peft_model(
if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
peft_config = GPTQLoraConfig(**peft_config.to_dict())

# this hijack is needed as `get_peft_model` uses PEFTModelForCausalLM which inherits from
# PEFTModel and it in turn relies on PEFT_TYPE_TO_MODEL_MAPPING to initialize its base LoraModel
with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
except Exception as exc:
raise NotImplementedError(
f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
)
f"{model.__class__.__name__} not support \
{peft_config.peft_type.value} peft type yet."
) from exc

return peft_model

Expand Down
3 changes: 0 additions & 3 deletions plugins/accelerated-peft/src/gptqmodel/__init__.py

This file was deleted.

10 changes: 0 additions & 10 deletions plugins/accelerated-peft/src/gptqmodel/models/__init__.py

This file was deleted.

Empty file.

This file was deleted.

1 change: 0 additions & 1 deletion plugins/accelerated-peft/src/gptqmodel/utils/__init__.py

This file was deleted.

28 changes: 24 additions & 4 deletions plugins/accelerated-peft/tests/test_gptqmodel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

import pytest # pylint: disable=import-error
import torch
from typing import List
Expand All @@ -14,7 +31,7 @@
BS = 1
SEQLEN = 128

LOSS_TOLERANCE = 1e-3
LOSS_TOLERANCE = 0.1
ALLCLOSE_RTOL = 1e-3
ALLCLOSE_ATOL = 1e-4

Expand Down Expand Up @@ -60,12 +77,12 @@ class TrainArgs:

# quantization function to manage the loading and quantizing of pretrained model
# using external or local autogptq
def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs, device, use_external_lib=False):
def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs, device, torch_dtype, use_external_lib=False):
if use_external_lib:
from auto_gptq import AutoGPTQForCausalLM as GPTQModel, BaseQuantizeConfig as QuantizeConfig
quantize_kwargs = {"use_triton": True}
else:
from gptqmodel import GPTQModel, QuantizeConfig
from fms_acceleration_peft.gptqmodel import GPTQModel, QuantizeConfig
quantize_kwargs = {}

quantize_config = QuantizeConfig(
Expand All @@ -76,6 +93,7 @@ def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs,
model_name,
quantize_config = quantize_config,
config = config,
torch_dtype = getattr(torch, torch_dtype),
).to(device)
# quantize model, the examples should be list of dict whose keys can only be "input_ids"
model.quantize(calibration_dataset, **quantize_kwargs)
Expand Down Expand Up @@ -184,6 +202,7 @@ def test_quantizing_pretrained_model_outputs_match(
calibration_dataset,
quant_config_kwargs,
device,
FLOAT16,
use_external_lib=True
)
refactored_model = quantize_model(
Expand All @@ -192,6 +211,7 @@ def test_quantizing_pretrained_model_outputs_match(
calibration_dataset,
quant_config_kwargs,
device,
FLOAT16,
use_external_lib=False
)

Expand Down Expand Up @@ -228,7 +248,7 @@ def test_quantizing_pretrained_model_outputs_match(
refactored_logits = refactored_model(input_ids).logits

# Measure the distribution error with KD Loss
loss_fn = torch.nn.KLDivLoss(reduction="mean")
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")
# input should be a distribution in the log space
input = torch.nn.functional.log_softmax(refactored_logits, dim=1)
# target must be prob distribution
Expand Down

0 comments on commit b42d401

Please sign in to comment.