Skip to content

Commit

Permalink
#0: Chunked SDPA tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 10, 2024
1 parent 6deae12 commit d018e2a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,17 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
# @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"])
@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat16])
@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b])
# @pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"])
# @pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"])
@pytest.mark.parametrize("q_chunk_size", [32])
@pytest.mark.parametrize("k_chunk_size", [32])
@pytest.mark.parametrize("prefill_chunk_size", [32])
@pytest.mark.parametrize("page_block_size", [32])
@pytest.mark.parametrize("q_chunk_size", [64, 128])
@pytest.mark.parametrize("k_chunk_size", [64, 128])
@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048])
@pytest.mark.parametrize("page_block_size", [64, 128])
@pytest.mark.parametrize(
"b, nh, nkv, s, d",
(
[1, 1, 1, 64, 32], # Llama2-70B
[1, 1, 1, 16 * 1024, 32], # Llama2-70B
# [1, 16, 1, 2048, 64], # Falcon-40B
# [1, 71, 1, 2048, 64], # Falcon-7B
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "dataflow_api.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp"
#include "debug/dprint.h"

template <uint32_t tile_bytes, uint32_t num_readers>
constexpr uint32_t get_barrier_read_threshold() {
Expand Down Expand Up @@ -158,7 +157,7 @@ void kernel_main() {
const uint32_t local_nh_end = get_arg_val<uint32_t>(5);
const uint32_t local_q_start = get_arg_val<uint32_t>(6);
const uint32_t local_q_end = get_arg_val<uint32_t>(7);
const uint32_t chunk_start_t = get_arg_val<uint32_t>(8);
const uint32_t chunk_start_t_in_q_chunks = get_arg_val<uint32_t>(8);

const uint32_t q_chunks_per_core = local_q_end - local_q_start;

Expand Down Expand Up @@ -208,14 +207,19 @@ void kernel_main() {
out_tile_id = q_batch_offset + q_head_offset + q_chunk_offset;

if constexpr (is_causal) {
if constexpr (is_chunked) {
// Bump it up to the chunk start
q_chunk = chunk_start_t_in_q_chunks + q_chunk;
}
uint32_t q_low_idx =
q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk
uint32_t q_high_idx = q_low_idx + Sq_chunk_t;

if constexpr (is_chunked) {
q_low_idx = chunk_start_t + q_low_idx;
q_high_idx = chunk_start_t + q_high_idx;
}
// if constexpr (is_chunked) {
// q_low_idx = chunk_start_t + q_low_idx;
// q_high_idx = chunk_start_t + q_high_idx;
// q_chunk = chunk_start_t_in_q_chunks + q_chunk;
// }

for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) {
const uint32_t k_low_idx = k_chunk * Sk_chunk_t;
Expand All @@ -227,11 +231,6 @@ void kernel_main() {
// Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check
// Read mask chunk
if (!(q_low_idx >= k_high_idx)) {
DPRINT << "WRITER: chunk_start_t: " << chunk_start_t << ENDL();
DPRINT << "WRITER: q_low_idx: " << q_low_idx << ENDL();
DPRINT << "WRITER: k_high_idx: " << k_high_idx << ENDL();
DPRINT << "WRITER: masking for q_chunk: " << q_chunk << " and k_chunk: " << k_chunk
<< ENDL();
generate_mask<cb_mask_in>(Sq_chunk_t, Sk_chunk_t, q_chunk, k_chunk);
}
// DPRINT << "WRITER: done with k_chunk: " << k_chunk << ENDL();
Expand All @@ -257,5 +256,4 @@ void kernel_main() {
}
}
}
DPRINT << "WRITER: done" << ENDL();
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,39 @@ operation::ProgramWithCallbacks sdpa_multi_core(

// Paged cache parameters when in chunked mode
bool is_chunked = chunk_start_idx.has_value();
// In chunked mode, we only need to process K/V up to chunk_start_idx + Sq
const uint32_t Sk = is_chunked ? (chunk_start_idx.value() + Sq) : k_shape[2];

const uint32_t Sqt = Sq / TILE_HEIGHT;
const uint32_t Skt = Sk / TILE_HEIGHT;
const uint32_t DHt = DH / TILE_WIDTH;

const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT;
const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT;
const uint32_t q_num_chunks = Sq / q_chunk_size;
const uint32_t k_num_chunks = Sk / k_chunk_size;
const bool use_provided_mask = attn_mask.has_value();

// log_debug all of the above
tt::log_debug("B: {}", B);
tt::log_debug("NQH: {}", NQH);

tt::log_debug("Sq: {}", Sq);
tt::log_debug("Sk: {}", Sk);
tt::log_debug("DH: {}", DH);
tt::log_debug("Sqt: {}", Sqt);
tt::log_debug("Skt: {}", Skt);
tt::log_debug("DHt: {}", DHt);
tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t);
tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t);
tt::log_debug("q_chunk_size: {}", q_chunk_size);
tt::log_debug("k_chunk_size: {}", k_chunk_size);
tt::log_debug("q_num_chunks: {}", q_num_chunks);
tt::log_debug("k_num_chunks: {}", k_num_chunks);
tt::log_debug("NKH: {}", NKH);

uint32_t chunk_start_t = 0;
uint32_t chunk_start_t_in_q_chunks = 0;
uint32_t block_size = 0;
uint32_t block_size_t = 0;
uint32_t max_blocks_per_seq = 0;
Expand All @@ -59,6 +91,7 @@ operation::ProgramWithCallbacks sdpa_multi_core(

if (is_chunked) {
chunk_start_t = chunk_start_idx.value() / TILE_HEIGHT; // Q offset in tiles
chunk_start_t_in_q_chunks = chunk_start_idx.value() / q_chunk_size;
const auto& page_table_tensor = page_table.value();
block_size = k_shape[2]; // K's sequence dimension represents block size
block_size_t = block_size / TILE_HEIGHT;
Expand All @@ -85,37 +118,6 @@ operation::ProgramWithCallbacks sdpa_multi_core(
tt::log_info("page_table_df: {}", page_table_df);
}

// In chunked mode, we only need to process K/V up to chunk_start_idx + Sq
const uint32_t Sk = is_chunked ? (chunk_start_idx.value() + Sq) : k_shape[2];

const uint32_t Sqt = Sq / TILE_HEIGHT;
const uint32_t Skt = Sk / TILE_HEIGHT;
const uint32_t DHt = DH / TILE_WIDTH;

const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT;
const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT;
const uint32_t q_num_chunks = Sq / q_chunk_size;
const uint32_t k_num_chunks = Sk / k_chunk_size;
const bool use_provided_mask = attn_mask.has_value();

// log_debug all of the above
tt::log_debug("B: {}", B);
tt::log_debug("NQH: {}", NQH);

tt::log_debug("Sq: {}", Sq);
tt::log_debug("Sk: {}", Sk);
tt::log_debug("DH: {}", DH);
tt::log_debug("Sqt: {}", Sqt);
tt::log_debug("Skt: {}", Skt);
tt::log_debug("DHt: {}", DHt);
tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t);
tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t);
tt::log_debug("q_chunk_size: {}", q_chunk_size);
tt::log_debug("k_chunk_size: {}", k_chunk_size);
tt::log_debug("q_num_chunks: {}", q_num_chunks);
tt::log_debug("k_num_chunks: {}", k_num_chunks);
tt::log_debug("NKH: {}", NKH);

Program program = CreateProgram();

Device* device = input_tensor_q.device();
Expand Down Expand Up @@ -549,7 +551,7 @@ operation::ProgramWithCallbacks sdpa_multi_core(
local_nh_end,
local_q_start,
local_q_end,
chunk_start_t});
chunk_start_t_in_q_chunks});
SetRuntimeArgs(
program,
compute_kernels_id,
Expand Down Expand Up @@ -581,7 +583,8 @@ operation::ProgramWithCallbacks sdpa_multi_core(
is_causal,
cb_in0_id,
cb_out0_id,
is_chunked](
is_chunked,
q_chunk_size](
const void* operation,
Program& program,
const std::vector<Tensor>& input_tensors,
Expand All @@ -602,10 +605,13 @@ operation::ProgramWithCallbacks sdpa_multi_core(

uint32_t page_table_addr = 0;
uint32_t chunk_start_t = 0;
uint32_t chunk_start_t_in_q_chunks = 0;
if (is_chunked) {
page_table_addr = optional_input_tensors.at(1).value().buffer()->address();
chunk_start_t = static_cast<const ScaledDotProductAttention*>(operation)->chunk_start_idx.value() /
TILE_HEIGHT; // Q offset in tiles
chunk_start_t_in_q_chunks =
static_cast<const ScaledDotProductAttention*>(operation)->chunk_start_idx.value() / q_chunk_size;
}

auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id);
Expand Down

0 comments on commit d018e2a

Please sign in to comment.