Skip to content

Commit

Permalink
feat: disable rotary kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Sep 30, 2024
1 parent 704f976 commit 99f2b07
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,36 +598,36 @@ def forward(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_metada
original_layout = input_metadata.layout

# Rotary Embedding Implementation
if torch.is_tensor(input_metadata.rotary_cos) and torch.is_tensor(input_metadata.rotary_sin):
if input_metadata.causal or input_metadata.local:
q_ro = apply_rotary_emb(
q,
input_metadata.rotary_cos,
input_metadata.rotary_sin,
seqlen_offsets=input_metadata.cache_seqlens,
interleaved=input_metadata.rotary_interleaved,
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
input_metadata.rotary_cos,
input_metadata.rotary_sin,
seqlen_offsets=input_metadata.cache_seqlens,
interleaved=input_metadata.rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=input_metadata.max_seqlens_q,
)
k_ro = apply_rotary_emb(
input_metadata.k_new,
input_metadata.rotary_cos,
input_metadata.rotary_sin,
seqlen_offsets=input_metadata.cache_seqlens,
interleaved=input_metadata.rotary_interleaved,
)

q, input_metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype)
# if torch.is_tensor(input_metadata.rotary_cos) and torch.is_tensor(input_metadata.rotary_sin):
# if input_metadata.causal or input_metadata.local:
# q_ro = apply_rotary_emb(
# q,
# input_metadata.rotary_cos,
# input_metadata.rotary_sin,
# seqlen_offsets=input_metadata.cache_seqlens,
# interleaved=input_metadata.rotary_interleaved,
# )
# else:
# q_ro = rearrange(
# apply_rotary_emb(
# rearrange(q, "b s h d -> b 1 (s h) d"),
# input_metadata.rotary_cos,
# input_metadata.rotary_sin,
# seqlen_offsets=input_metadata.cache_seqlens,
# interleaved=input_metadata.rotary_interleaved,
# ),
# "b 1 (s h) d -> b s h d",
# s=input_metadata.max_seqlens_q,
# )
# k_ro = apply_rotary_emb(
# input_metadata.k_new,
# input_metadata.rotary_cos,
# input_metadata.rotary_sin,
# seqlen_offsets=input_metadata.cache_seqlens,
# interleaved=input_metadata.rotary_interleaved,
# )

# q, input_metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype)

# kernels expects "bsghd"
if input_metadata.layout == "bshd":
Expand Down

0 comments on commit 99f2b07

Please sign in to comment.