Skip to content

Commit

Permalink
Make sure patch does not persist between tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 30, 2023
1 parent c099da3 commit f569531
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import contextlib
import os
from unittest import mock
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -113,6 +114,12 @@ def test_attn_patch_integration(patch: str):
pytest.skip(
'The CI cluster does not have access to the Llama models, so skip this test.'
)

# Save the original attention function to restore at the end of the test.
from transformers.models.llama.modeling_llama import \
LlamaAttention
original_attn = LlamaAttention.forward

name = 'meta-llama/Llama-2-7b-hf'
model_cfg = DictConfig({
'name': 'hf_causal_lm',
Expand Down Expand Up @@ -145,6 +152,9 @@ def test_attn_patch_integration(patch: str):
outputs = model(tokenized_input)
loss = outputs.loss
loss.backward()

# Ensure the patch does not persist beyond this test.
LlamaAttention.forward = original_attn

@pytest.mark.gpu
@pytest.mark.parametrize('model_name', ['llama2', 'mistral'])
Expand Down

0 comments on commit f569531

Please sign in to comment.