Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Oct 27, 2023
1 parent bfc092e commit d162855
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch

try:
from llmfoundry import optim, utils

# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
import transformers

from llmfoundry.models.utils.flash_attn_checker import is_flash_v1_installed
from llmfoundry.utils.flash_attn_checker import is_flash_v1_installed
if is_flash_v1_installed():
import transformers
transformers.utils.is_flash_attn_available = lambda: False
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Seq2SeqFinetuningCollator,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.models.utils.flash_attn_checker import is_flash_v2_installed
from llmfoundry.utils.flash_attn_checker import is_flash_v2_installed

try:
from peft.peft_model import PeftModel
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed,
from llmfoundry.utils.flash_attn_checker import (is_flash_v1_installed,
is_flash_v2_installed)


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/test_huggingface_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.hf.hf_fsdp import rgetattr
from llmfoundry.models.utils.flash_attn_checker import (is_flash_v1_installed,
from llmfoundry.utils.flash_attn_checker import (is_flash_v1_installed,
is_flash_v2_installed)
from llmfoundry.utils.builders import build_tokenizer

Expand Down

0 comments on commit d162855

Please sign in to comment.