From d018e2a2bf4968a20693d2934c03c7b1c5fb7372 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Tue, 10 Dec 2024 14:21:13 -0800 Subject: [PATCH] #0: Chunked SDPA tests pass --- .../misc/test_scaled_dot_product_attention.py | 12 ++-- .../kernels/dataflow/writer_interleaved.cpp | 22 +++--- .../sdpa/device/sdpa_program_factory.cpp | 72 ++++++++++--------- 3 files changed, 55 insertions(+), 51 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index bad374f1832..a1d61dbf287 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -304,17 +304,17 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") # @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) @pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) -@pytest.mark.parametrize("k_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) # @pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) # @pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) -@pytest.mark.parametrize("q_chunk_size", [32]) -@pytest.mark.parametrize("k_chunk_size", [32]) -@pytest.mark.parametrize("prefill_chunk_size", [32]) -@pytest.mark.parametrize("page_block_size", [32]) +@pytest.mark.parametrize("q_chunk_size", [64, 128]) +@pytest.mark.parametrize("k_chunk_size", [64, 128]) +@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048]) +@pytest.mark.parametrize("page_block_size", [64, 128]) @pytest.mark.parametrize( "b, nh, nkv, s, d", ( - [1, 1, 1, 64, 32], # Llama2-70B + [1, 1, 1, 16 * 1024, 32], # Llama2-70B # [1, 16, 1, 2048, 64], # Falcon-40B # [1, 71, 1, 2048, 64], # Falcon-7B ), diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp index 94954563afc..fe5852d5a7f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp @@ -5,7 +5,6 @@ #include "dataflow_api.h" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" -#include "debug/dprint.h" template constexpr uint32_t get_barrier_read_threshold() { @@ -158,7 +157,7 @@ void kernel_main() { const uint32_t local_nh_end = get_arg_val(5); const uint32_t local_q_start = get_arg_val(6); const uint32_t local_q_end = get_arg_val(7); - const uint32_t chunk_start_t = get_arg_val(8); + const uint32_t chunk_start_t_in_q_chunks = get_arg_val(8); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -208,14 +207,19 @@ void kernel_main() { out_tile_id = q_batch_offset + q_head_offset + q_chunk_offset; if constexpr (is_causal) { + if constexpr (is_chunked) { + // Bump it up to the chunk start + q_chunk = chunk_start_t_in_q_chunks + q_chunk; + } uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx = q_low_idx + Sq_chunk_t; - if constexpr (is_chunked) { - q_low_idx = chunk_start_t + q_low_idx; - q_high_idx = chunk_start_t + q_high_idx; - } + // if constexpr (is_chunked) { + // q_low_idx = chunk_start_t + q_low_idx; + // q_high_idx = chunk_start_t + q_high_idx; + // q_chunk = chunk_start_t_in_q_chunks + q_chunk; + // } for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; @@ -227,11 +231,6 @@ void kernel_main() { // Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check // Read mask chunk if (!(q_low_idx >= k_high_idx)) { - DPRINT << "WRITER: chunk_start_t: " << chunk_start_t << ENDL(); - DPRINT << "WRITER: q_low_idx: " << q_low_idx << ENDL(); - DPRINT << "WRITER: k_high_idx: " << k_high_idx << ENDL(); - DPRINT << "WRITER: masking for q_chunk: " << q_chunk << " and k_chunk: " << k_chunk - << ENDL(); generate_mask(Sq_chunk_t, Sk_chunk_t, q_chunk, k_chunk); } // DPRINT << "WRITER: done with k_chunk: " << k_chunk << ENDL(); @@ -257,5 +256,4 @@ void kernel_main() { } } } - DPRINT << "WRITER: done" << ENDL(); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 48025c65caf..6396d6f5394 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -49,7 +49,39 @@ operation::ProgramWithCallbacks sdpa_multi_core( // Paged cache parameters when in chunked mode bool is_chunked = chunk_start_idx.has_value(); + // In chunked mode, we only need to process K/V up to chunk_start_idx + Sq + const uint32_t Sk = is_chunked ? (chunk_start_idx.value() + Sq) : k_shape[2]; + + const uint32_t Sqt = Sq / TILE_HEIGHT; + const uint32_t Skt = Sk / TILE_HEIGHT; + const uint32_t DHt = DH / TILE_WIDTH; + + const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT; + const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT; + const uint32_t q_num_chunks = Sq / q_chunk_size; + const uint32_t k_num_chunks = Sk / k_chunk_size; + const bool use_provided_mask = attn_mask.has_value(); + + // log_debug all of the above + tt::log_debug("B: {}", B); + tt::log_debug("NQH: {}", NQH); + + tt::log_debug("Sq: {}", Sq); + tt::log_debug("Sk: {}", Sk); + tt::log_debug("DH: {}", DH); + tt::log_debug("Sqt: {}", Sqt); + tt::log_debug("Skt: {}", Skt); + tt::log_debug("DHt: {}", DHt); + tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t); + tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t); + tt::log_debug("q_chunk_size: {}", q_chunk_size); + tt::log_debug("k_chunk_size: {}", k_chunk_size); + tt::log_debug("q_num_chunks: {}", q_num_chunks); + tt::log_debug("k_num_chunks: {}", k_num_chunks); + tt::log_debug("NKH: {}", NKH); + uint32_t chunk_start_t = 0; + uint32_t chunk_start_t_in_q_chunks = 0; uint32_t block_size = 0; uint32_t block_size_t = 0; uint32_t max_blocks_per_seq = 0; @@ -59,6 +91,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( if (is_chunked) { chunk_start_t = chunk_start_idx.value() / TILE_HEIGHT; // Q offset in tiles + chunk_start_t_in_q_chunks = chunk_start_idx.value() / q_chunk_size; const auto& page_table_tensor = page_table.value(); block_size = k_shape[2]; // K's sequence dimension represents block size block_size_t = block_size / TILE_HEIGHT; @@ -85,37 +118,6 @@ operation::ProgramWithCallbacks sdpa_multi_core( tt::log_info("page_table_df: {}", page_table_df); } - // In chunked mode, we only need to process K/V up to chunk_start_idx + Sq - const uint32_t Sk = is_chunked ? (chunk_start_idx.value() + Sq) : k_shape[2]; - - const uint32_t Sqt = Sq / TILE_HEIGHT; - const uint32_t Skt = Sk / TILE_HEIGHT; - const uint32_t DHt = DH / TILE_WIDTH; - - const uint32_t Sq_chunk_t = q_chunk_size / TILE_HEIGHT; - const uint32_t Sk_chunk_t = k_chunk_size / TILE_HEIGHT; - const uint32_t q_num_chunks = Sq / q_chunk_size; - const uint32_t k_num_chunks = Sk / k_chunk_size; - const bool use_provided_mask = attn_mask.has_value(); - - // log_debug all of the above - tt::log_debug("B: {}", B); - tt::log_debug("NQH: {}", NQH); - - tt::log_debug("Sq: {}", Sq); - tt::log_debug("Sk: {}", Sk); - tt::log_debug("DH: {}", DH); - tt::log_debug("Sqt: {}", Sqt); - tt::log_debug("Skt: {}", Skt); - tt::log_debug("DHt: {}", DHt); - tt::log_debug("Sq_chunk_t: {}", Sq_chunk_t); - tt::log_debug("Sk_chunk_t: {}", Sk_chunk_t); - tt::log_debug("q_chunk_size: {}", q_chunk_size); - tt::log_debug("k_chunk_size: {}", k_chunk_size); - tt::log_debug("q_num_chunks: {}", q_num_chunks); - tt::log_debug("k_num_chunks: {}", k_num_chunks); - tt::log_debug("NKH: {}", NKH); - Program program = CreateProgram(); Device* device = input_tensor_q.device(); @@ -549,7 +551,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( local_nh_end, local_q_start, local_q_end, - chunk_start_t}); + chunk_start_t_in_q_chunks}); SetRuntimeArgs( program, compute_kernels_id, @@ -581,7 +583,8 @@ operation::ProgramWithCallbacks sdpa_multi_core( is_causal, cb_in0_id, cb_out0_id, - is_chunked]( + is_chunked, + q_chunk_size]( const void* operation, Program& program, const std::vector& input_tensors, @@ -602,10 +605,13 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t page_table_addr = 0; uint32_t chunk_start_t = 0; + uint32_t chunk_start_t_in_q_chunks = 0; if (is_chunked) { page_table_addr = optional_input_tensors.at(1).value().buffer()->address(); chunk_start_t = static_cast(operation)->chunk_start_idx.value() / TILE_HEIGHT; // Q offset in tiles + chunk_start_t_in_q_chunks = + static_cast(operation)->chunk_start_idx.value() / q_chunk_size; } auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id);