diff --git a/csrc/composable_kernel b/csrc/composable_kernel index e7b628644..c6cb8c52c 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit e7b6286441aae59d3a87db67f42369d3cc2636a4 +Subproject commit c6cb8c52c168fcc63ca5fc63fbe9650f81052a26 diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index b6a274f4f..b305c7b84 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m head_size, dtype, true, // is_group_mode - true, // is_v_rowmajor + false, // is_v_rowmajor mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, @@ -183,8 +183,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, at::Tensor out_acc) { // q: (total_q, nheads, d) - // k: (num_blocks, page_block_size, num_heads_k, d) - // v: (num_blocks, page_block_size, num_heads_k, d) + // k: (num_blocks, num_heads_k, d / 8, page_block_size, 8) + // v: (num_blocks, num_heads_k, d, page_block_size) // o: (total_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) @@ -241,12 +241,12 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.nhead_stride_q = q.stride(1); args.batch_stride_k = k.stride(0); - args.stride_k = k.stride(1); - args.nhead_stride_k = k.stride(2); + args.nhead_stride_k = k.stride(1); + args.stride_k = k.stride(2); args.batch_stride_v = v.stride(0); - args.stride_v = v.stride(1); - args.nhead_stride_v = v.stride(2); + args.nhead_stride_v = v.stride(1); + args.stride_v = v.stride(2); args.batch_stride_o = 0; args.stride_o = out.stride(0); @@ -292,8 +292,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x head_size / 8 x page_block_size x 8 if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x page_block_size x head_size if there's a block_table. c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 @@ -348,11 +348,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; - const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + const int num_heads_k = k.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); - const int page_block_size = !paged_KV ? 1 : k.size(1); + const int page_block_size = !paged_KV ? 1 : k.size(3); TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case @@ -394,8 +394,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); } else { - CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); + CHECK_SHAPE(k, num_blocks, num_heads_k, head_size / 8, page_block_size, 8); + CHECK_SHAPE(v, num_blocks, num_heads_k, head_size, page_block_size); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); }