From 834e46f97e12f441f9d4e4ea28952fd0d55e6475 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 22 Oct 2023 11:58:26 -0700 Subject: [PATCH] fix test --- tests/test_huggingface_flash.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py index b25b06e31b..cc20ee0edc 100644 --- a/tests/test_huggingface_flash.py +++ b/tests/test_huggingface_flash.py @@ -181,8 +181,7 @@ def test_flash2(model_name: str, use_flash_attention_2: bool): model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) # check that it actually used flash attention 2 - assert model.model.config._flash_attn_2_enabled if is_flash_v2_installed( - ) else not model.model.config._flash_attn_2_enabled + assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled attention_layer = rgetattr( rgetattr(model, attention_layers_attr)[0], attention_attr) assert isinstance(attention_layer, flash_attn_class)