diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 5bec1dbb70..1b02c41c29 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,12 +17,12 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -def is_flash_v2_installed(): +def is_flash_v2_installed(v2_version: str = '2.0.0'): try: import flash_attn as flash_attn except: return False - return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + return version.parse(flash_attn.__version__) >= version.parse(v2_version) def is_flash_v1_installed(): diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 9c6f4287f7..c4ca68d733 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,8 +8,17 @@ from transformers import PretrainedConfig +from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import attn_config_defaults +# NOTE: All utils are imported directly even if unused so that +# HuggingFace can detect all the needed files to copy into its modules folder. +# Otherwise, certain modules are missing. +# isort: off +from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) + ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', } @@ -224,13 +233,18 @@ def _validate_config(self) -> None: raise ValueError( 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' ) - if self.attn_config['rope'] and ( - self.attn_config['rope_impl'] - == 'dail') and (self.attn_config['rope_dail_config']['type'] - not in ['original', 'xpos']): - raise ValueError( - 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' - ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + == 'dail'): + if self.attn_config['rope_dail_config']['type'] not in [ + 'original', 'xpos' + ]: + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) + if not is_flash_v2_installed(v2_version='2.0.1'): + raise ImportError( + 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'