Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Model Patcher Class #55

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7941ed7
set main to track current plugin versions
achew010 Jul 17, 2024
4b871d0
move model_patcher to framework
achew010 Jul 17, 2024
3bf9a55
replace local patching with model_patcher
achew010 Jul 18, 2024
815b0c8
add additional unit tests
achew010 Jul 18, 2024
7efbfed
remove redundant patch function
achew010 Jul 18, 2024
33258ba
shifted patch summary logging to framework plugin and patch id renames
achew010 Jul 18, 2024
af7009c
modified unit tests from PR comments
achew010 Jul 20, 2024
6b6fca9
incremental refactor of unit tests
achew010 Jul 22, 2024
252a73c
changes to mp trigger unit tests
achew010 Jul 23, 2024
94e217e
additional changes to trigger unit tests
achew010 Jul 23, 2024
a31bf6e
adding MP Rule unit tests
achew010 Jul 23, 2024
2683d9e
add context manager to isolate patching unit tests
achew010 Jul 24, 2024
748595c
some fixes
fabianlim Jul 24, 2024
9438aba
clarified comments
fabianlim Jul 25, 2024
8c825d9
modelpatcher unit tests
achew010 Jul 24, 2024
df95ece
added forward_builder fn unit test
achew010 Jul 25, 2024
e653b80
lint changes
achew010 Jul 25, 2024
e6f2284
more lint changes
achew010 Jul 25, 2024
736e706
file renaming and added license headers on new files
achew010 Jul 26, 2024
7c302ba
added guard to patch model only if model exist in framework plugin ca…
achew010 Jul 26, 2024
cd253b3
replaced buggy partial wrapping on ModelPatcher.patch and set tox env…
achew010 Jul 27, 2024
1d498e0
additional linting
achew010 Jul 28, 2024
a4f8800
shifted patch trigger to main framework class
achew010 Jul 29, 2024
ac31192
additional modifications to foak patch rules
achew010 Jul 29, 2024
8895cad
linting
achew010 Jul 29, 2024
f6848a7
additional changes from comments
achew010 Jul 29, 2024
5e535b2
fixes to mp unit test
achew010 Jul 29, 2024
c204c86
updated with new benchmark results
achew010 Jul 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration-peft"
version = '0.0.1'
version = '0.1.0.1.dev'
description = "FMS Acceleration for PeFT"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,45 @@
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch

from fms_acceleration.model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from functools import partial

# these parameters are to be patched for triton v2
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
def build_patch_to_view_tensor_to_parameter_for_fsdp_gptq(
module,
torch_dtype,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"

to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
to_patch = ".".join(to_patch)
source = importlib.import_module(to_patch)
original_obj = getattr(source, obj_name_to_patch)
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)

# replace it
setattr(source, obj_name_to_patch, original_obj)
# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(module, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(module, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
return patch_forward_to_view_attributes_before_call(
module.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)

def load_fsdp_gptq_patch(target_module, torch_dtype):
# Register patch
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
ModelPatcher.register(
ModelPatcherRule(
rule_id="autogptq_patch_tensors_as_float_parameters",
trigger=ModelPatcherTrigger(check=target_module),
forward_builder = build_patch_to_view_tensor_to_parameter_for_fsdp_gptq,
forward_builder_args=["torch_dtype"],
)
)
ModelPatcher.patch = partial(ModelPatcher.patch, torch_dtype=torch_dtype)

def make_sure_no_tensor_in_meta_device(
model,
Expand Down Expand Up @@ -124,7 +134,6 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# Third Party
from fms_acceleration import AccelerationPlugin
from fms_acceleration.model_patcher import patch_target_module
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
Expand Down Expand Up @@ -81,11 +82,6 @@ def model_loader(self, model_name: str, **kwargs):
from .gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -143,14 +139,11 @@ def model_loader(self, model_name: str, **kwargs):
kwargs["low_cpu_mem_usage"] = True
if self.use_external_lib:
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)
from .autogptq_utils import make_sure_no_tensor_in_meta_device # pylint: disable=import-outside-toplevel

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
Expand Down Expand Up @@ -201,31 +194,12 @@ def model_loader(self, model_name: str, **kwargs):
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
# register FSDP patch
from .autogptq_utils import load_fsdp_gptq_patch
load_fsdp_gptq_patch(target_module = QuantLinear, torch_dtype = torch_dtype)

