From 528340f6da2f93b99aca44c740ad587d0f1975ec Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 24 Oct 2023 17:42:55 +0000 Subject: [PATCH] remove commented out code --- tests/test_huggingface_flash.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py index f2f9c4bfcc..a71217ea1f 100644 --- a/tests/test_huggingface_flash.py +++ b/tests/test_huggingface_flash.py @@ -42,8 +42,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, 'The CI cluster does not have access to the Llama models, so skip this test.' ) - # original_forward = LlamaAttention.forward - device = 'cuda:0' sequence_length = 4096 model_dim = 4096 if '7b' in model_name else 8192 @@ -96,7 +94,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, reproducibility.seed_all(42) with patch.object(LlamaAttention, 'forward', new=patch_fn): - # LlamaAttention.forward = patch_fn attention = LlamaAttention(config=llama_config,) attention.to(dtype=dtype, device=device) new_output, _, _ = attention( @@ -107,9 +104,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool, use_cache=False, ) - # # Reset the forward function so patches don't persist - # LlamaAttention.forward = original_forward - assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)