diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 51fa67993a..0b4e49b1a0 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -9,6 +9,9 @@ # gated import otherwise. import transformers + from llmfoundry.models.utils.flash_attn_checker import is_flash_v1_installed + if is_flash_v1_installed(): + transformers.utils.is_flash_attn_available = lambda: False from llmfoundry import optim, utils from llmfoundry.data import (ConcatTokensDataset, MixtureOfDenoisersCollator, NoConcatDataset, @@ -19,8 +22,8 @@ ComposerHFT5) from llmfoundry.models.layers.attention import ( MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, - flash_attn_fn, is_flash_v1_installed, - scaled_multihead_dot_product_attention, triton_flash_attn_fn) + flash_attn_fn, scaled_multihead_dot_product_attention, + triton_flash_attn_fn) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP, build_ffn) @@ -29,8 +32,6 @@ MPTForCausalLM, MPTModel, MPTPreTrainedModel) from llmfoundry.tokenizers import TiktokenTokenizerWrapper - if is_flash_v1_installed(): - transformers.utils.is_flash_attn_available = lambda: False except ImportError as e: try: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index eb90b07045..88829c67e1 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -24,8 +24,8 @@ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss -from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights +from llmfoundry.models.utils.flash_attn_checker import is_flash_v2_installed try: from peft.peft_model import PeftModel diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 39fa7162ac..ea663e9069 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -15,22 +15,8 @@ from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY - - -def is_flash_v2_installed(): - try: - import flash_attn as flash_attn - except: - return False - return version.parse(flash_attn.__version__) >= version.parse('2.0.0') - - -def is_flash_v1_installed(): - try: - import flash_attn as flash_attn - except: - return False - return version.parse(flash_attn.__version__) < version.parse('2.0.0') +from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed, + is_flash_v2_installed) def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, diff --git a/llmfoundry/models/utils/flash_attn_checker.py b/llmfoundry/models/utils/flash_attn_checker.py new file mode 100644 index 0000000000..dc916a8628 --- /dev/null +++ b/llmfoundry/models/utils/flash_attn_checker.py @@ -0,0 +1,20 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from packaging import version + + +def is_flash_v2_installed(): + try: + import flash_attn as flash_attn + except: + return False + return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + + +def is_flash_v1_installed(): + try: + import flash_attn as flash_attn + except: + return False + return version.parse(flash_attn.__version__) < version.parse('2.0.0') diff --git a/tests/test_huggingface_flash.py b/tests/test_huggingface_flash.py index a71217ea1f..9b5aa160dd 100644 --- a/tests/test_huggingface_flash.py +++ b/tests/test_huggingface_flash.py @@ -14,8 +14,8 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.models.hf.hf_fsdp import rgetattr -from llmfoundry.models.layers.attention import (is_flash_v1_installed, - is_flash_v2_installed) +from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed, + is_flash_v2_installed) from llmfoundry.utils.builders import build_tokenizer # Before importing any transformers models, we need to disable transformers flash attention if