Skip to content

Commit

Permalink
Fix formatter (#74)
Browse files Browse the repository at this point in the history
* formatted accelerated-peft

Signed-off-by: 1000850000 user <[email protected]>

* formatted foak

Signed-off-by: 1000850000 user <[email protected]>

---------

Signed-off-by: 1000850000 user <[email protected]>
  • Loading branch information
achew010 authored Aug 27, 2024
1 parent 4224c66 commit 32918eb
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Run formatter
run: |
cd plugins/${{ matrix.plugin_name }}
tox -e fmt
tox -e fmt -- . --check
- name: Run pytest
run: |
cd plugins/${{ matrix.plugin_name }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from functools import partial
from typing import Callable, List

# Third Party
from fms_acceleration.model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
)
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch

from fms_acceleration.model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from functools import partial

# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
module,
torch_dtype,
Expand All @@ -38,9 +42,7 @@ def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(module, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False)
setattr(module, attr_name, attr)

# this patches the forward to convert them back to original
Expand All @@ -51,18 +53,21 @@ def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
torch_dtype=torch.int32, # patch it back to
)


def register_tensors_as_parameters_patch_rule(target_module, torch_dtype):
# Register patch
ModelPatcher.register(
ModelPatcherRule(
rule_id="autogptq_patch_tensors_as_float_parameters",
trigger=ModelPatcherTrigger(check=target_module),
forward_builder = partial(
build_patch_to_view_tensor_to_parameter_for_fsdp_gptq, torch_dtype=torch_dtype
forward_builder=partial(
build_patch_to_view_tensor_to_parameter_for_fsdp_gptq,
torch_dtype=torch_dtype,
),
)
)


def make_sure_no_tensor_in_meta_device(
model,
use_triton: bool,
Expand Down Expand Up @@ -133,6 +138,7 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# Local
from .autogptq_utils import register_tensors_as_parameters_patch_rule


class AutoGPTQAccelerationPlugin(AccelerationPlugin):

require_packages = []
Expand All @@ -57,11 +58,13 @@ def __init__(self, configurations: Dict[str, Dict]):
)

if self.use_external_lib:
from transformers.utils.import_utils import _is_package_available # pylint: disable=import-outside-toplevel
assert (
_is_package_available("auto_gptq") is True
), (
"Unable to use external library, auto_gptq module not found. "
# Third Party
from transformers.utils.import_utils import ( # pylint: disable=import-outside-toplevel
_is_package_available,
)

assert _is_package_available("auto_gptq") is True, (
"Unable to use external library, auto_gptq module not found. "
"Refer to README for installation instructions "
"as a specific version might be required."
)
Expand All @@ -71,19 +74,28 @@ def model_loader(self, model_name: str, **kwargs):
# Third Party
if self.use_external_lib:
# Third Party
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM as GPTQModel,
)
from auto_gptq import BaseQuantizeConfig as QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

from auto_gptq import ( # isort:skip pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM as GPTQModel,
)
from auto_gptq import ( # isort:skip pylint: disable=import-outside-toplevel,import-error
BaseQuantizeConfig as QuantizeConfig,
)
else:
from .gptqmodel import GPTQModel, QuantizeConfig # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.utils import Backend # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
# Local
from .gptqmodel import ( # pylint: disable=import-outside-toplevel,import-error
GPTQModel,
QuantizeConfig,
)
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from .gptqmodel.utils import ( # pylint: disable=import-outside-toplevel,import-error
Backend,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -141,7 +153,9 @@ def model_loader(self, model_name: str, **kwargs):
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
# Local
from .autogptq_utils import make_sure_no_tensor_in_meta_device # pylint: disable=import-outside-toplevel
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
Expand Down Expand Up @@ -250,7 +264,9 @@ def augmentation(
)
else:
# Local
from .gptqmodel.utils.peft import get_gptq_peft_model # pylint: disable=import-outside-toplevel,import-error
from .gptqmodel.utils.peft import ( # pylint: disable=import-outside-toplevel,import-error
get_gptq_peft_model,
)

(peft_config,) = modifiable_args # unpack modifiable args

Expand Down
10 changes: 5 additions & 5 deletions plugins/accelerated-peft/tests/test_gptqmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def load_autogptq_plugin_model(
"peft": {
"quantization": {
"auto_gptq": {
"kernel": "triton_v2",
"from_quantized": True,
"kernel": "triton_v2",
"from_quantized": True,
"use_external_lib": use_external_lib,
}
}
}
}
},
Expand Down Expand Up @@ -292,10 +292,10 @@ def test_quantizing_pretrained_model_outputs_match(
loss_fn = torch.nn.KLDivLoss(reduction="sum")
# input should be a distribution in the log space
input = torch.nn.functional.log_softmax(refactored_logits, dim=-1)
input = input.view(BS*SEQLEN, -1)
input = input.view(BS * SEQLEN, -1)
# target must be prob distribution
target = torch.nn.functional.softmax(original_logits, dim=-1)
target = target.view(BS*SEQLEN, -1)
target = target.view(BS * SEQLEN, -1)
error = loss_fn(input, target)
assert error.lt(
LOSS_TOLERANCE
Expand Down
32 changes: 20 additions & 12 deletions plugins/accelerated-peft/tests/test_peft_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from unittest.mock import patch
import os

# Third Party
Expand All @@ -26,7 +27,6 @@
update_configuration_contents,
)
import pytest
from unittest.mock import patch

MODEL_NAME_AUTO_GPTQ = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"

Expand Down Expand Up @@ -89,8 +89,10 @@ def test_configure_gptq_plugin():

e.match(f"AutoGPTQAccelerationPlugin: Value at '{key}'")


def test_autogptq_loading():
"Test for correctness of autogptq loading logic"

def autogptq_unavailable(package_name: str):
return False

Expand All @@ -100,13 +102,12 @@ def autogptq_unavailable(package_name: str):
# 3. check when using external package and it is not available, an AssertionError is thrown
with pytest.raises(
AssertionError,
match = "Unable to use external library, auto_gptq module not found. "
"Refer to README for installation instructions as a specific version might be required."
match="Unable to use external library, auto_gptq module not found. "
"Refer to README for installation instructions as a specific version might be required.",
):
with patch(
"transformers.utils.import_utils."
"_is_package_available",
autogptq_unavailable,
"transformers.utils.import_utils._is_package_available",
autogptq_unavailable,
):
with instantiate_framework(
update_configuration_contents(
Expand All @@ -118,7 +119,11 @@ def autogptq_unavailable(package_name: str):
) as framework:
pass

from fms_acceleration_peft.framework_plugin_autogptq import AutoGPTQAccelerationPlugin # pylint: disable=import-outside-toplevel
# First Party
from fms_acceleration_peft.framework_plugin_autogptq import ( # pylint: disable=import-outside-toplevel
AutoGPTQAccelerationPlugin,
)

# - Test that plugin attribute is set when config field `use_external_lib` is False
# When plugin attribute is set correctly, it will route to correct package on model loading
with instantiate_framework(
Expand All @@ -131,21 +136,24 @@ def autogptq_unavailable(package_name: str):
) as framework:
for _, plugin in framework.active_plugins:
if isinstance(plugin, AutoGPTQAccelerationPlugin):
assert plugin.use_external_lib is False, \
"Plugin attribute not correctly set from config field"
assert (
plugin.use_external_lib is False
), "Plugin attribute not correctly set from config field"

# - Test that plugin attribute is set when config field `use_external_lib` is None
# When plugin attribute is set correctly, it will route to correct package on model loading
default_config = read_configuration(CONFIG_PATH_AUTO_GPTQ)
default_config['peft']['quantization']['auto_gptq'].pop('use_external_lib')
default_config["peft"]["quantization"]["auto_gptq"].pop("use_external_lib")
with instantiate_framework(
default_config,
require_packages_check=False,
) as framework:
for _, plugin in framework.active_plugins:
if isinstance(plugin, AutoGPTQAccelerationPlugin):
assert plugin.use_external_lib is False, \
"Plugin attribute not correctly set from config field"
assert (
plugin.use_external_lib is False
), "Plugin attribute not correctly set from config field"


# We do not enable the skip since this test does not actually require the packages
# installed
Expand Down
3 changes: 3 additions & 0 deletions plugins/accelerated-peft/tests/test_q4_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@

CUDA_AVAILABLE = False
if torch.cuda.is_available():
# First Party
from fms_acceleration_peft.gptqmodel import Backend, GPTQModel # noqa: E402
from fms_acceleration_peft.gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # noqa: E402
QuantLinear as TritonV2QuantLinear,
)

CUDA_AVAILABLE = True


GENERATE_EVAL_SIZE = 100


class TestsQ4Triton(unittest.TestCase):
@unittest.skipIf(
CUDA_AVAILABLE is False,
Expand Down
4 changes: 3 additions & 1 deletion plugins/accelerated-peft/tests/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@

CUDA_AVAILABLE = False
if torch.cuda.is_available():
# First Party
from fms_acceleration_peft.gptqmodel import Backend, GPTQModel # noqa: E402
CUDA_AVAILABLE = True

CUDA_AVAILABLE = True

MODEL_ID = "TheBloke/Llama-7B-GPTQ"
DATASET_ID = "timdettmers/openassistant-guanaco"
Expand Down
4 changes: 2 additions & 2 deletions plugins/attention-and-distributed-packing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ otherwise if `transformers < v4.44.0` the plugin will use an internal implementa
To reproduce the benchmarks, simply run the following commands,

Reproduce [Padding Free on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_pf.csv)
`bash scripts/run_benchmarks.sh "1 2" "4 8" benchmark_outputs scenarios-orca.yaml "none"`
`tox -e run-benches -- "1 2" "4 8" benchmark_outputs scenarios-orca.yaml "none"`

Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv)
`bash scripts/run_benchmarks.sh "2 4 8" "16 32 64" benchmark_outputs scenarios-orca.yaml "padding-free"`
`tox -e run-benches -- "2 4 8" "16 32 64" benchmark_outputs scenarios-orca.yaml "padding-free"`

## Known Issues

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
import torch.distributed as dist


# consider moving this somewhere else later
def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
"""
Expand Down Expand Up @@ -58,9 +59,20 @@ def _all_reduce_hook(grad):
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")


def register_foak_model_patch_rules(base_type):
from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
from .models import llama, mistral, mixtral # pylint: disable=import-outside-toplevel
# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
ModelPatcher,
)

# Local
from .models import ( # pylint: disable=import-outside-toplevel
llama,
mistral,
mixtral,
)

rules = [
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
Expand All @@ -69,6 +81,7 @@ def register_foak_model_patch_rules(base_type):
for _rule in rules:
ModelPatcher.register(_rule)


class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):

# NOTE: may remove this when we have generic model rules
Expand Down Expand Up @@ -122,7 +135,7 @@ def augmentation(
), "need to run in fp16 mixed precision or load model in fp16"

# wrapper function to register foak patches
register_foak_model_patch_rules(base_type = self._base_layer)
register_foak_model_patch_rules(base_type=self._base_layer)
return model, modifiable_args

def get_callbacks_and_ready_for_train(
Expand Down
Loading

0 comments on commit 32918eb

Please sign in to comment.