From f569531afc5f1fb15732990b6a8d0d983c3d320d Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 30 Oct 2023 20:28:58 +0000 Subject: [PATCH] Make sure patch does not persist between tests --- tests/test_huggingface_flash.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py index 31402fc74f..2d63c7c56e 100644 --- a/tests/test_huggingface_flash.py +++ b/tests/test_huggingface_flash.py @@ -3,6 +3,7 @@ import contextlib import os +from unittest import mock from unittest.mock import patch import pytest @@ -113,6 +114,12 @@ def test_attn_patch_integration(patch: str): pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' ) + + # Save the original attention function to restore at the end of the test. + from transformers.models.llama.modeling_llama import \ + LlamaAttention + original_attn = LlamaAttention.forward + name = 'meta-llama/Llama-2-7b-hf' model_cfg = DictConfig({ 'name': 'hf_causal_lm', @@ -145,6 +152,9 @@ def test_attn_patch_integration(patch: str): outputs = model(tokenized_input) loss = outputs.loss loss.backward() + + # Ensure the patch does not persist beyond this test. + LlamaAttention.forward = original_attn @pytest.mark.gpu @pytest.mark.parametrize('model_name', ['llama2', 'mistral'])