# patch all the QuantLinear base layers
for mod in model.modules():
if isinstance(mod, QuantLinear):

# convert all patched attributes to Parameters of torch_dtype
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
attr = torch.nn.Parameter(
attr.view(torch_dtype), requires_grad=False
)
setattr(mod, attr_name, attr)

# this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
mod.forward,
attribute_names=PATCH_FOR_FSDP_TRITON_V2,
torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)

# replace
AutoModelForCausalLM.from_config = _old_from_config
# replace
AutoModelForCausalLM.from_config = _old_from_config

# AutoGPTQ does not set the torch_dtype of the model carefully
model.config.torch_dtype = torch_dtype
Expand Down
2 changes: 1 addition & 1 deletion plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fms-acceleration"
version = '0.1.1.dev'
version = '0.1.2.dev'
description = "FMS Acceleration Plugin Framework"
authors = [
{name = "Fabian Lim", email = "[email protected]"},
Expand Down
33 changes: 32 additions & 1 deletion plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Standard
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Callable
import importlib
import sys

Expand All @@ -24,6 +24,29 @@
from transformers import TrainingArguments
import torch

from transformers.utils import logging

# want to use the transformers logger, but a bit of pain
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger.setLevel(logging._get_default_logging_level())
logger.addHandler(logging._default_handler)

def log_patch_summary(
logging_func: Callable = None,
):
if logging_func is None:
logging_func = print

# this is a guarded import, because the model rule registration
# does not need to be loaded unless patch_model is required
# Local
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
patch_model_summary,
)

for line in patch_model_summary().split("\n"):
logging_func(line)


@dataclass
class PluginRegistration:
Expand Down Expand Up @@ -146,6 +169,14 @@ def augmentation(
def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
):
# Finally apply all registered patches to the model
from .model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
ModelPatcher.patch(model)

# if patching is done, print patch summary to logger
if len(ModelPatcher.history)>0:
log_patch_summary(logging_func=logger.info)

return []

