Skip to content

Commit

Permalink
addressed PR changes
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 3, 2024
1 parent 477a26b commit 61e0b56
Show file tree
Hide file tree
Showing 39 changed files with 148 additions and 73 deletions.
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ auto_gptq = ["auto_gptq @ git+https://github.com/AutoGPTQ/AutoGPTQ.git"]
files = ["requirements.txt"]

[tool.hatch.build.targets.wheel]
only-include = ["src/"]
only-include = ["src/fms_acceleration_peft"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def __init__(self, configurations: Dict[str, Dict], use_external_lib: bool = Fal
self._check_config_equal(
key="peft.quantization.auto_gptq.from_quantized", value=True
)
self.use_external_lib = use_external_lib and importlib.util.find_spec("autogptq") is not None
self.use_external_lib = use_external_lib

if self.use_external_lib:
assert importlib.util.find_spec("auto_gptq") is not None, "Unable to use external library, autogptq module not found."

def model_loader(self, model_name: str, **kwargs):
# guarded imports
Expand All @@ -61,9 +64,9 @@ def model_loader(self, model_name: str, **kwargs):
QuantLinear,
)
else:
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel.utils import Backend
from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import (
from .gptqmodel import GPTQModel, QuantizeConfig
from .gptqmodel.utils import Backend
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import (
QuantLinear,
)

Expand Down Expand Up @@ -126,22 +129,21 @@ def model_loader(self, model_name: str, **kwargs):

# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
kwargs["low_cpu_mem_usage"] = True
if is_fsdp_enabled() and self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
kwargs["low_cpu_mem_usage"] = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
Expand Down Expand Up @@ -253,7 +255,7 @@ def augmentation(
replace_module_peft,
)
else:
from gptqmodel.utils.peft import get_gptq_peft_model
from .gptqmodel.utils.peft import get_gptq_peft_model


(peft_config,) = modifiable_args # unpack modifiable args
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .models import GPTQModel
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import Backend, get_backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .auto import MODEL_MAP, GPTQModel
from .base import BaseGPTQModel
from .dbrx import DbrxGPTQ
from .dbrx_converted import DbrxConvertedGPTQ
from .gemma import GemmaGPTQ
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .config import (FORMAT, FORMAT_FIELD_CODE, FORMAT_FIELD_JSON,
QUANT_CONFIG_FILENAME, QUANT_METHOD, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig)
from .gptq import GPTQ
from .quantizer import Quantizer, quantize
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from .backend import Backend, get_backend
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
###############################################################################
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple, Union

import torch
from peft import PeftConfig, PeftModel, PeftType, get_peft_model
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel
from peft.tuners.lora import LoraConfig, LoraModel
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ

from ..models.base import BaseGPTQModel
Expand Down Expand Up @@ -100,22 +98,6 @@ def find_all_linear_names(
results.add(res)
return list(results)


@contextmanager
def hijack_peft_mappings():
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel

try:
yield
except:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
raise
finally:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel

def get_gptq_peft_model(
model: BaseGPTQModel,
peft_config: PeftConfig = None,
Expand All @@ -129,6 +111,9 @@ def get_gptq_peft_model(
if not train_mode and not model_id:
raise ValueError("model_id(where to load adapters) not specified when in inference mode.")

PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel

if train_mode:
peft_type = peft_config.peft_type
if not isinstance(peft_type, str):
Expand All @@ -139,16 +124,9 @@ def get_gptq_peft_model(
if peft_type == PeftType.LORA.value and not isinstance(peft_config, GPTQLoraConfig):
peft_config = GPTQLoraConfig(**peft_config.to_dict())

with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
raise NotImplementedError(
f"{model.__class__.__name__} not support {peft_config.peft_type.value} peft type yet."
)
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)

return peft_model

Expand Down
3 changes: 0 additions & 3 deletions plugins/accelerated-peft/src/gptqmodel/__init__.py

This file was deleted.

10 changes: 0 additions & 10 deletions plugins/accelerated-peft/src/gptqmodel/models/__init__.py

This file was deleted.

Empty file.

This file was deleted.

1 change: 0 additions & 1 deletion plugins/accelerated-peft/src/gptqmodel/utils/__init__.py

This file was deleted.

28 changes: 24 additions & 4 deletions plugins/accelerated-peft/tests/test_gptqmodel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

import pytest # pylint: disable=import-error
import torch
from typing import List
Expand All @@ -14,7 +31,7 @@
BS = 1
SEQLEN = 128

LOSS_TOLERANCE = 1e-3
LOSS_TOLERANCE = 0.1
ALLCLOSE_RTOL = 1e-3
ALLCLOSE_ATOL = 1e-4

Expand Down Expand Up @@ -60,12 +77,12 @@ class TrainArgs:

# quantization function to manage the loading and quantizing of pretrained model
# using external or local autogptq
def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs, device, use_external_lib=False):
def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs, device, torch_dtype, use_external_lib=False):
if use_external_lib:
from auto_gptq import AutoGPTQForCausalLM as GPTQModel, BaseQuantizeConfig as QuantizeConfig
quantize_kwargs = {"use_triton": True}
else:
from gptqmodel import GPTQModel, QuantizeConfig
from fms_acceleration_peft.gptqmodel import GPTQModel, QuantizeConfig
quantize_kwargs = {}

quantize_config = QuantizeConfig(
Expand All @@ -76,6 +93,7 @@ def quantize_model(model_name, config, calibration_dataset, quant_config_kwargs,
model_name,
quantize_config = quantize_config,
config = config,
torch_dtype = getattr(torch, torch_dtype),
).to(device)
# quantize model, the examples should be list of dict whose keys can only be "input_ids"
model.quantize(calibration_dataset, **quantize_kwargs)
Expand Down Expand Up @@ -184,6 +202,7 @@ def test_quantizing_pretrained_model_outputs_match(
calibration_dataset,
quant_config_kwargs,
device,
FLOAT16,
use_external_lib=True
)
refactored_model = quantize_model(
Expand All @@ -192,6 +211,7 @@ def test_quantizing_pretrained_model_outputs_match(
calibration_dataset,
quant_config_kwargs,
device,
FLOAT16,
use_external_lib=False
)

Expand Down Expand Up @@ -228,7 +248,7 @@ def test_quantizing_pretrained_model_outputs_match(
refactored_logits = refactored_model(input_ids).logits

# Measure the distribution error with KD Loss
loss_fn = torch.nn.KLDivLoss(reduction="mean")
loss_fn = torch.nn.KLDivLoss(reduction="batchmean")
# input should be a distribution in the log space
input = torch.nn.functional.log_softmax(refactored_logits, dim=1)
# target must be prob distribution
Expand Down

0 comments on commit 61e0b56

Please sign in to comment.