diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index 1174360adda..37df74b5bf5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -299,59 +299,6 @@ def test_sdpa_noncausal_unequal_seqlen(device, b, nh, nkv, sq, sk, d, q_chunk_si run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=sk) -@skip_for_blackhole("Mismatching on BH, see #12349") -@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") -@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") -# @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b, ttnn.bfloat16], ids=["bfp8", "bf16"]) -@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) -@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) -@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) -@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) -@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048]) -@pytest.mark.parametrize("page_block_size", [64, 128]) -@pytest.mark.parametrize( - "b, nh, nkv, s, d", - ([1, 8, 1, 16 * 1024, 128],), # Llama2-70B -) -def test_sdpa_chunked( - device, - b, - nh, - nkv, - s, - d, - q_chunk_size, - k_chunk_size, - prefill_chunk_size, - page_block_size, - q_dtype, - k_dtype, - use_program_cache, - use_high_precision_compute=False, -): - for _ in range(2): - run_test_chunked_sdpa( - device, - b, - nh, - nkv, - s, - d, - q_chunk_size, - k_chunk_size, - prefill_chunk_size, - page_block_size, - q_dtype, - k_dtype, - use_high_precision_compute, - ) - - # Print number of program cache entries - assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( - device.num_program_cache_entries() - ) - - def run_test_chunked_sdpa( device, b, @@ -366,9 +313,10 @@ def run_test_chunked_sdpa( q_dtype, k_dtype, use_high_precision_compute, + grid_size=None, ): program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=device.compute_with_storage_grid_size(), + compute_with_storage_grid_size=grid_size or device.compute_with_storage_grid_size(), q_chunk_size=q_chunk_size, k_chunk_size=k_chunk_size, exp_approx_mode=False, @@ -464,3 +412,115 @@ def unpage_cache(cache): out_pass, out_pcc = comp_pcc(gt_chunk, tt_back, 0.998) logger.debug(f"python vs pytorch: {out_pcc}") assert out_pass + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("q_chunk_size", [128, 256], ids=["q128", "q256"]) +@pytest.mark.parametrize("k_chunk_size", [128, 256], ids=["k128", "k256"]) +@pytest.mark.parametrize("prefill_chunk_size", [1024, 2048]) +@pytest.mark.parametrize("page_block_size", [64, 128]) +@pytest.mark.parametrize( + "b, nh, nkv, s, d", + [ + [1, 8, 1, 16 * 1024, 128], + ], # Llama2-70B +) +def test_sdpa_chunked( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_program_cache, + use_high_precision_compute=False, +): + for _ in range(2): + run_test_chunked_sdpa( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_high_precision_compute, + ) + + # Print number of program cache entries + assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( + device.num_program_cache_entries() + ) + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("q_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("k_dtype", [ttnn.bfloat8_b]) +@pytest.mark.parametrize("q_chunk_size", [128]) +@pytest.mark.parametrize("k_chunk_size", [128]) +@pytest.mark.parametrize("prefill_chunk_size", [1024]) +@pytest.mark.parametrize("page_block_size", [64]) +@pytest.mark.parametrize( + "b, nh, nkv, s, d", + [ + [2, 1, 1, 4096, 128], + ], # Llama2-70B +) +def test_sdpa_chunked_iterate_batch( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_program_cache, + use_high_precision_compute=False, +): + """ + This tests chunked prefill where a single core has more than one user to process. + """ + for _ in range(2): + run_test_chunked_sdpa( + device, + b, + nh, + nkv, + s, + d, + q_chunk_size, + k_chunk_size, + prefill_chunk_size, + page_block_size, + q_dtype, + k_dtype, + use_high_precision_compute, + grid_size=(1, 1), + ) + + # Print number of program cache entries + assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format( + device.num_program_cache_entries() + ) diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp index d390839aa50..294a1b0b2a2 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp @@ -71,6 +71,7 @@ void kernel_main() { constexpr uint32_t cb_k_in = tt::CBIndex::c_1; constexpr uint32_t cb_v_in = tt::CBIndex::c_2; constexpr uint32_t cb_mask_in = tt::CBIndex::c_3; + constexpr uint32_t cb_id_page_table = tt::CBIndex::c_6; constexpr uint32_t onetile = 1; constexpr uint32_t q_tile_bytes = get_tile_size(cb_q_in); @@ -109,7 +110,6 @@ void kernel_main() { for (uint32_t nb = local_batch_start; nb < local_batch_end; ++nb) { if constexpr (is_chunked) { // Chunked means that we have paged attention - constexpr uint32_t cb_id_page_table = tt::CBIndex::c_6; const InterleavedAddrGen page_table_gen = { .bank_base_address = page_table_addr, .page_size = page_table_stick_size}; cb_reserve_back(cb_id_page_table, 1); @@ -306,5 +306,9 @@ void kernel_main() { } } } + + if constexpr (is_chunked) { + cb_pop_front(cb_id_page_table, 1); + } } }