def _check_config_and_maybe_check_values(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# ------------------------ helpers -----------------------


def _patch_target_module(
def patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
Expand Down Expand Up @@ -310,7 +310,7 @@ def _import_and_reload(model: torch.nn.Module):
# handle those with reload first
for rule in _with_reload + _no_reload:
_target, _object, _reload = rule.import_and_maybe_reload
_patch_target_module(_target, _object, _reload)
patch_target_module(_target, _object, _reload)
ModelPatcher.history.append(
ModelPatcherHistory(
instance=id(model),
Expand Down
11 changes: 11 additions & 0 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, Callable, Dict, List, Set, Tuple, Type

# Third Party
from torch.nn import CrossEntropyLoss
import torch
import yaml

Expand Down Expand Up @@ -180,3 +181,13 @@ def dummy_augmentation(self, model, train_args, modifiable_args):
def dummy_custom_loader(self, model_name, **kwargs):
"dummy custom loader returning dummy model"
return create_noop_model_with_archs(archs=["DummyModel"]) #


class DummyModule(torch.nn.Module):
def __init__(self, hidden_size, *args, **kwargs) -> None:
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
self.linear = torch.nn.Linear(hidden_size, hidden_size)
self.loss_fn = CrossEntropyLoss()

def forward(self, X):
return self.linear(X)
136 changes: 136 additions & 0 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
import pytest # pylint: disable=(import-error
import torch

# First Party
from fms_acceleration.model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
patch_target_module,
)
from fms_acceleration.utils.test_utils import DummyModule

DUMMY_RULE_ID = "test_patch"
DUMMY_HIDDEN_DIM = 32


class DummyCrossEntropyLoss(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, inputs, targets):
return -1


@pytest.fixture()
def model():
return DummyModule(DUMMY_HIDDEN_DIM)

@pytest.fixture()
def model_inputs(seed: int = 42):
torch.manual_seed(seed)
return torch.rand(1, DUMMY_HIDDEN_DIM)

def test_rule_registration_and_simple_forward_patching(model_inputs, model): # pylint: disable=redefined-outer-name
"Test model patcher replaces the forward function with a dummy forward"
# 1. Register rule and specify a trigger on target module for the rule to be applied
# 2. Patch model
# 3. check target module's forward function and dummy patch produces similar outputs
dummy_forward_to_patch = lambda self, X: X * 2 # pylint: disable=unnecessary-lambda-assignment
ModelPatcher.rules.pop(DUMMY_RULE_ID, None)
rule = ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
trigger=ModelPatcherTrigger(check=DummyModule),
forward=dummy_forward_to_patch,
)
ModelPatcher.register(rule)
assert DUMMY_RULE_ID in ModelPatcher.rules.keys(), "Rule Registration Failed" # pylint: disable=consider-iterating-dictionary
ModelPatcher.patch(model)
assert torch.allclose(
model(model_inputs), model_inputs * 2
), "Failed to patch forward function"


# Test patching of model attribute
def test_patching_downstream_module(model): # pylint: disable=redefined-outer-name
"Test patching an imported module indirectly managed by other modules using import_and_reload"
# 1. Register rule targeting downstream module and specify target to reload with patch applied
# 2. Patch model
# 3. check patched module now exist in model
ModelPatcher.rules.pop(DUMMY_RULE_ID, None)
fabianlim marked this conversation as resolved.
Show resolved Hide resolved

# Reload because we only want to patch CrossEntropyLoss for this target module
ModelPatcher.register(
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
import_and_maybe_reload=(
"torch.nn.CrossEntropyLoss",
DummyCrossEntropyLoss,
"fms_acceleration.utils.test_utils",
),
)
)
ModelPatcher.patch(model)
assert isinstance(
DummyModule(DUMMY_HIDDEN_DIM).loss_fn, DummyCrossEntropyLoss
), "Failed to patch attribute with import and reload"


# Test patching standalone functions
def test_patching_standalone_function(model_inputs): # pylint: disable=redefined-outer-name
"Test patching of standalone file functions"
# 1. Take an arbitrary function
# 2. replace with a dummy function
# 3. check that the arbitrary function and dummy functions produces similar outputs
dummy_function_to_patch = lambda X: X # pylint: disable=unnecessary-lambda-assignment
patch_target_module(
"fms_acceleration.utils.test_utils.read_configuration",
dummy_function_to_patch,
)
# First Party
from fms_acceleration.utils.test_utils import read_configuration # pylint: disable=import-outside-toplevel

assert torch.allclose(
read_configuration(model_inputs), model_inputs
), "Failed to patch standalone function"


def test_forward_patching_with_forward_builder(model_inputs, model): # pylint: disable=redefined-outer-name
"Test model patcher replaces forward using a dummy forward building function"

def dummy_forward_builder(module, multiplier):
# can apply modifications to module here
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
setattr(module, "dummy_attribute", True)
return lambda self, X: X * multiplier

ModelPatcher.rules.pop(DUMMY_RULE_ID, None)
ModelPatcher.register(
ModelPatcherRule(
rule_id=DUMMY_RULE_ID,
trigger=ModelPatcherTrigger(check=DummyModule),
forward_builder=dummy_forward_builder,
forward_builder_args=["multiplier"],
)
)
ModelPatcher.patch(model, multiplier=4)
assert hasattr(model, "dummy_attribute") and torch.allclose(
model(model_inputs), model_inputs * 4
), "Failed to patch forward function with forward building feature"
Loading
Loading