Skip to content

Commit

Permalink
Updating the Flash Attention version to fix cross entropy loss (#812)
Browse files Browse the repository at this point in the history
* ..

* ..

* ..

* ..

* ..
  • Loading branch information
ShashankMosaicML authored Dec 20, 2023
1 parent a7e916b commit 2ba9224
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 196 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, om_model_config: Union[DictConfig,
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')
+ 'Please install flash_attn==2.3.6`.')

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
config = AutoConfig.from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def flash_attn_fn(
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.2')
'Please install flash-attn==1.0.9 or flash-attn==2.3.6')

check_valid_inputs(query, key, value)

Expand Down Expand Up @@ -344,7 +344,7 @@ def flash_attn_fn(
window_size=(sliding_window_size, sliding_window_size))
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
'flash-attn==1.0.9 or flash-attn==2.3.6 is required.')

output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
Expand Down
173 changes: 0 additions & 173 deletions llmfoundry/models/layers/cross_entropy_loss.py

This file was deleted.

2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,5 @@ def _validate_config(self) -> None:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2'
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6'
)
6 changes: 1 addition & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,7 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714).
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library.
from llmfoundry.models.layers.cross_entropy_loss import \
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.3.2',
'flash-attn==2.3.6',
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]

extra_deps['peft'] = [
Expand Down
16 changes: 5 additions & 11 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,22 +422,16 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str,


@pytest.mark.gpu
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
def test_loss_fn():
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.
We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down

0 comments on commit 2ba9224

Please sign in to comment.