Skip to content

Commit

Permalink
refactored the code to put flash version checking in separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 27, 2023
1 parent 87b2fdc commit bfc092e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 23 deletions.
9 changes: 5 additions & 4 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# gated import otherwise.
import transformers

from llmfoundry.models.utils.flash_attn_checker import is_flash_v1_installed
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -19,8 +22,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, is_flash_v1_installed,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
Expand All @@ -29,8 +32,6 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False

except ImportError as e:
try:
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.models.utils.flash_attn_checker import is_flash_v2_installed

try:
from peft.peft_model import PeftModel
Expand Down
18 changes: 2 additions & 16 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,8 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def is_flash_v2_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')


def is_flash_v1_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed,
is_flash_v2_installed)


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
Expand Down
20 changes: 20 additions & 0 deletions llmfoundry/models/utils/flash_attn_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from packaging import version


def is_flash_v2_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')


def is_flash_v1_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
4 changes: 2 additions & 2 deletions tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.utils.builders import build_tokenizer

# Before importing any transformers models, we need to disable transformers flash attention if
Expand Down

0 comments on commit bfc092e

Please sign in to comment.