Skip to content

Commit

Permalink
Refactored Model Patcher Class (#55)
Browse files Browse the repository at this point in the history
* set main to track current plugin versions

* move model_patcher to framework

* replace local patching with model_patcher

* add additional unit tests

* remove redundant patch function

* shifted patch summary logging to framework plugin and patch id renames

* modified unit tests from PR comments

* incremental refactor of unit tests

* changes to mp trigger unit tests

* additional changes to trigger unit tests

* adding MP Rule unit tests

* add context manager to isolate patching unit tests

* some fixes

* clarified comments

* modelpatcher unit tests

* added forward_builder fn unit test

* lint changes

* more lint changes

* file renaming and added license headers on new files

* added guard to patch model only if model exist in framework plugin callback

* replaced buggy partial wrapping on ModelPatcher.patch and set tox env to allow triton access to global constexpr

* additional linting

* shifted patch trigger to main framework class

* additional modifications to foak patch rules

* linting

* additional changes from comments

* fixes to mp unit test

* updated with new benchmark results

---------

Co-authored-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
achew010 and fabianlim authored Jul 29, 2024
1 parent f4cf311 commit b6c1455
Show file tree
Hide file tree
Showing 33 changed files with 1,500 additions and 487 deletions.
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration-peft"
version = '0.0.1'
version = '0.1.0.1.dev'
description = "FMS Acceleration for PeFT"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,43 +16,52 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from typing import Any, Callable, List
import importlib
from typing import Callable, List

# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch

from fms_acceleration.model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from functools import partial

# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
module,
torch_dtype,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"

to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
to_patch = ".".join(to_patch)
source = importlib.import_module(to_patch)
original_obj = getattr(source, obj_name_to_patch)
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)

# replace it
setattr(source, obj_name_to_patch, original_obj)
# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(module, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(module, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
return patch_forward_to_view_attributes_before_call(
module.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)

def register_tensors_as_parameters_patch_rule(target_module, torch_dtype):
# Register patch
ModelPatcher.register(
ModelPatcherRule(
rule_id="autogptq_patch_tensors_as_float_parameters",
trigger=ModelPatcherTrigger(check=target_module),
forward_builder = partial(
build_patch_to_view_tensor_to_parameter_for_fsdp_gptq, torch_dtype=torch_dtype
),
)
)

def make_sure_no_tensor_in_meta_device(
model,
Expand Down Expand Up @@ -124,7 +133,6 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@

# Third Party
from fms_acceleration import AccelerationPlugin
from fms_acceleration.model_patcher import patch_target_module
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed

# Local
from .autogptq_utils import register_tensors_as_parameters_patch_rule

class AutoGPTQAccelerationPlugin(AccelerationPlugin):

Expand Down Expand Up @@ -81,11 +84,6 @@ def model_loader(self, model_name: str, **kwargs):
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -143,14 +141,11 @@ def model_loader(self, model_name: str, **kwargs):
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)
from .autogptq_utils import make_sure_no_tensor_in_meta_device # pylint: disable=import-outside-toplevel

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
Expand Down Expand Up @@ -201,31 +196,14 @@ def model_loader(self, model_name: str, **kwargs):
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# register FSDP patch
register_tensors_as_parameters_patch_rule(
target_module=QuantLinear,
torch_dtype=torch_dtype,
)

# patch all the QuantLinear base layers
for mod in model.modules():
if isinstance(mod, QuantLinear):

# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(mod, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
mod.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config
# replace
AutoModelForCausalLM.from_config = _old_from_config

# AutoGPTQ does not set the torch_dtype of the model carefully
model.config.torch_dtype = torch_dtype
Expand Down
3 changes: 2 additions & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration"
version = '0.1.1.dev'
version = '0.1.2.dev'
description = "FMS Acceleration Plugin Framework"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand All @@ -27,6 +27,7 @@ dependencies = [
"transformers<4.40",
"peft",
"accelerate",
"pandas",
]

[tool.hatch.build.targets.wheel]
Expand Down
25 changes: 25 additions & 0 deletions plugins/framework/src/fms_acceleration/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@
logger.setLevel(logging._get_default_logging_level())
logger.addHandler(logging._default_handler)

def log_patch_summary(
logging_func: Callable = None,
):
if logging_func is None:
logging_func = print

# this is a guarded import, because the model rule registration
# does not need to be loaded unless patch_model is required
# Local
from .model_patcher import ( # pylint: disable=import-outside-toplevel
patch_model_summary,
)

for line in patch_model_summary().split("\n"):
logging_func(line)


def check_plugin_packages(plugin: AccelerationPlugin):
if plugin.require_packages is None:
Expand Down Expand Up @@ -207,6 +223,12 @@ def requires_agumentation(self):
def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
):

from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
if model is not None:
# Finally apply all registered patches to the model
ModelPatcher.patch(model)

# show the initialized message
if accelerator is not None and accelerator.is_main_process:
log_initialization_message(
Expand All @@ -215,6 +237,9 @@ def get_callbacks_and_ready_for_train(
logging_func=logger.info,
)

# if patching is done, print patch summary to logger
log_patch_summary(logging_func=logger.info)

cbks = []
for _, plugin in self.active_plugins:
cbks.extend(plugin.get_callbacks_and_ready_for_train(model, accelerator))
Expand Down
1 change: 0 additions & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from transformers import TrainingArguments
import torch


@dataclass
class PluginRegistration:
plugin: "AccelerationPlugin"
Expand Down
Loading

0 comments on commit b6c1455

Please sign in to comment.