Skip to content

Commit

Permalink
support vllm splitkv layout
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Dec 2, 2024
1 parent 58f6f67 commit ff24498
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 31 files
+0 −4 Dockerfile
+26 −0 Dockerfile.compiler
+25 −20 Jenkinsfile
+1 −1 docs/sphinx/requirements.in
+1 −1 docs/sphinx/requirements.txt
+0 −1 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+1 −1 example/ck_tile/03_gemm/CMakeLists.txt
+22 −3 example/ck_tile/03_gemm/universal_gemm.cpp
+1 −0 example/ck_tile/16_batched_gemm/CMakeLists.txt
+37 −0 example/ck_tile/16_batched_gemm/README.md
+103 −0 example/ck_tile/16_batched_gemm/batched_gemm.cpp
+63 −0 example/ck_tile/16_batched_gemm/batched_gemm.hpp
+253 −0 example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+1 −1 example/ck_tile/CMakeLists.txt
+112 −0 include/ck_tile/host/reference/reference_gemm.hpp
+32 −6 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+3 −0 include/ck_tile/ops/gemm.hpp
+110 −119 include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp
+258 −0 include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+111 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+383 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+86 −180 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+1 −1 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+149 −0 python/ck4inductor/batched_universal_gemm/gen_instances.py
+99 −0 python/ck4inductor/batched_universal_gemm/op.py
+1 −3 python/ck4inductor/grouped_conv_fwd/gen_instances.py
+1 −0 test/ck_tile/CMakeLists.txt
+4 −0 test/ck_tile/batched_gemm/CMakeLists.txt
+29 −0 test/ck_tile/batched_gemm/test_batched_gemm.cpp
+9 −0 test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc
+225 −0 test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
26 changes: 13 additions & 13 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m
head_size,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
false, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
Expand Down Expand Up @@ -183,8 +183,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
at::Tensor out_acc)
{
// q: (total_q, nheads, d)
// k: (num_blocks, page_block_size, num_heads_k, d)
// v: (num_blocks, page_block_size, num_heads_k, d)
// k: (num_blocks, num_heads_k, d / 8, page_block_size, 8)
// v: (num_blocks, num_heads_k, d, page_block_size)
// o: (total_q, nheads, d)

// alibi_slopes:(batch_size, nheads) or (nhead)
Expand Down Expand Up @@ -241,12 +241,12 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
args.nhead_stride_q = q.stride(1);

args.batch_stride_k = k.stride(0);
args.stride_k = k.stride(1);
args.nhead_stride_k = k.stride(2);
args.nhead_stride_k = k.stride(1);
args.stride_k = k.stride(2);

args.batch_stride_v = v.stride(0);
args.stride_v = v.stride(1);
args.nhead_stride_v = v.stride(2);
args.nhead_stride_v = v.stride(1);
args.stride_v = v.stride(2);

args.batch_stride_o = 0;
args.stride_o = out.stride(0);
Expand Down Expand Up @@ -292,8 +292,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x head_size / 8 x page_block_size x 8 if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x page_block_size x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
Expand Down Expand Up @@ -348,11 +348,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = k.size(1);

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
const int page_block_size = !paged_KV ? 1 : k.size(3);
TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");

if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
Expand Down Expand Up @@ -394,8 +394,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(k, num_blocks, num_heads_k, head_size / 8, page_block_size, 8);
CHECK_SHAPE(v, num_blocks, num_heads_k, head_size, page_block_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

Expand Down

0 comments on commit ff24498

Please sign in to comment.