From ba8dd76ee0e6bfabfd322ed924241b89ee50af46 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 9 Oct 2023 17:24:57 -0700 Subject: [PATCH] gate attn impls --- .github/workflows/pr-gpu.yaml | 5 +++ llmfoundry/models/layers/attention.py | 61 ++++++++++++++++++--------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 769b345e39..150b5488ce 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -31,6 +31,11 @@ jobs: container: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 markers: 'gpu' pytest_command: 'coverage run -m pytest' + # TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite + # - name: 'gpu-2.1.0-flash2' + # container: # UPDATE AFTER THIS PR GOES TO PRODUCTION IMAGE + # markers: 'gpu' + # pytest_command: 'coverage run -m pytest' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index ee1e009af4..90641e8267 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,17 +17,19 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -def raise_if_flash_attn_v2(): - flash_attn_version = None - # This only needs to be in a try except so that huggingface does not try to import it +def is_flash_v2_installed(): try: - from flash_attn import __version__ as flash_attn_version - if version.parse(flash_attn_version) >= version.parse('2.0.0'): - raise RuntimeError( - 'flash-attn==2.0.0+ is not supported. Please use flash-attn==1.0.9.' - ) + import flash_attn as flash_attn + except: + return False + return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + +def is_flash_v1_installed(): + try: + import flash_attn as flash_attn except: - pass + return False + return version.parse(flash_attn.__version__) < version.parse('2.0.0') def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, @@ -291,18 +293,35 @@ def flash_attn_fn( reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights) + if is_flash_v1_installed(): + output_unpad = flash_attn_interface.flash_attn_unpadded_func( + query_unpad, + key_unpad, + value_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights) + elif is_flash_v2_installed(): + output_unpad = flash_attn_interface.flash_attn_func( + query_unpad, + key_unpad, + value_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights) + else: + raise RuntimeError( + 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.') output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,