Skip to content

Commit

Permalink
Fix FA2 tests (#29909)
Browse files Browse the repository at this point in the history
* fix FA2 tests

* refactor inference test name
  • Loading branch information
ylacombe authored Apr 1, 2024
1 parent 3b8e293 commit 569f6c7
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 19 deletions.
4 changes: 2 additions & 2 deletions tests/models/bark/test_modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def test_resize_embeddings_untied(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
Expand Down Expand Up @@ -936,7 +936,7 @@ def test_flash_attn_2_inference(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
Expand Down
4 changes: 2 additions & 2 deletions tests/models/distilbert/test_modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_torchscript_device_change(self):
@require_torch_accelerator
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
def test_flash_attn_2_inference_equivalence(self):
import torch

for model_class in self.all_model_classes:
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_flash_attn_2_inference(self):
@require_torch_accelerator
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
import torch

for model_class in self.all_model_classes:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def test_flash_attn_2_generate_use_cache(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Gemma flash attention does not support right padding")

@require_torch_sdpa
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_flash_attn_2_generate_use_cache(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mistral flash attention does not support right padding")


Expand Down
2 changes: 1 addition & 1 deletion tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def test_flash_attn_2_generate_use_cache(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mixtral flash attention does not support right padding")

# Ignore copy
Expand Down
2 changes: 1 addition & 1 deletion tests/models/qwen2/test_modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def test_flash_attn_2_generate_use_cache(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Qwen2 flash attention does not support right padding")


Expand Down
2 changes: 1 addition & 1 deletion tests/models/starcoder2/test_modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test_flash_attn_2_generate_use_cache(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Starcoder2 flash attention does not support right padding")


Expand Down
4 changes: 2 additions & 2 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
def test_flash_attn_2_inference_equivalence(self):
import torch

for model_class in self.all_model_classes:
Expand Down Expand Up @@ -934,7 +934,7 @@ def test_flash_attn_2_inference(self):
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
import torch

for model_class in self.all_model_classes:
Expand Down
12 changes: 4 additions & 8 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3245,7 +3245,7 @@ def test_flash_attn_2_conversion(self):
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
def test_flash_attn_2_inference_equivalence(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
Expand All @@ -3260,9 +3260,7 @@ def test_flash_attn_2_inference(self):
)
model_fa.to(torch_device)

model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)

dummy_input = inputs_dict[model.main_input_name][:1]
Expand Down Expand Up @@ -3340,7 +3338,7 @@ def test_flash_attn_2_inference(self):
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
Expand All @@ -3355,9 +3353,7 @@ def test_flash_attn_2_inference_padding_right(self):
)
model_fa.to(torch_device)

model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
model.to(torch_device)

dummy_input = inputs_dict[model.main_input_name][:1]
Expand Down

0 comments on commit 569f6c7

Please sign in to comment.