diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index a245b05d9..4ac0065d2 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -65,4 +65,4 @@ jobs: - name: AMD Kernel Tests run: | pytest flash_attn/flash_attn_triton_kernel_prefill_amd.py - pytest flash_attn/flash_attn_triton_kernel_decode_amd.py \ No newline at end of file + pytest flash_attn/flash_attn_triton_kernel_decode_amd.py::test_op_fwd \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_kernel_decode_amd.py b/flash_attn/flash_attn_triton_kernel_decode_amd.py index 953261ab0..c027c54d9 100644 --- a/flash_attn/flash_attn_triton_kernel_decode_amd.py +++ b/flash_attn/flash_attn_triton_kernel_decode_amd.py @@ -146,9 +146,6 @@ def _fwd_kernel_splitK( else: kv_len = N_CTX_K hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) - # print("kv_len:", kv_len) - # print("lo:", lo) - # print("hi:", hi) HEAD_RATIO: tl.constexpr = H_q // H_kv if IS_GQA: @@ -354,7 +351,6 @@ def _fwd_kernel_splitK( else: qk = qk - m_i_new[:, None] - # print("qk before p:", qk) p = tl.math.exp2(qk) # print("p:", p) @@ -708,16 +704,6 @@ def forward(cls, q, k, v, input_metadata): # Handle MQA/GQA case if heads_per_group_q > heads_per_group_k: input_metadata.is_gqa = True - - # n_heads_per_group = heads_per_group_q // heads_per_group_k - - # # Repeat each row of k and v to match the number of query heads - # k = k.repeat_interleave(n_heads_per_group, dim=3) - # v = v.repeat_interleave(n_heads_per_group, dim=3) - - # # Update heads_per_group_k and heads_per_group_v - # heads_per_group_k = heads_per_group_q - # heads_per_group_v = heads_per_group_q elif heads_per_group_q < heads_per_group_k: raise ValueError("heads_per_group_q < heads_per_group_k") else: @@ -981,7 +967,9 @@ def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): quant_k = (quantize_kv_int4(k, num_groups=num_groups).contiguous().view(torch.int32)) quant_v = (quantize_kv_int4(v, num_groups=num_groups).contiguous().view(torch.int32)) scale = 1 / K**0.5 - tri_out = attention_decode(q, quant_k, quant_v, scale) + input_metadata = MetaData(sm_scale=scale) + input_metadata.layout = "bsghd" + tri_out = attention_decode(q, quant_k, quant_v, input_metadata) q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) @@ -1049,7 +1037,9 @@ def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype= requires_grad=False).expand(-1, -1, -1, Hq // Hkv, -1) sm_scale = 1.3 - fn = lambda: attention_decode(q, k, v, sm_scale) + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.layout = "bsghd" + fn = lambda: attention_decode(q, k, v, input_metadata) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) # flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K)