diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 8744cb168ff..04a6ad99b8d 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -879,7 +879,7 @@ def test_resize_embeddings_untied(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return @@ -936,7 +936,7 @@ def test_flash_attn_2_inference(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 481d4b24cd7..6bd821859ea 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -301,7 +301,7 @@ def test_torchscript_device_change(self): @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -353,7 +353,7 @@ def test_flash_attn_2_inference(self): @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 1b32f1b16ee..8c3aa392ba9 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -462,7 +462,7 @@ def test_flash_attn_2_generate_use_cache(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Gemma flash attention does not support right padding") @require_torch_sdpa diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 2e675a28515..432097e9d13 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -466,7 +466,7 @@ def test_flash_attn_2_generate_use_cache(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index efd48d6a9c3..98654c51335 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -465,7 +465,7 @@ def test_flash_attn_2_generate_use_cache(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mixtral flash attention does not support right padding") # Ignore copy diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 49da4fec98e..21ee694bdca 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -477,7 +477,7 @@ def test_flash_attn_2_generate_use_cache(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index f0794c46dce..95f604d06b3 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -461,7 +461,7 @@ def test_flash_attn_2_generate_use_cache(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Starcoder2 flash attention does not support right padding") diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b79f3a2c0da..7ff6387ff21 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -888,7 +888,7 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -934,7 +934,7 @@ def test_flash_attn_2_inference(self): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3dbfea719a5..7241993b6d1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3245,7 +3245,7 @@ def test_flash_attn_2_conversion(self): @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3260,9 +3260,7 @@ def test_flash_attn_2_inference(self): ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1] @@ -3340,7 +3338,7 @@ def test_flash_attn_2_inference(self): @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3355,9 +3353,7 @@ def test_flash_attn_2_inference_padding_right(self): ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1]