diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index a1d61dbf287..1174360adda 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -73,7 +73,7 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype tt_back = ttnn.transformer.scaled_dot_product_attention( tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config ) - tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + tt_back = ttnn.to_torch(tt_back) K_repeated = torch.cat([K[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S V_repeated = torch.cat([V[:, i : i + 1, :, :].repeat(1, nh // nkv, 1, 1) for i in range(nkv)], dim=1) # b, nh, d, S @@ -238,7 +238,7 @@ def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dt program_config=program_config, compute_kernel_config=compute_kernel_config, ) - tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() + tt_back = ttnn.to_torch(tt_back) if nkv > 1 and nkv != nh: assert nh % nkv == 0 @@ -305,19 +305,13 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si # @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) @pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) @pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) -# @pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) -# @pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) -@pytest.mark.parametrize("q_chunk_size", [64, 128]) -@pytest.mark.parametrize("k_chunk_size", [64, 128]) +@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) +@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) @pytest.mark.parametrize("prefill_chunk_size", [1024, 2048]) @pytest.mark.parametrize("page_block_size", [64, 128]) @pytest.mark.parametrize( "b, nh, nkv, s, d", - ( - [1, 1, 1, 16 * 1024, 32], # Llama2-70B - # [1, 16, 1, 2048, 64], # Falcon-40B - # [1, 71, 1, 2048, 64], # Falcon-7B - ), + ([1, 8, 1, 16 * 1024, 128],), # Llama2-70B ) def test_sdpa_chunked( device, @@ -332,8 +326,8 @@ def test_sdpa_chunked( page_block_size, q_dtype, k_dtype, + use_program_cache, use_high_precision_compute=False, - use_program_cache=False, ): for _ in range(2): run_test_chunked_sdpa( @@ -352,6 +346,11 @@ def test_sdpa_chunked( use_high_precision_compute, ) + # Print number of program cache entries + assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( + device.num_program_cache_entries() + ) + def run_test_chunked_sdpa( device, @@ -369,7 +368,7 @@ def run_test_chunked_sdpa( use_high_precision_compute, ): program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=(1, 1), + compute_with_storage_grid_size=device.compute_with_storage_grid_size(), q_chunk_size=q_chunk_size, k_chunk_size=k_chunk_size, exp_approx_mode=False, @@ -438,24 +437,13 @@ def unpage_cache(cache): return paged_cache_back # Check that we can convert from normal to paged to normal - assert torch.allclose(unpage_cache(page_cache(K)), K) - assert torch.allclose(unpage_cache(page_cache(V)), V) + assert torch.allclose(unpage_cache(page_cache(K)), K), "K is not equal to unpage_cache(page_cache(K))" + assert torch.allclose(unpage_cache(page_cache(V)), V), "V is not equal to unpage_cache(page_cache(V))" tt_paged_K = ttnn.Tensor(page_cache(K), k_dtype).to(ttnn.TILE_LAYOUT).to(device) tt_paged_V = ttnn.Tensor(page_cache(V), k_dtype).to(ttnn.TILE_LAYOUT).to(device) page_table_tt = ttnn.Tensor(page_table, ttnn.int32).to(device) - # tt_K = ttnn.Tensor(K, k_dtype).to(ttnn.TILE_LAYOUT).to(device) - # tt_V = ttnn.Tensor(V, k_dtype).to(ttnn.TILE_LAYOUT).to(device) - - # # TODO: Chunk Q - # tt_Q = ttnn.Tensor(Q, q_dtype).to(ttnn.TILE_LAYOUT).to(device) - - # tt_back = ttnn.transformer.scaled_dot_product_attention( - # tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config - # ) - # tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - for chunk_idx in range(num_prefill_chunks): # Chunk Q Q_chunk = Q[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size] @@ -472,13 +460,7 @@ def unpage_cache(cache): compute_kernel_config=compute_kernel_config, ) tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - - # tt_chunk = tt_back[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size] gt_chunk = gt[:, :, chunk_idx * prefill_chunk_size : (chunk_idx + 1) * prefill_chunk_size] - out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.994) + out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998) logger.debug(f"python vs pytorch: {out_pcc}") assert out_pass - - # out_pass, out_pcc = comp_pcc(gt, tt_back, 0.994) - # logger.debug(f"python vs pytorch: {out_pcc}") - # assert out_pass diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index cea44be64ef..1fa196dc220 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -15,7 +15,6 @@ #include "compute_kernel_api/tile_move_copy.h" #include "compute_kernel_api/matmul.h" #include "compute_kernel_api/reduce.h" -#include "debug/dprint.h" namespace NAMESPACE { template @@ -370,7 +369,7 @@ void MAIN { const uint32_t local_nh_end = get_arg_val(4); const uint32_t local_q_start = get_arg_val(5); const uint32_t local_q_end = get_arg_val(6); - const uint32_t chunk_start_t = get_arg_val(7); + const uint32_t chunked_q_chunk_offset = get_arg_val(7); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -416,6 +415,9 @@ void MAIN { #endif // Get Q chunk + if constexpr (is_chunked) { + q_chunk = chunked_q_chunk_offset + q_chunk; + } uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx; @@ -424,19 +426,12 @@ void MAIN { } else { q_high_idx = Skt; } - if constexpr (is_chunked) { - q_low_idx = chunk_start_t + q_low_idx; - q_high_idx = chunk_start_t + q_high_idx; - } - // UNPACK(DPRINT << "UNPACK: q_low_idx: " << q_low_idx << ENDL();) - // UNPACK(DPRINT << "UNPACK: q_high_idx: " << q_high_idx << ENDL();) cb_wait_front(cb_q_in, q_chunk_tiles); // loop while k_low < q_high for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; const uint32_t k_high_idx = k_low_idx + Sk_chunk_t; - // UNPACK(DPRINT << "UNPACK: k_chunk: " << k_chunk << ENDL();) /* QK = Q_CHUNK @ K_CHUNK */ pack_reconfig_data_format(cb_qk_im); @@ -455,8 +450,6 @@ void MAIN { qk_subblock_w, true /*transpose*/); - // UNPACK(DPRINT << "UNPACK: done with qk_im" << ENDL();) - /* QK *= SCALE */ mul_block_bcast_scalar_inplace(); @@ -467,7 +460,6 @@ void MAIN { // Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check if constexpr (is_causal) { if (!(q_low_idx >= k_high_idx)) { - // UNPACK(DPRINT << "UNPACK: adding mask" << ENDL();) /* QK += MASK */ reconfig_data_format(cb_qk_im, cb_mask_in); add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); @@ -479,7 +471,6 @@ void MAIN { add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); } } - // UNPACK(DPRINT << "UNPACK: done with mask" << ENDL();) reconfig_data_format(cb_qk_im, cb_identity_scale_in); reduce_c< @@ -524,7 +515,7 @@ void MAIN { out_subblock_h, out_subblock_w, false /*transpose*/); - // UNPACK(DPRINT << "UNPACK: done with out_im" << ENDL();) + reconfig_data_format_srca(cb_out_im); cb_pop_front(cb_qk_im, qk_chunk_tiles); @@ -570,8 +561,5 @@ void MAIN { } } } - // UNPACK(DPRINT << "UNPACK: done" << ENDL();) - // MATH(DPRINT << "MATH: done" << ENDL();) - // PACK(DPRINT << "PACK: done" << ENDL();) } } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp index dd36223997a..d390839aa50 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp @@ -4,7 +4,6 @@ #include #include "dataflow_api.h" -#include "debug/dprint.h" template constexpr uint32_t get_barrier_read_threshold() { @@ -58,7 +57,7 @@ void kernel_main() { const uint32_t local_nh_end = get_arg_val(argidx++); const uint32_t local_q_start = get_arg_val(argidx++); const uint32_t local_q_end = get_arg_val(argidx++); - const uint32_t chunk_start_t = get_arg_val(argidx++); + const uint32_t chunked_q_chunk_offset = get_arg_val(argidx++); const uint32_t q_chunks_per_core = local_q_end - local_q_start; @@ -126,7 +125,6 @@ void kernel_main() { const uint32_t mask_batch_offset = nb * Sqt * Skt; for (uint32_t nq = local_nh_start; nq < local_nh_end; ++nq) { for (uint32_t q_iter = 0; q_iter < q_chunks_per_core; ++q_iter) { - // DPRINT << "q_iter: " << q_iter << ENDL(); uint32_t q_chunk; #if defined BALANCED_Q_PARALLEL uint32_t q_chunk_div_2 = q_chunks_per_core / 2; @@ -163,8 +161,9 @@ void kernel_main() { cb_push_back(cb_q_in, q_chunk_tiles); - // DPRINT << "q_chunk: " << q_chunk << "\n"; - + if constexpr (is_chunked) { + q_chunk = chunked_q_chunk_offset + q_chunk; + } uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx; @@ -173,14 +172,6 @@ void kernel_main() { } else { q_high_idx = Skt; } - if constexpr (is_chunked) { - // Add the chunk offset to the low and high indices - q_low_idx = chunk_start_t + q_low_idx; - q_high_idx = chunk_start_t + q_high_idx; - } - - // DPRINT << "q_low_idx: " << q_low_idx << ENDL(); - // DPRINT << "q_high_idx: " << q_high_idx << ENDL(); const uint32_t kv_head = nq / q_heads_per_kv; const uint32_t kv_head_offset = kv_head * Skt * DHt; @@ -190,7 +181,6 @@ void kernel_main() { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; const uint32_t k_high_idx = k_low_idx + Sk_chunk_t; const uint32_t k_start_tile_id = kv_batch_offset + kv_head_offset + k_chunk * Sk_chunk_t * DHt; - // DPRINT << "READER: k_chunk: " << k_chunk << ENDL(); if constexpr (is_chunked) { // Use page table to read K chunk @@ -240,7 +230,6 @@ void kernel_main() { cb_push_back(cb_k_in, k_chunk_tiles); } - // DPRINT << "READER: done with k_chunk: " << k_chunk << ENDL(); if constexpr (use_provided_mask) { // Finding the diagonal is harder now that q_chunk_size and k_chunk_size can differ // Q-range = [q_low, q_high) @@ -314,11 +303,8 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_v_in, k_chunk_tiles); } - - // DPRINT << "READER: done with v_chunk: " << k_chunk << ENDL(); } } } } - // DPRINT << "READER: done" << ENDL(); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp index fe5852d5a7f..5ac5d251036 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp @@ -50,7 +50,6 @@ void fill_diagonal_tile(uint32_t cb_id, uint32_t tile_id, uint32_t partial_val) fill_tile(cb_id, tile_id, 0); - // DPRINT << "Fill partial tile" << ENDL(); const uint16_t datum_val = partial_val >> 16; volatile tt_l1_ptr uint16_t* uint16_ptr = reinterpret_cast(get_write_ptr(cb_id) + tile_id * tile_bytes); @@ -215,12 +214,6 @@ void kernel_main() { q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk uint32_t q_high_idx = q_low_idx + Sq_chunk_t; - // if constexpr (is_chunked) { - // q_low_idx = chunk_start_t + q_low_idx; - // q_high_idx = chunk_start_t + q_high_idx; - // q_chunk = chunk_start_t_in_q_chunks + q_chunk; - // } - for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { const uint32_t k_low_idx = k_chunk * Sk_chunk_t; const uint32_t k_high_idx = k_low_idx + Sk_chunk_t; @@ -233,7 +226,6 @@ void kernel_main() { if (!(q_low_idx >= k_high_idx)) { generate_mask(Sq_chunk_t, Sk_chunk_t, q_chunk, k_chunk); } - // DPRINT << "WRITER: done with k_chunk: " << k_chunk << ENDL(); } } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index 4c4baff4989..cc137f759ed 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -259,4 +259,19 @@ operation::ProgramWithCallbacks ScaledDotProductAttention::create_program( this->program_config); } +operation::Hash ScaledDotProductAttention::compute_program_hash( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const { + bool is_chunked_prefill = this->chunk_start_idx.has_value(); + return operation::hash_operation( + this->scale, + this->output_mem_config, + this->program_config, + this->is_causal, + is_chunked_prefill, + this->compute_kernel_config, + input_tensors, + optional_input_tensors); +} + } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp index 05674c52785..02f76971bec 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.hpp @@ -31,6 +31,10 @@ struct ScaledDotProductAttention { const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector& output_tensors) const; + + operation::Hash compute_program_hash( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors) const; }; } // namespace ttnn::operations::transformer diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 6396d6f5394..ec9805872e7 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -80,8 +80,8 @@ operation::ProgramWithCallbacks sdpa_multi_core( tt::log_debug("k_num_chunks: {}", k_num_chunks); tt::log_debug("NKH: {}", NKH); - uint32_t chunk_start_t = 0; - uint32_t chunk_start_t_in_q_chunks = 0; + // In chunked prefill mode, the offset of Q in terms of Q chunks + uint32_t chunked_q_chunk_offset = 0; uint32_t block_size = 0; uint32_t block_size_t = 0; uint32_t max_blocks_per_seq = 0; @@ -90,8 +90,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( tt::DataFormat page_table_df = tt::DataFormat::Int32; if (is_chunked) { - chunk_start_t = chunk_start_idx.value() / TILE_HEIGHT; // Q offset in tiles - chunk_start_t_in_q_chunks = chunk_start_idx.value() / q_chunk_size; + chunked_q_chunk_offset = chunk_start_idx.value() / q_chunk_size; const auto& page_table_tensor = page_table.value(); block_size = k_shape[2]; // K's sequence dimension represents block size block_size_t = block_size / TILE_HEIGHT; @@ -107,15 +106,14 @@ operation::ProgramWithCallbacks sdpa_multi_core( "page table page size in bytes must be a multiple of 32 due to address alignment"); } // Log page table info - tt::log_info("is_chunked: {}", is_chunked); + tt::log_debug("is_chunked: {}", is_chunked); if (is_chunked) { - tt::log_info("chunk_start_t: {}", chunk_start_t); - tt::log_info("block_size: {}", block_size); - tt::log_info("block_size_t: {}", block_size_t); - tt::log_info("max_blocks_per_seq: {}", max_blocks_per_seq); - tt::log_info("page_table_stick_size: {}", page_table_stick_size); - tt::log_info("page_table_is_dram: {}", page_table_is_dram); - tt::log_info("page_table_df: {}", page_table_df); + tt::log_debug("block_size: {}", block_size); + tt::log_debug("block_size_t: {}", block_size_t); + tt::log_debug("max_blocks_per_seq: {}", max_blocks_per_seq); + tt::log_debug("page_table_stick_size: {}", page_table_stick_size); + tt::log_debug("page_table_is_dram: {}", page_table_is_dram); + tt::log_debug("page_table_df: {}", page_table_df); } Program program = CreateProgram(); @@ -538,7 +536,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( local_nh_end, local_q_start, local_q_end, - chunk_start_t}); + chunked_q_chunk_offset}); SetRuntimeArgs( program, writer_kernels_id, @@ -551,7 +549,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( local_nh_end, local_q_start, local_q_end, - chunk_start_t_in_q_chunks}); + chunked_q_chunk_offset}); SetRuntimeArgs( program, compute_kernels_id, @@ -563,7 +561,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( local_nh_end, local_q_start, local_q_end, - chunk_start_t}); + chunked_q_chunk_offset}); } auto override_runtime_arguments_callback = @@ -604,13 +602,10 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t out_addr = out0_buffer->address(); uint32_t page_table_addr = 0; - uint32_t chunk_start_t = 0; - uint32_t chunk_start_t_in_q_chunks = 0; + uint32_t chunked_q_chunk_offset = 0; if (is_chunked) { page_table_addr = optional_input_tensors.at(1).value().buffer()->address(); - chunk_start_t = static_cast(operation)->chunk_start_idx.value() / - TILE_HEIGHT; // Q offset in tiles - chunk_start_t_in_q_chunks = + chunked_q_chunk_offset = static_cast(operation)->chunk_start_idx.value() / q_chunk_size; } @@ -654,12 +649,12 @@ operation::ProgramWithCallbacks sdpa_multi_core( reader_args[2] = v_addr; reader_args[3] = mask_addr; reader_args[4] = page_table_addr; - reader_args[12] = chunk_start_t; + reader_args[12] = chunked_q_chunk_offset; writer_args[0] = out_addr; - writer_args[8] = chunk_start_t; + writer_args[8] = chunked_q_chunk_offset; - compute_args[7] = chunk_start_t; + compute_args[7] = chunked_q_chunk_offset; } };