From e653b80e7e536366a028e86e75bba1bbcee0dda0 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Thu, 25 Jul 2024 17:56:53 +0000 Subject: [PATCH] lint changes --- .../src/fms_acceleration/framework_plugin.py | 3 +- .../src/fms_acceleration/model_patcher.py | 39 +++++++++---------- .../src/fms_acceleration/utils/test_utils.py | 6 +-- .../module1/__init__.py | 2 +- .../module1/module1_1.py | 3 +- .../module1/module3/__init__.py | 2 +- .../module4/module4_1.py | 3 +- .../module4/module5/__init__.py | 2 +- .../module4/module5/module5_1.py | 7 ++-- .../tests/model_patcher_test_utils.py | 14 ++++--- plugins/framework/tests/test_model_patcher.py | 26 +++++++------ .../framework/tests/test_model_patcher2.py | 8 +--- 12 files changed, 58 insertions(+), 57 deletions(-) diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index ed98c720..169d98eb 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -21,11 +21,10 @@ # Third Party from accelerate import Accelerator from peft import LoraConfig +from transformers.utils import logging from transformers import TrainingArguments import torch -from transformers.utils import logging - # want to use the transformers logger, but a bit of pain logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger.setLevel(logging._get_default_logging_level()) diff --git a/plugins/framework/src/fms_acceleration/model_patcher.py b/plugins/framework/src/fms_acceleration/model_patcher.py index 02b28eee..10cf2f02 100644 --- a/plugins/framework/src/fms_acceleration/model_patcher.py +++ b/plugins/framework/src/fms_acceleration/model_patcher.py @@ -197,15 +197,6 @@ def __post_init__(self): "forward_builder." ) - # if self.import_and_maybe_reload is not None and self.import_and_maybe_reload[2] in self.import_and_maybe_reload[0]: - # raise ValueError( - # f"Rule '{self.rule_id}' import_and_maybe_reload specified has argument 3 in the same path " - # "as argument 1. The path to reload has to be different from object to be patched." - # ) - - - - # helpful to keep a history of all patching that has been done @dataclass class ModelPatcherHistory: @@ -269,8 +260,8 @@ def register(rule: ModelPatcherRule): @staticmethod def did_rule_trigger(module: torch.nn.Module, module_name: str): - - active_rule_name, active_rule = None, None + + active_rule_name, active_rule = None, None for name, rule in ModelPatcher.rules.items(): # if there is no trigger @@ -283,11 +274,12 @@ def did_rule_trigger(module: torch.nn.Module, module_name: str): active_rule_name = name active_rule = rule # otherwise, if there is already an active rule, raise warning - # that subsequent compatible forward rules will be ignored for simple forward patches - # forwardbuilders are handled when they are decomposed into new simple forward rules + # that subsequent compatible forward rules will be ignored + # for simple forward patches. forward_builder args are handled + # when they are decomposed into new simple forward rules elif rule.forward is not None: - warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule {active_rule.rule_id} has been applied") - #raise Exception(f"rule {rule.rule_id} is ignored on {module_name} as an earlier rule has been applied") + warnings.warn(f"rule {rule.rule_id} is ignored on {module_name} as an \ + earlier rule {active_rule.rule_id} has been applied") return active_rule_name, active_rule @@ -338,18 +330,23 @@ def _import_and_reload(model: torch.nn.Module): elif _target.startswith(module_path): _no_reload.append(rule) - # If there are multiple reload targets, + # If there are multiple reload targets, # ensure that their paths do not conflict as reloading same module might reset patches if len(_with_reload)>1: # sort ascending target path length - _with_reload = sorted(_with_reload, key=lambda _rule: len(_rule.import_and_maybe_reload[2]), reverse=False) + _with_reload = sorted( + _with_reload, + key=lambda _rule: len(_rule.import_and_maybe_reload[2]), + reverse=False + ) for rule_s in _with_reload: for rule_l in _with_reload[1:]: # if target paths in rule s is a prefix of rule l, raise an error _, _, _path_s = rule_s.import_and_maybe_reload _, _, _path_l = rule_l.import_and_maybe_reload assert not _path_l.startswith(_path_s), \ - f"Attempting to reload same path `{_path_s}` multiple times in {rule_s.rule_id} and {rule_l.rule_id}" + f"Attempting to reload same path `{_path_s}` multiple times in \ + {rule_s.rule_id} and {rule_l.rule_id}" # handle those with reload first for rule in _with_reload + _no_reload: @@ -469,11 +466,11 @@ def patch(model: torch.nn.Module, **kwargs): # only once. We do not have any checks for this at the moment # 1. Iterate over all ModelPatcher rules - # 2. For import_and_maybe_reload rules, an assertion + # 2. For import_and_maybe_reload rules, an assertion # is currently thrown if there are multiple reloads - # 3. For _patch_forwards, ensure that the trigger check + # 3. For _patch_forwards, ensure that the trigger check # module or callable function is unique across all rules - # otherwise, an assertion is thrown as it could patch the + # otherwise, an assertion is thrown as it could patch the # forwards over previous patches try: diff --git a/plugins/framework/src/fms_acceleration/utils/test_utils.py b/plugins/framework/src/fms_acceleration/utils/test_utils.py index 3952b528..929c61e3 100644 --- a/plugins/framework/src/fms_acceleration/utils/test_utils.py +++ b/plugins/framework/src/fms_acceleration/utils/test_utils.py @@ -183,8 +183,8 @@ def dummy_custom_loader(self, model_name, **kwargs): @contextmanager def instantiate_model_patcher(): - from fms_acceleration.model_patcher import ModelPatcher + from fms_acceleration.model_patcher import ModelPatcher # pylint: disable=import-outside-toplevel old_registrations = ModelPatcher.rules ModelPatcher.rules = {} - yield - ModelPatcher.rules = old_registrations \ No newline at end of file + yield + ModelPatcher.rules = old_registrations diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py index 546e2bed..9fbb4ad6 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/__init__.py @@ -1,2 +1,2 @@ from .module3 import Module3Class -from .module1_1 import Module1Class, mod_1_function \ No newline at end of file +from .module1_1 import Module1Class, mod_1_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py index 71906b38..46541ce0 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module1_1.py @@ -5,4 +5,5 @@ def __init__(self) -> None: self.attribute = Module2Class() def mod_1_function(): - return "unpatched_mod_function" \ No newline at end of file + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py index 9aa0c47d..93c28a31 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module1/module3/__init__.py @@ -1 +1 @@ -from .module3_1 import Module3Class \ No newline at end of file +from .module3_1 import Module3Class diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py index aa7a9700..79b6c0c8 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module4_1.py @@ -1,2 +1,3 @@ def mod_4_function(): - return "unpatched_mod_function" \ No newline at end of file + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py index 7652a2b7..e803018f 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/__init__.py @@ -1 +1 @@ -from .module5_1 import Module5Class, mod_5_function \ No newline at end of file +from .module5_1 import Module5Class, mod_5_function diff --git a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py index dfba5e17..b4351eff 100644 --- a/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py +++ b/plugins/framework/tests/model_patcher_fixtures/module4/module5/module5_1.py @@ -1,8 +1,9 @@ import torch class Module5Class(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self) -> None: + super().__init__() def mod_5_function(): - return "unpatched_mod_function" \ No newline at end of file + return "unpatched_mod_function" + \ No newline at end of file diff --git a/plugins/framework/tests/model_patcher_test_utils.py b/plugins/framework/tests/model_patcher_test_utils.py index 10c56537..b977a757 100644 --- a/plugins/framework/tests/model_patcher_test_utils.py +++ b/plugins/framework/tests/model_patcher_test_utils.py @@ -22,17 +22,17 @@ @contextmanager def isolate_test_module_fixtures(): old_mod = { - k: sys.modules[k] for k in MODULE_PATHS + k: sys.modules[k] for k in MODULE_PATHS if k in sys.modules } yield # Reload only reloads the speicified module, but makes not attempt to reload - # the imports of that module. + # the imports of that module. # - i.e., This moeans that if and import had been changed - # then the reload will take the changed import. + # then the reload will take the changed import. # - i.e., This also means that the individuals must be reloaded seperatedly # for a complete reset. - # + # # Therefore, we need to reload ALL Modules in opposite tree order, meaning that # the children must be reloaded before their parent @@ -43,7 +43,9 @@ def isolate_test_module_fixtures(): def create_module_class( class_name: str, - namespaces: Dict[str, Any] = {}, + namespaces: Dict[str, Any] = None, parent_class: Type = torch.nn.Module ): - return type(class_name, (parent_class,), namespaces) \ No newline at end of file + if namespaces is None: + namespaces = {} + return type(class_name, (parent_class,), namespaces) diff --git a/plugins/framework/tests/test_model_patcher.py b/plugins/framework/tests/test_model_patcher.py index f26c20c7..0d24950a 100644 --- a/plugins/framework/tests/test_model_patcher.py +++ b/plugins/framework/tests/test_model_patcher.py @@ -21,18 +21,15 @@ # First Party from fms_acceleration.model_patcher import ( - ModelPatcher, ModelPatcherRule, ModelPatcherTrigger, patch_target_module, ModelPatcherTriggerType, - ModelPatcherHistory, - combine_functions, combine_triggers, ) from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures -from .model_patcher_fixtures import module1, module2, module4 +from .model_patcher_fixtures import module1 MOD_CLS_A = create_module_class("MOD_CLS_A") MOD_SUBCLS_A = create_module_class("MOD_SUBCLS_A", parent_class=MOD_CLS_A) @@ -119,7 +116,8 @@ def test_mp_trigger_correctly_triggers(): # Scenario 1: # if check is a Callable, is_triggered result must be equal to the boolean output of check - # 1. create function to check that returns true if module has attribute `attr_1`, otherwise return False + # 1. create function to check that returns true if module has attribute `attr_1`, + # otherwise return False # 2. create trigger that checks the above function # 3. create a subclass of module_A and ensure is_triggered returns True # 4. create a module_B and ensure is_triggered returns False @@ -201,7 +199,12 @@ def check_module(module): (AssertionError, "Only `AND`, `OR` logic implemented for combining triggers") ], ]) -def test_combine_mp_triggers_produces_correct_output(target_module, trigger_checks, logic, expected_result): +def test_combine_mp_triggers_produces_correct_output( + target_module, + trigger_checks, + logic, + expected_result +): triggers = [ModelPatcherTrigger(check=check) for check in trigger_checks] # if expected_result is a tuple of (Exception, Exception_message) @@ -225,7 +228,7 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): "Ensure MP rule is throws appropriate error when wrong argument combinations are passed" # Test mp rule construction raises with multiple arguments with pytest.raises( - ValueError, + ValueError, match="must only have only one of forward, " \ "foward builder, or import_and_maybe_reload, specified." ): @@ -238,7 +241,7 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): # Test mp rule construction raises with trigger and import_and_reload with pytest.raises( - ValueError, + ValueError, match="has import_and_maybe_reload specified, " \ "and trigger must be None." ): @@ -248,9 +251,10 @@ def test_mp_rule_raises_error_when_arguments_incorrectly_configured(): import_and_maybe_reload=(), ) - # Test that rule construction raises if forward_builder_args are provided without a forward_builder + # Test that rule construction raises if forward_builder_args are provided + # without a forward_builder with pytest.raises( - ValueError, + ValueError, match="has forward_builder_args but no " \ "forward_builder." ): @@ -371,7 +375,7 @@ def patched_mod_function(): # with the original version assert not isinstance(module1.module3.module3_1.Module3Class(), PatchedModuleClass) - # S4 - module1.module3 submodule has a dependency on + # S4 - module1.module3 submodule has a dependency on # module1.module1_1.mod_1_function # 1. Replace the module1.module1_1.mod_1_function with a new function # 2. Ensure the target reloading of module1.module3 picks up the patched function diff --git a/plugins/framework/tests/test_model_patcher2.py b/plugins/framework/tests/test_model_patcher2.py index 47a19ce2..2011ed0e 100644 --- a/plugins/framework/tests/test_model_patcher2.py +++ b/plugins/framework/tests/test_model_patcher2.py @@ -1,6 +1,5 @@ # Third Party import pytest # pylint: disable=(import-error -import torch # First Party from fms_acceleration.model_patcher import ( @@ -8,13 +7,11 @@ ModelPatcherRule, ModelPatcherTrigger, patch_target_module, - combine_functions, ) from .model_patcher_test_utils import create_module_class, isolate_test_module_fixtures -from .model_patcher_fixtures import module1, module2, module4 +from .model_patcher_fixtures import module4 from fms_acceleration.utils.test_utils import instantiate_model_patcher - from .test_model_patcher import DUMMY_RULE_ID #Test patching of model attribute @@ -311,7 +308,7 @@ def build_list_of_triggers( (ModelPatcherTrigger(check=SubModule1A), patched_forward_function), (ModelPatcherTrigger(check=is_module_type_B), patched_forward_function), (ModelPatcherTrigger(check=is_module_type_C), patched_forward_function), - ] + ] ModelPatcher.register( ModelPatcherRule( @@ -326,4 +323,3 @@ def build_list_of_triggers( for _, mod in model.named_children(): if hasattr(mod, "forward"): assert mod.forward() == "patched_forward_function" -