diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index de8b06197e..18d1e5152b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -236,7 +236,7 @@ def flash_attn_fn( torch.Tensor]]]: del key_padding_mask try: - from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip + from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: raise RuntimeError( 'Please install flash-attn==1.0.9 or flash-attn==2.3.6') @@ -276,12 +276,12 @@ def flash_attn_fn( max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] - query_unpad = index_first_axis(rearrange(query, 'b s ... -> (b s) ...'), - indices_q) - key_unpad = index_first_axis(rearrange(key, 'b s ... -> (b s) ...'), - indices_k) - value_unpad = index_first_axis(rearrange(value, 'b s ... -> (b s) ...'), - indices_v) + query_unpad = bert_padding.index_first_axis( + rearrange(query, 'b s ... -> (b s) ...'), indices_q) + key_unpad = bert_padding.index_first_axis( + rearrange(key, 'b s ... -> (b s) ...'), indices_k) + value_unpad = bert_padding.index_first_axis( + rearrange(value, 'b s ... -> (b s) ...'), indices_v) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( not should_repeat_kv_for_gqa):