Skip to content

Commit

Permalink
Added Support for Rotary Positional Embeddings (#99)
Browse files Browse the repository at this point in the history
* feat: added rotary support in kvcache

* confirmed non-fused rotary passes all tests
  • Loading branch information
alexkranias-amd authored Nov 20, 2024
1 parent 8947040 commit 1fcc51b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
39 changes: 39 additions & 0 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .fwd_ref import attention_forward_pytorch_ref_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import MetaData, get_shape_from_layout, DEBUG
from einops import rearrange, repeat
from flash_attn.layers.rotary import apply_rotary_emb

USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')

Expand Down Expand Up @@ -516,6 +518,43 @@ def fwd_kvcache(
batch, _ , nheads_q, _= q.shape
metadata.need_alibi(alibi_slopes, batch, nheads_q)

# rotary boolean
apply_rotary = torch.is_tensor(rotary_cos) and torch.is_tensor(rotary_sin)
if apply_rotary:
metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved)

# Rotary Embedding Implementation
if apply_rotary:
if metadata.causal: # NOTE: when support is added. Add `or metadata.local`
q_ro = apply_rotary_emb(
q,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=metadata.max_seqlens_q,
)
k_ro = apply_rotary_emb(
metadata.k_new,
metadata.rotary_cos,
metadata.rotary_sin,
seqlen_offsets=metadata.cache_seqlens,
interleaved=metadata.rotary_interleaved,
)

q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype)

# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
Expand Down
10 changes: 10 additions & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class MetaData():
dropout_p, return_scores= 0.0, False
# NOTE: scale sm_scale by log_2(e) and use 2^x in the loop as we do not have native e^x support in HW.
use_exp2 = False
rotary_sin = None
rotary_cos = None
rotary_interleaved = False
rotary_conjunction = False


def __repr__(self) -> str:
Expand Down Expand Up @@ -85,6 +89,12 @@ def need_alibi(self, alibi_slopes, batch, nheads):
def need_causal(self):
self.causal = True

def need_rotary(self, sin, cos, rotary_interleaved, rotary_conjunction=False):
self.rotary_sin = sin
self.rotary_cos = cos
self.rotary_interleaved = rotary_interleaved
self.rotary_conjunction = rotary_conjunction

def need_dropout(self, dropout_p, return_scores):
self.dropout_p = dropout_p
self.return_scores = return_scores
Expand Down
11 changes: 4 additions & 7 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,10 +1850,10 @@ def test_flash_attn_varlen_causal(
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("rotary_interleaved", [False, True])
@pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [None])
Expand Down Expand Up @@ -1907,9 +1907,6 @@ def test_flash_attn_kvcache(

if local == True:
pytest.skip("local sliding window attention not supported on AMD's Triton Backend yet")

if rotary_interleaved == True or rotary_fraction > 0.0:
pytest.skip("rotary embedding not supported on AMD's Triton Backend yet")

if has_leftpad == True:
pytest.skip("cache_leftpad not supported on AMD's Triton Backend yet")
Expand Down

0 comments on commit 1fcc51b

Please sign in to comment.