Skip to content

Commit

Permalink
remove commented out code
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 24, 2023
1 parent 6db6315 commit 528340f
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
'The CI cluster does not have access to the Llama models, so skip this test.'
)

# original_forward = LlamaAttention.forward

device = 'cuda:0'
sequence_length = 4096
model_dim = 4096 if '7b' in model_name else 8192
Expand Down Expand Up @@ -96,7 +94,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,

reproducibility.seed_all(42)
with patch.object(LlamaAttention, 'forward', new=patch_fn):
# LlamaAttention.forward = patch_fn
attention = LlamaAttention(config=llama_config,)
attention.to(dtype=dtype, device=device)
new_output, _, _ = attention(
Expand All @@ -107,9 +104,6 @@ def test_patch_equivalence(patch_fn_name: str, explicit_mask: bool,
use_cache=False,
)

# # Reset the forward function so patches don't persist
# LlamaAttention.forward = original_forward

assert torch.allclose(attn_output, new_output, atol=atol, rtol=rtol)


Expand Down

0 comments on commit 528340f

Please sign in to comment.