Skip to content

Commit

Permalink
lint changes
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Jul 25, 2024
1 parent df95ece commit e653b80
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 57 deletions.
3 changes: 1 addition & 2 deletions plugins/framework/src/fms_acceleration/framework_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
# Third Party
from accelerate import Accelerator
from peft import LoraConfig
from transformers.utils import logging
from transformers import TrainingArguments
import torch

from transformers.utils import logging

# want to use the transformers logger, but a bit of pain
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger.setLevel(logging._get_default_logging_level())
Expand Down
39 changes: 18 additions & 21 deletions plugins/framework/src/fms_acceleration/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,6 @@ def __post_init__(self):
"forward_builder."
)

# if self.import_and_maybe_reload is not None and self.import_and_maybe_reload[2] in self.import_and_maybe_reload[0]:
# raise ValueError(
# f"Rule '{self.rule_id}' import_and_maybe_reload specified has argument 3 in the same path "
# "as argument 1. The path to reload has to be different from object to be patched."
# )




# helpful to keep a history of all patching that has been done
@dataclass
class ModelPatcherHistory:
Expand Down Expand Up @@ -269,8 +260,8 @@ def register(rule: ModelPatcherRule):

@staticmethod
def did_rule_trigger(module: torch.nn.Module, module_name: str):
active_rule_name, active_rule = None, None

active_rule_name, active_rule = None, None
for name, rule in ModelPatcher.rules.items():

# if there is no trigger
Expand All @@ -283,11 +274,12 @@ def did_rule_trigger(module: torch.nn.Module, module_name: str):
active_rule_name = name
active_rule = rule
# otherwise, if there is already an active rule, raise warning
# that subsequent compatible forward rules will be ignored for simple forward patches
# forwardbuilders are handled when they are decomposed into new simple forward rules
# that subsequent compatible forward rules will be ignored
# for simple forward patches. forward_builder args are handled
# when they are decomposed into new simple forward rules
elif rule.forward is not None:
warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule {active_rule.rule_id} has been applied")
#raise Exception(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule has been applied")
warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an \
earlier rule {active_rule.rule_id} has been applied")

return active_rule_name, active_rule

Expand Down Expand Up @@ -338,18 +330,23 @@ def _import_and_reload(model: torch.nn.Module):
elif _target.startswith(module_path):
_no_reload.append(rule)

# If there are multiple reload targets,
# If there are multiple reload targets,
# ensure that their paths do not conflict as reloading same module might reset patches
if len(_with_reload)>1:
# sort ascending target path length
_with_reload = sorted(_with_reload, key=lambda _rule: len(_rule.import_and_maybe_reload[2]), reverse=False)
_with_reload = sorted(
_with_reload,
key=lambda _rule: len(_rule.import_and_maybe_reload[2]),
reverse=False
)
for rule_s in _with_reload:
for rule_l in _with_reload[1:]:
# if target paths in rule s is a prefix of rule l, raise an error
_, _, _path_s = rule_s.import_and_maybe_reload
_, _, _path_l = rule_l.import_and_maybe_reload
assert not _path_l.startswith(_path_s), \
f"Attempting to reload same path `{_path_s}` multiple times in {rule_s.rule_id} and {rule_l.rule_id}"
f"Attempting to reload same path `{_path_s}` multiple times in \
{rule_s.rule_id} and {rule_l.rule_id}"

# handle those with reload first
for rule in _with_reload + _no_reload:
Expand Down Expand Up @@ -469,11 +466,11 @@ def patch(model: torch.nn.Module, **kwargs):
# only once. We do not have any checks for this at the moment

# 1. Iterate over all ModelPatcher rules
# 2. For import_and_maybe_reload rules, an assertion
# 2. For import_and_maybe_reload rules, an assertion
# is currently thrown if there are multiple reloads
# 3. For _patch_forwards, ensure that the trigger check
# 3. For _patch_forwards, ensure that the trigger check
# module or callable function is unique across all rules
# otherwise, an assertion is thrown as it could patch the
# otherwise, an assertion is thrown as it could patch the
# forwards over previous patches

try:
Expand Down
6 changes: 3 additions & 3 deletions plugins/framework/src/fms_acceleration/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def dummy_custom_loader(self, model_name, **kwargs):

@contextmanager
def instantiate_model_patcher():
from fms_acceleration.model_patcher import ModelPatcher
from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel
old_registrations = ModelPatcher.rules
ModelPatcher.rules = {}
yield
ModelPatcher.rules = old_registrations
yield
ModelPatcher.rules = old_registrations
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .module3 import Module3Class
from .module1_1 import Module1Class, mod_1_function
from .module1_1 import Module1Class, mod_1_function
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ def __init__(self) -> None:
self.attribute = Module2Class()

def mod_1_function():
return "unpatched_mod_function"
return "unpatched_mod_function"

Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .module3_1 import Module3Class
from .module3_1 import Module3Class
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
def mod_4_function():
return "unpatched_mod_function"
return "unpatched_mod_function"

Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .module5_1 import Module5Class, mod_5_function
from .module5_1 import Module5Class, mod_5_function
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch

