Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 2, 2023
1 parent 1e59de5 commit ac0fd40
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def is_flash_v2_installed():
def is_flash_v2_installed(v2_version: str = '2.0.0'):
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
return version.parse(flash_attn.__version__) >= version.parse(v2_version)


def is_flash_v1_installed():
Expand Down
28 changes: 21 additions & 7 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@

from transformers import PretrainedConfig

from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.layers.blocks import attn_config_defaults

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)

ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
}
Expand Down Expand Up @@ -224,13 +233,18 @@ def _validate_config(self) -> None:
raise ValueError(
'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".'
)
if self.attn_config['rope'] and (
self.attn_config['rope_impl']
== 'dail') and (self.attn_config['rope_dail_config']['type']
not in ['original', 'xpos']):
raise ValueError(
'If using the dail implementation of rope, the type should be one of "original" or "xpos".'
)
if self.attn_config['rope'] and (self.attn_config['rope_impl']
== 'dail'):
if self.attn_config['rope_dail_config']['type'] not in [
'original', 'xpos'
]:
raise ValueError(
'If using the dail implementation of rope, the type should be one of "original" or "xpos".'
)
if not is_flash_v2_installed(v2_version='2.0.1'):
raise ImportError(
'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support'
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
Expand Down

0 comments on commit ac0fd40

Please sign in to comment.