Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 6, 2024
1 parent e51fe14 commit 564165b
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 65 deletions.
11 changes: 8 additions & 3 deletions flash_attn/flash_attn_triton_amd/fwd_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
split_k = max(split_k, 1)
return split_k

def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes, layout, cache_seqlens, cache_batch_idx, new_kv, k_new, v_new):
def attention_decode_forward_triton_impl(q, k, v, k_new, v_new, sm_scale, causal, layout, alibi_slopes, cache_seqlens, cache_batch_idx):
# kernel config
BLOCK_M = 16
BLOCK_N = 64
Expand All @@ -553,16 +553,18 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
q=q.unsqueeze(2)
k=k.unsqueeze(2)
v=v.unsqueeze(2)
if new_kv:
if k_new is not None:
k_new = k_new.unsqueeze(2)
if v_new is not None:
v_new = v_new.unsqueeze(2)
layout = "bsghd"
elif layout == "bhsd":
q=q.permute(0, 2, 1, 3).unsqueeze(2)
k=k.permute(0, 2, 1, 3).unsqueeze(2)
v=v.permute(0, 2, 1, 3).unsqueeze(2)
if new_kv:
if k_new is not None:
k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2)
if v_new is not None:
v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2)
layout = "bsghd"
elif layout == "bsghd":
Expand All @@ -571,6 +573,9 @@ def attention_decode_forward_triton_impl(q, k, v, sm_scale, causal, alibi_slopes
raise ValueError("Layout not given")
assert layout == "bsghd"

# check that both are provided or both are none
assert ((k_new is None) and (v_new is None)) or ((k_new is not None) and (v_new is not None))

# get dims
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape
Expand Down
87 changes: 87 additions & 0 deletions flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,93 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, dropout_p, philox

return o, softmax_lse, sd_mask

def attention_decode_forward_pytorch_ref_impl(
q,
k_cache,
v_cache,
k_new,
v_new,
cache_seqlens,
cache_batch_idx,
sm_scale,
causal,
layout,
alibi_slopes,
rotary_cos,
rotary_sin,
rotary_interleaved,
use_exp2
):
if DEBUG:
print()
print("attention_forward_pytorch_ref_impl")
print("q:", q, q.shape)
print("k:", k_cache, k_cache.shape)
print("v:", v_cache, v_cache.shape)
print("k_new:", k_new, k_new.shape if k_new is not None else None)
print("v_new:", v_new, v_new.shape if v_new is not None else None)
print("cache_seqlens:", cache_seqlens)
print("cache_batch_idx:", cache_batch_idx)
print("sm_scale:", sm_scale)
print("causal:", causal)
print("alibi_slopes:", alibi_slopes)
print("layout:", layout)
print("rotary_cos:", rotary_cos)
print("rotary_sin:", rotary_sin)
print("rotary_interleaved:", rotary_interleaved)
print("use_exp2:", use_exp2)

# Ensure the layout is 'bhsd'
if layout == "bshd":
q = q.transpose(1, 2).contiguous()
k_cache = k_cache.transpose(1, 2).contiguous()
v_cache = v_cache.transpose(1, 2).contiguous()
if k_new is not None:
k_new = k_new.transpose(1, 2).contiguous()
if v_new is not None:
v_new = v_new.transpose(1, 2).contiguous()
elif layout != "bhsd":
raise ValueError(f"Unknown layout {layout}")

# check that both are provided or both are none
assert ((k_new is None) and (v_new is None)) or ((k_new is not None) and (v_new is not None))

# Prepare tensors
batch_size, nheads_q, seq_len_q, head_dim = q.shape
batch_size, nheads_k_cache, seq_len_k_cache, head_dim = k_cache.shape
if k_new:
batch_size, nheads_k_new, seq_len_k_new, head_dim = k_new.shape

# insert new tensors in cache
# TODO

# convert to 3d tensors for core impl
q = q.reshape(batch_size * nheads_q, seq_len_q, head_dim)
k_cache = k_cache.reshape(batch_size * nheads_k_cache, seq_len_k_cache, head_dim)
v_cache = v_cache.reshape(batch_size * nheads_k_cache, seq_len_k_cache, head_dim)
# if k_new is not None:
# k_new = k_new.reshape(batch_size * nheads_k_new, seq_len_k_new, head_dim)
# if v_new is not None:
# v_new = v_new.reshape(batch_size * nheads_k_new, seq_len_k_new, head_dim)


# launch core impl
output, softmax_lse, sd_mask = attention_forward_core_ref_impl(
q, k_cache, v_cache, sm_scale, causal, 0.0, None, None, use_exp2
)

output = output.reshape(batch_size, nheads_q, seq_len_q, head_dim)
softmax_lse = softmax_lse.reshape(batch_size, nheads_q, seq_len_q)
sd_mask = sd_mask.reshape(batch_size, nheads_q, seq_len_q, seq_len_k_cache)

if layout == "bshd":
output = output.transpose(1, 2)

return output, softmax_lse




def attention_vanilla_forward_pytorch_ref_impl(q, k, v, sm_scale, causal, layout, dropout_p, philox_seed, philox_offset, use_exp2):
"""Compute reference output and softmax_lse using PyTorch's built-in function"""

Expand Down
82 changes: 57 additions & 25 deletions flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from .fwd_prefill import attention_prefill_forward_triton_impl
from .bwd_prefill import attention_prefill_backward_triton_impl
from .fwd_decode import attention_decode_forward_triton_impl
from .fwd_ref import attention_forward_pytorch_ref_impl
from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_pytorch_ref_impl
from .bwd_ref import attention_backward_pytorch_ref_impl
from .utils import MetaData, get_shape_from_layout, DEBUG
from .utils import DEBUG, USE_REF, MetaData, get_shape_from_layout
from einops import rearrange, repeat
from flash_attn.layers.rotary import apply_rotary_emb

USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')

def fwd(q,
k,
v,
Expand Down Expand Up @@ -502,6 +500,20 @@ def fwd_kvcache(
rotary_interleaved,
num_splits):

if DEBUG:
print()
print("flash_attn_triton_amd.py::fwd_kvcache")
print("q:", q, q.shape)
print("k:", k, k.shape if k is not None else None)
print("v:", v, v.shape if v is not None else None)
print("alibi_slopes:", alibi_slopes)
print("softmax_scale:", softmax_scale)
print("causal:", causal)
print("out:", out)
print("window_size_left:", window_size_left)
print("window_size_right:", window_size_right)
print("softcap:", softcap)

if out is None:
out = torch.empty_like(q)

Expand All @@ -513,11 +525,6 @@ def fwd_kvcache(
metadata.cache_seqlens = cache_seqlens
metadata.cache_batch_idx = cache_batch_idx

if k is not None and v is not None:
metadata.new_kv = True
metadata.seqlen_new = k.shape[1]
metadata.k_new = k
metadata.v_new = v

if causal:
metadata.need_causal()
Expand Down Expand Up @@ -563,20 +570,45 @@ def fwd_kvcache(

q, metadata.k_new = q_ro.to(q.dtype), k_ro.to(q.dtype)

# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
v_cache,
metadata.sm_scale,
metadata.causal,
metadata.alibi_slopes,
metadata.layout,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.new_kv,
metadata.k_new,
metadata.v_new,
)
if USE_REF:
if DEBUG:
print("Using reference implementation")
output, softmax_lse = attention_decode_forward_pytorch_ref_impl(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
cache_batch_idx,
metadata.sm_scale,
metadata.causal,
metadata.layout,
metadata.alibi_slopes,
metadata.rotary_cos,
metadata.rotary_sin,
metadata.rotary_interleaved,
False
)
out.copy_(output)
else:
if DEBUG:
print("Using Triton implementation")

# launch kernel
# TODO: pass output as an arg. Maybe we are copying output which is causing slow down
output, softmax_lse = attention_decode_forward_triton_impl(
q,
k_cache,
v_cache,
metadata.sm_scale,
metadata.causal,
metadata.alibi_slopes,
metadata.layout,
metadata.cache_seqlens,
metadata.cache_batch_idx,
metadata.new_kv,
metadata.k_new,
metadata.v_new,
)
return output, softmax_lse
16 changes: 7 additions & 9 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
import triton
import triton.language as tl

# global variables
AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM: # TODO remove this
random.seed(42)
DROPOUT_USE_PYTORCH = False
DROPOUT_DUMP = False

# Flash Attention Metadata
class MetaData():
cu_seqlens_q = None
cu_seqlens_k = None
Expand All @@ -30,10 +33,6 @@ class MetaData():
layout = None
cache_seqlens = None
cache_batch_idx = None
new_kv = False
seqlen_new = None
k_new = None
v_new = None
return_scores= False
dropout_p= 0.0
philox_seed, philox_offset = None, None # if dropout_p > 0.0 seed the RNG so we get reproducible results for testing.
Expand All @@ -43,7 +42,7 @@ class MetaData():
rotary_cos = None
rotary_interleaved = False
rotary_conjunction = False

is_decode = False

def __repr__(self) -> str:
return (f"MetaData(\n"
Expand All @@ -60,17 +59,16 @@ def __repr__(self) -> str:
f" layout={self.layout},\n"
f" cache_seqlens={self.cache_seqlens},\n"
f" cache_batch_idx={self.cache_batch_idx},\n"
f" new_kv={self.new_kv},\n"
f" seqlen_new={self.seqlen_new},\n"
f" k_new={self.k_new},\n"
f" v_new={self.v_new},\n"
f" dropout_p={self.dropout_p},\n"
f" return_scores={self.return_scores}\n"
f")")

def __init__(self, sm_scale=1.0):
self.sm_scale = sm_scale

def is_decode(self):
is_decode = True

def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
Expand Down
58 changes: 30 additions & 28 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,50 +1842,52 @@ def test_flash_attn_varlen_causal(

# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [1])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [False])
# @pytest.mark.parametrize("num_splits", [1, 0])
@pytest.mark.parametrize("num_splits", [1])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("new_kv", [False, True])
@pytest.mark.parametrize("new_kv", [False])
# @pytest.mark.parametrize("alibi", [False, True])
@pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
# @pytest.mark.parametrize("rotary_interleaved", [False, True])
@pytest.mark.parametrize("rotary_interleaved", [False])
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [None])
# @pytest.mark.parametrize("has_leftpad", [False, True])
@pytest.mark.parametrize("has_leftpad", [False])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
(4, 4)
# (1, 128),
# (1, 339),
# (3, 1024),
# (64, 800),
# (64, 256),
# (3, 799),
# (64, 2048),
# (16, 20000),
# (1, 128 * 1024),
# (16, 128 * 1024),
# (128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
Expand Down

0 comments on commit 564165b

Please sign in to comment.