diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py index 8be5f8b317e..1c908f0ab94 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py @@ -321,11 +321,21 @@ def run_test_sdpa_decode_single_iter( sharded_out=False, start_indices=None, causal=True, + start_core=ttnn.CoreCoord(0, 0), + sub_core_grids=None, ): compute_grid_size = device.compute_with_storage_grid_size() - if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: - pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") - + if sub_core_grids is None: + if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: + pytest.skip(f"Need {grid_size} grid size to run this test but core grid is {compute_grid_size}") + else: + unharvested_grid_size = (7, 10) + if compute_grid_size.x > unharvested_grid_size[0] or compute_grid_size.y > unharvested_grid_size[1]: + pytest.skip(f"Need {unharvested_grid_size} grid size to run this test but core grid is {compute_grid_size}") + if grid_size[0] * grid_size[1] > sub_core_grids.num_cores(): + pytest.skip( + f"Need {grid_size[0]*grid_size[1]} grid size to run this test but core grid is {sub_core_grids.num_cores()}" + ) padded_num_heads = nearest_pow_2(nearest_n(nh, n=32)) torch.manual_seed(1234) @@ -346,7 +356,14 @@ def run_test_sdpa_decode_single_iter( ) dram_memcfg = ttnn.DRAM_MEMORY_CONFIG - shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) + if sub_core_grids is None: + shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) + compute_sub_core_grids = None + else: + shard_grid = ttnn.num_cores_to_corerangeset_in_subcoregrids(start_core, b, sub_core_grids, row_wise=True) + compute_sub_core_grids = ttnn.num_cores_to_corerangeset_in_subcoregrids( + start_core, grid_size[0] * grid_size[1], sub_core_grids, row_wise=True + ) shard_spec = ttnn.ShardSpec(shard_grid, (padded_num_heads, d), ttnn.ShardOrientation.ROW_MAJOR, False) height_sharded_memcfg = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) @@ -364,6 +381,7 @@ def run_test_sdpa_decode_single_iter( k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, + sub_core_grids=compute_sub_core_grids, q_chunk_size=padded_num_heads, k_chunk_size=k_chunk_size, exp_approx_mode=False, @@ -904,6 +922,75 @@ def test_sdpa_decode_sharded(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype ) +@skip_for_blackhole("Unsupported on BH, see #12349") +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize("device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], indirect=True) +@pytest.mark.parametrize( + "dtype, q_dtype", + [ + [ttnn.bfloat8_b, ttnn.bfloat16], + ], + ids=[ + "bfp8_cache_bf16_act", + ], +) +@pytest.mark.parametrize( + "b, nh, nkv, s, d, grid_size", + ( + [8, 8, 1, 2048, 128, (8, 4)], + [8, 8, 1, 256, 128, (8, 4)], + ), # Llama2-70B +) +@pytest.mark.parametrize( + "start_core, sub_core_grids", + [ + ( + ttnn.CoreCoord(1, 0), + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)), + ] + ), + ), + ], +) +def test_sdpa_decode_sharded_on_subcoregrids( + device, use_program_cache, b, nh, nkv, s, d, dtype, grid_size, q_dtype, start_core, sub_core_grids +): + run_test_sdpa_decode_single_iter( + device, + b, + nh, + nkv, + s, + d, + dtype, + grid_size, + q_dtype, + sharded_in=True, + sharded_out=True, + start_core=start_core, + sub_core_grids=sub_core_grids, + ) + run_test_sdpa_decode_single_iter( + device, + b, + nh, + nkv, + s, + d, + dtype, + grid_size, + q_dtype, + sharded_in=True, + sharded_out=True, + start_core=start_core, + sub_core_grids=sub_core_grids, + ) + assert device.num_program_cache_entries() == 1 + + @skip_for_blackhole("Unsupported on BH, see #12349") @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @pytest.mark.skip("Skipping Perf Test in CI") diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp index 8dc18614f00..c968f5d8a7f 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_config.hpp @@ -11,6 +11,7 @@ namespace ttnn::operations::transformer { struct SDPAProgramConfig { CoreCoord compute_with_storage_grid_size; + std::optional sub_core_grids; std::size_t q_chunk_size; std::size_t k_chunk_size; std::optional exp_approx_mode; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 7c09d0e4de0..9615b729578 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -122,9 +122,20 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( CoreCoord grid_size = program_config.has_value() ? program_config->compute_with_storage_grid_size : device->compute_with_storage_grid_size(); - auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); uint32_t num_cores_available = grid_size.x * grid_size.y; + CoreRangeSet core_grid; + bool on_subcoregrid = false; + if (program_config.has_value() && program_config->sub_core_grids.has_value()) { + core_grid = program_config->sub_core_grids.value(); + TT_FATAL( + core_grid.num_cores() == num_cores_available, + "Number of cores in sub_core_grids must match the number of cores available"); + on_subcoregrid = true; + } else { + core_grid = CoreRangeSet(std::vector{CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1})}); + } + uint32_t num_cores_in_grid = device->compute_with_storage_grid_size().x * device->compute_with_storage_grid_size().y; TT_FATAL(num_cores_available <= num_cores_in_grid, "Expected number of cores available to be less than or equal to the number of cores in the grid, got {} and {}", num_cores_available, num_cores_in_grid); TT_FATAL(num_cores_available >= B, "Expect number of cores available to be greater or equal to batch size, got {} and {}", num_cores_available, B); @@ -154,32 +165,53 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( // h_worker2) head_reducer2 to head_reducerk then send the result to head_reducer1, which is also the batch_output1 std::vector core_group; std::vector core_group_idle; - if (is_q_sharded || is_output_sharded) { - int reducer_idx = 0; - int worker_idx = num_output_cores; - - for (int i = 0; i < num_cores_available; ++i) { - CoreCoord core; - if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { - core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x}; - reducer_idx++; - } else { - core = {worker_idx % grid_size.x, worker_idx / grid_size.x}; - worker_idx++; - } - if (i < num_active_cores) { - core_group.push_back(core); - } else { - core_group_idle.push_back(core); + if (on_subcoregrid) { + if (is_q_sharded || is_output_sharded) { + auto cores_vec = corerange_to_cores(core_grid, num_cores_available, true); + int reducer_idx = 0; + int worker_idx = num_output_cores; + for (int i = 0; i < num_cores_available; ++i) { + if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { + i < num_active_cores ? core_group.push_back(cores_vec[reducer_idx]) + : core_group_idle.push_back(cores_vec[reducer_idx]); + reducer_idx++; + } else { + i < num_active_cores ? core_group.push_back(cores_vec[worker_idx]) + : core_group_idle.push_back(cores_vec[worker_idx]); + worker_idx++; + } } + } else { + TT_FATAL(false, "We only support SDPA on subcoregrids with sharded Q and sharded output"); } } else { - for (int i = 0; i < num_cores_available; ++i) { - CoreCoord core = {i % grid_size.x, i / grid_size.x}; - if (i < num_active_cores) { - core_group.push_back(core); - } else { - core_group_idle.push_back(core); + if (is_q_sharded || is_output_sharded) { + int reducer_idx = 0; + int worker_idx = num_output_cores; + + for (int i = 0; i < num_cores_available; ++i) { + CoreCoord core; + if (i % num_cores_per_batch == 0 && reducer_idx < num_output_cores) { + core = {reducer_idx % grid_size.x, reducer_idx / grid_size.x}; + reducer_idx++; + } else { + core = {worker_idx % grid_size.x, worker_idx / grid_size.x}; + worker_idx++; + } + if (i < num_active_cores) { + core_group.push_back(core); + } else { + core_group_idle.push_back(core); + } + } + } else { + for (int i = 0; i < num_cores_available; ++i) { + CoreCoord core = {i % grid_size.x, i / grid_size.x}; + if (i < num_active_cores) { + core_group.push_back(core); + } else { + core_group_idle.push_back(core); + } } } } diff --git a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp index a1c1129cea6..75ae4fffad6 100644 --- a/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/transformer_pybind.cpp @@ -21,13 +21,15 @@ namespace py = pybind11; void py_module(py::module& module) { py::class_(module, "SDPAProgramConfig") .def( - py::init>(), + py::init, std::size_t, std::size_t, std::optional>(), py::kw_only(), py::arg("compute_with_storage_grid_size"), + py::arg("sub_core_grids") = std::nullopt, py::arg("q_chunk_size").noconvert(), py::arg("k_chunk_size").noconvert(), py::arg("exp_approx_mode") = std::nullopt) .def_readwrite("compute_with_storage_grid_size", &SDPAProgramConfig::compute_with_storage_grid_size) + .def_readwrite("sub_core_grids", &SDPAProgramConfig::sub_core_grids) .def_readwrite("q_chunk_size", &SDPAProgramConfig::q_chunk_size) .def_readwrite("k_chunk_size", &SDPAProgramConfig::k_chunk_size) .def_readwrite("exp_approx_mode", &SDPAProgramConfig::exp_approx_mode);