From 5bef19064b2e5905311ae718949553f4ad38a580 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 7 Dec 2024 17:24:46 -0500 Subject: [PATCH] [tests] reset known modules that are patched on each test function end (#2147) * reset known modules that are patched on each test function end * fix the llama model module name * prevent unsloth patching multiple times * pop classes out of the globals after reset * fix tuple indexing * manually workaround for llama fa2 --- src/axolotl/monkeypatch/trainer_grad_accum.py | 8 ++--- src/axolotl/monkeypatch/unsloth_.py | 8 +++++ tests/conftest.py | 29 +++++++++++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 5ee90f91ae..97e6b7f2c5 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -3,14 +3,14 @@ see https://github.com/huggingface/transformers/pull/35128 """ import inspect +import logging -from accelerate.logging import get_logger from transformers import LlamaForCausalLM from transformers.trainer import Trainer from axolotl.monkeypatch.unsloth_ import detab_code -LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum") +LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") ORIGINAL_CONTEXT_CODE = """ with self.compute_loss_context_manager(): @@ -145,7 +145,7 @@ def patch_training_step_for_ga(): globals(), ) exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching training_step", main_process_only=True) + LOG.info("patching training_step") Trainer.training_step = ( # pylint: disable=protected-access _fixed_training_step # pylint: disable=undefined-variable # noqa: F821 ) @@ -201,7 +201,7 @@ def patch_forward_for_ga(): globals(), ) exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching forward", main_process_only=True) + LOG.info("patching forward") LlamaForCausalLM.forward = ( # pylint: disable=protected-access _fixed_forward # pylint: disable=undefined-variable # noqa: F821 ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 7358803ba1..21fdb7edff 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]: return code, spaces +self_attn_lora_patched = False # pylint: disable=invalid-name + + def patch_self_attn_lora(): + global self_attn_lora_patched # pylint: disable=global-statement + if self_attn_lora_patched: + # prevent patching multiple times + return self_attn_forward = get_self_attn_code() LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access self_attn_forward @@ -134,6 +141,7 @@ def patch_self_attn_lora(): globals(), ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + self_attn_lora_patched = True LOG.info("patching unsloth attn lora", main_process_only=True) LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 diff --git a/tests/conftest.py b/tests/conftest.py index a775216fc0..1295d34b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,9 @@ shared pytest fixtures """ import functools +import importlib import shutil +import sys import tempfile import time @@ -113,3 +115,30 @@ def temp_dir(): yield _temp_dir # Clean up the directory after the test shutil.rmtree(_temp_dir) + + +@pytest.fixture(scope="function", autouse=True) +def cleanup_monkeypatches(): + from transformers.models.llama.modeling_llama import LlamaFlashAttention2 + + original_fa2_forward = LlamaFlashAttention2.forward + # monkey patches can happen inside the tests + yield + # Reset LlamaFlashAttention2 forward + LlamaFlashAttention2.forward = original_fa2_forward + + # Reset other known monkeypatches + modules_to_reset: list[tuple[str, list[str]]] = [ + ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), + ("transformers.trainer",), + ("transformers.loss.loss_utils",), + ] + for module_name_tuple in modules_to_reset: + module_name = module_name_tuple[0] + module = importlib.import_module(module_name) + sys.modules[module_name] = module + importlib.reload(sys.modules[module_name]) + if len(module_name_tuple) > 1: + module_globals = module_name_tuple[1] + for module_global in module_globals: + globals().pop(module_global, None)