Skip to content

Commit

Permalink
Merge pull request #104 from ROCm/page-group-update
Browse files Browse the repository at this point in the history
Fix mha_varlen_fwd num_split and change ck interface
  • Loading branch information
rocking5566 authored Nov 28, 2024
2 parents 0538b43 + d2dff5b commit 58f6f67
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 282 files
3 changes: 2 additions & 1 deletion csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
args.page_block_size = 0;
}

args.is_gappy = false;
args.cache_batch_idx = nullptr;

args.seqstart_q_ptr = seqlens_q.data_ptr();
Expand Down Expand Up @@ -442,7 +443,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
if (return_dropout_randval) {p.zero_();}
}

int num_splits = 1;
int num_splits = 0;
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");
Expand Down

0 comments on commit 58f6f67

Please sign in to comment.