class Module5Class(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(self) -> None:
super().__init__()

def mod_5_function():
return "unpatched_mod_function"
return "unpatched_mod_function"

14 changes: 8 additions & 6 deletions plugins/framework/tests/model_patcher_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@
@contextmanager
def isolate_test_module_fixtures():
old_mod = {
k: sys.modules[k] for k in MODULE_PATHS
k: sys.modules[k] for k in MODULE_PATHS if k in sys.modules
}
yield

# Reload only reloads the speicified module, but makes not attempt to reload
# the imports of that module.
# the imports of that module.
# - i.e., This moeans that if and import had been changed
# then the reload will take the changed import.
# then the reload will take the changed import.
# - i.e., This also means that the individuals must be reloaded seperatedly
# for a complete reset.
#
#
# Therefore, we need to reload ALL Modules in opposite tree order, meaning that
# the children must be reloaded before their parent

Expand All @@ -43,7 +43,9 @@ def isolate_test_module_fixtures():

def create_module_class(
class_name: str,
namespaces: Dict[str, Any] = {},
namespaces: Dict[str, Any] = None,
parent_class: Type = torch.nn.Module
):
return type(class_name, (parent_class,), namespaces)
if namespaces is None:
namespaces = {}
return type(class_name, (parent_class,), namespaces)
26 changes: 15 additions & 11 deletions plugins/framework/tests/test_model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,15 @@

# First Party
from fms_acceleration.model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
patch_target_module,
ModelPatcherTriggerType,
ModelPatcherHistory,
combine_functions,
combine_triggers,
)

from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures
from .model_patcher_fixtures import module1, module2, module4
from .model_patcher_fixtures import module1

MOD_CLS_A = create_module_class("MOD_CLS_A")
MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A)
Expand Down Expand Up @@ -119,7 +116,8 @@ def test_mp_trigger_correctly_triggers():

# Scenario 1:
# if check is a Callable, is_triggered result must be equal to the boolean output of check
# 1. create function to check that returns true if module has attribute `attr_1`, otherwise return False
# 1. create function to check that returns true if module has attribute `attr_1`,
# otherwise return False
# 2. create trigger that checks the above function
# 3. create a subclass of module_A and ensure is_triggered returns True
# 4. create a module_B and ensure is_triggered returns False
Expand Down Expand Up @@ -201,7 +199,12 @@ def check_module(module):
(AssertionError, "Only `AND`, `OR` logic implemented for combining triggers")
],
])
def test_combine_mp_triggers_produces_correct_output(target_module, trigger_checks, logic, expected_result):
def test_combine_mp_triggers_produces_correct_output(
target_module,
trigger_checks,
logic,
expected_result
):
triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks]

# if expected_result is a tuple of (Exception, Exception_message)
Expand All @@ -225,7 +228,7 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
"Ensure MP rule is throws appropriate error when wrong argument combinations are passed"
# Test mp rule construction raises with multiple arguments
with pytest.raises(
ValueError,
ValueError,
match="must only have only one of forward, " \
"foward builder, or import_and_maybe_reload, specified."
):
Expand All @@ -238,7 +241,7 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured():

# Test mp rule construction raises with trigger and import_and_reload
with pytest.raises(
ValueError,
ValueError,
match="has import_and_maybe_reload specified, " \
"and trigger must be None."
):
Expand All @@ -248,9 +251,10 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured():
import_and_maybe_reload=(),
)

# Test that rule construction raises if forward_builder_args are provided without a forward_builder
# Test that rule construction raises if forward_builder_args are provided
# without a forward_builder
with pytest.raises(
ValueError,
ValueError,
match="has forward_builder_args but no " \
"forward_builder."
):
Expand Down Expand Up @@ -371,7 +375,7 @@ def patched_mod_function():
# with the original version
assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass)

# S4 - module1.module3 submodule has a dependency on
# S4 - module1.module3 submodule has a dependency on
# module1.module1_1.mod_1_function
# 1. Replace the module1.module1_1.mod_1_function with a new function
# 2. Ensure the target reloading of module1.module3 picks up the patched function
Expand Down
8 changes: 2 additions & 6 deletions plugins/framework/tests/test_model_patcher2.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
# Third Party
import pytest # pylint: disable=(import-error
import torch

# First Party
from fms_acceleration.model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
patch_target_module,
combine_functions,
)

from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures
from .model_patcher_fixtures import module1, module2, module4
from .model_patcher_fixtures import module4
from fms_acceleration.utils.test_utils import instantiate_model_patcher

from .test_model_patcher import DUMMY_RULE_ID

#Test patching of model attribute
Expand Down Expand Up @@ -311,7 +308,7 @@ def build_list_of_triggers(
(ModelPatcherTrigger(check=SubModule1A), patched_forward_function),
(ModelPatcherTrigger(check=is_module_type_B), patched_forward_function),
(ModelPatcherTrigger(check=is_module_type_C), patched_forward_function),
]
]

ModelPatcher.register(
ModelPatcherRule(
Expand All @@ -326,4 +323,3 @@ def build_list_of_triggers(
for _, mod in model.named_children():
if hasattr(mod, "forward"):
assert mod.forward() == "patched_forward_function"

0 comments on commit e653b80

Please sign in to comment.