Skip to content

Commit

Permalink
limit decode test to test_op_fwd
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 2, 2024
1 parent 6415d9a commit 485ba55
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ jobs:
- name: AMD Kernel Tests
run: |
pytest flash_attn/flash_attn_triton_kernel_prefill_amd.py
pytest flash_attn/flash_attn_triton_kernel_decode_amd.py
pytest flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd
22 changes: 6 additions & 16 deletions flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ def _fwd_kernel_splitK(
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
# print("kv_len:", kv_len)
# print("lo:", lo)
# print("hi:", hi)

HEAD_RATIO: tl.constexpr = H_q // H_kv
if IS_GQA:
Expand Down Expand Up @@ -354,7 +351,6 @@ def _fwd_kernel_splitK(
else:
qk = qk - m_i_new[:, None]

# print("qk before p:", qk)
p = tl.math.exp2(qk)
# print("p:", p)

Expand Down Expand Up @@ -708,16 +704,6 @@ def forward(cls, q, k, v, input_metadata):
# Handle MQA/GQA case
if heads_per_group_q > heads_per_group_k:
input_metadata.is_gqa = True

# n_heads_per_group = heads_per_group_q // heads_per_group_k

# # Repeat each row of k and v to match the number of query heads
# k = k.repeat_interleave(n_heads_per_group, dim=3)
# v = v.repeat_interleave(n_heads_per_group, dim=3)

# # Update heads_per_group_k and heads_per_group_v
# heads_per_group_k = heads_per_group_q
# heads_per_group_v = heads_per_group_q
elif heads_per_group_q < heads_per_group_k:
raise ValueError("heads_per_group_q < heads_per_group_k")
else:
Expand Down Expand Up @@ -981,7 +967,9 @@ def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32))
quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32))
scale = 1 / K**0.5
tri_out = attention_decode(q, quant_k, quant_v, scale)
input_metadata = MetaData(sm_scale=scale)
input_metadata.layout = "bsghd"
tri_out = attention_decode(q, quant_k, quant_v, input_metadata)

q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
Expand Down Expand Up @@ -1049,7 +1037,9 @@ def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=
requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1)

sm_scale = 1.3
fn = lambda: attention_decode(q, k, v, sm_scale)
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.layout = "bsghd"
fn = lambda: attention_decode(q, k, v, input_metadata)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)

# flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K)
Expand Down

0 comments on commit 485ba55

Please sign in to comment.