Skip to content

Commit

Permalink
gate attn impls
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 10, 2023
1 parent fdec363 commit ba8dd76
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ jobs:
container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
# TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite
# - name: 'gpu-2.1.0-flash2'
# container: # UPDATE AFTER THIS PR GOES TO PRODUCTION IMAGE
# markers: 'gpu'
# pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
61 changes: 40 additions & 21 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def raise_if_flash_attn_v2():
flash_attn_version = None
# This only needs to be in a try except so that huggingface does not try to import it
def is_flash_v2_installed():
try:
from flash_attn import __version__ as flash_attn_version
if version.parse(flash_attn_version) >= version.parse('2.0.0'):
raise RuntimeError(
'flash-attn==2.0.0+ is not supported. Please use flash-attn==1.0.9.'
)
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')

def is_flash_v1_installed():
try:
import flash_attn as flash_attn
except:
pass
return False
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
Expand Down Expand Up @@ -291,18 +293,35 @@ def flash_attn_fn(

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)

output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
if is_flash_v1_installed():
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed():
output_unpad = flash_attn_interface.flash_attn_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')

output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
Expand Down

0 comments on commit ba8dd76

Please sign in to comment.