From af935b92252252eb31e19050476dbcacc04d1578 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 26 Nov 2024 04:57:37 -0500 Subject: [PATCH 1/3] update for new ck interface --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index e426ce2aa..e6b147b19 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -217,6 +217,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.page_block_size = 0; } + args.is_gappy = false; args.cache_batch_idx = nullptr; args.seqstart_q_ptr = seqlens_q.data_ptr(); From 8765dc11a7ae18fbeb18e72556a46824998c8e66 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 28 Nov 2024 03:27:46 -0500 Subject: [PATCH 2/3] update ck --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index fb1ccfa9d..e7b628644 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit fb1ccfa9df534c8c9f351dd959a0ff692d6f9210 +Subproject commit e7b6286441aae59d3a87db67f42369d3cc2636a4 From d2dff5ba43c30e6cb626ed25d2501d8e7a0b4849 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 28 Nov 2024 03:29:09 -0500 Subject: [PATCH 3/3] Fix default value of num_splits --- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index e6b147b19..b6a274f4f 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -443,7 +443,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si if (return_dropout_randval) {p.zero_();} } - int num_splits = 1; + int num_splits = 0; num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");