diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_numcores_to_corerangeset_subcoregrids.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py similarity index 90% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_numcores_to_corerangeset_subcoregrids.py rename to tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py index f07e6f21230c..0840e55513db 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_numcores_to_corerangeset_subcoregrids.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_num_cores_to_corerangeset_in_subcoregrids.py @@ -7,9 +7,9 @@ @pytest.mark.parametrize( - "start_core, num_cores, subcoregrids, row_wise, expected_core_range_set", + "start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set", [ - # Test Case 1: Basic row-wise scenario with enough cores in subcoregrids + # Test Case 1: Basic row-wise scenario with enough cores in sub_core_grids ( ttnn.CoreCoord(1, 0), 32, @@ -84,8 +84,10 @@ ), ], ) -def test_numcores_to_corerangeset(start_core, num_cores, subcoregrids, row_wise, expected_core_range_set): +def test_numcores_to_corerangeset_in_subcoregrids( + start_core, num_cores, sub_core_grids, row_wise, expected_core_range_set +): output_corerangeset = ttnn.num_cores_to_corerangeset_in_subcoregrids( - start_core, num_cores, subcoregrids, row_wise=row_wise + start_core, num_cores, sub_core_grids, row_wise=row_wise ) assert output_corerangeset.to_json() == expected_core_range_set.to_json() diff --git a/tt_metal/common/work_split.cpp b/tt_metal/common/work_split.cpp index 43241801cb21..ba687d9d3dab 100644 --- a/tt_metal/common/work_split.cpp +++ b/tt_metal/common/work_split.cpp @@ -151,18 +151,18 @@ CoreRangeSet num_cores_to_corerangeset( CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( const CoreCoord start_core, const uint32_t target_num_cores, - const CoreRangeSet subcoregrids, + const CoreRangeSet& sub_core_grids, const bool row_wise = false) { // If target_num_cores is 0 or input_corerangeset is empty, return empty CoreRangeSet TT_FATAL(target_num_cores > 0, "Target number of cores must be greater than 0"); TT_FATAL( - target_num_cores <= subcoregrids.num_cores(), + target_num_cores <= sub_core_grids.num_cores(), "Target number of cores {} is greater than total number of available cores {}", target_num_cores, - subcoregrids.num_cores()); + sub_core_grids.num_cores()); // Validate that the start core is contained within the entire CoreRangeSet - TT_FATAL(subcoregrids.contains(start_core), "Start core must be contained within the input CoreRangeSet"); + TT_FATAL(sub_core_grids.contains(start_core), "Start core must be contained within the input CoreRangeSet"); std::vector result_coreranges; bool start_core_found = false; @@ -241,7 +241,7 @@ CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( }; // Iterate over subcoregrids and process based on row_wise - for (const auto& subcoregrid : subcoregrids.ranges()) { + for (const auto& subcoregrid : sub_core_grids.ranges()) { if (subcoregrid.contains(start_core)) { start_core_found = true; } else { diff --git a/tt_metal/common/work_split.hpp b/tt_metal/common/work_split.hpp index 76cc43c77f3e..2b5ae0ecb9d8 100644 --- a/tt_metal/common/work_split.hpp +++ b/tt_metal/common/work_split.hpp @@ -43,7 +43,7 @@ CoreRangeSet num_cores_to_corerangeset( CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( const CoreCoord start_core, const uint32_t target_num_cores, - const CoreRangeSet subcoregrids, + const CoreRangeSet& sub_core_grids, const bool row_wise = false); // This function takes in the core grid size, as well as the number of units of work to divide between the cores // This function returns the number of cores, the CoreRangeSet of all cores, and then the CoreRangeSet that does diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index e15f29bd9aaa..4aed692974d5 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -351,7 +351,7 @@ void py_module(py::module& module) { module.def( "num_cores_to_corerangeset_in_subcoregrids", - py::overload_cast( + py::overload_cast( &tt::tt_metal::num_cores_to_corerangeset_in_subcoregrids), R"doc(Create a CoreRangeSet containing the specified number of cores starting from start_core in given subcoregrids)doc"); } diff --git a/ttnn/ttnn/core.py b/ttnn/ttnn/core.py index f161af720254..a1879be24f94 100644 --- a/ttnn/ttnn/core.py +++ b/ttnn/ttnn/core.py @@ -62,7 +62,7 @@ def num_cores_to_corerangeset( def num_cores_to_corerangeset_in_subcoregrids( start_core: ttnn.CoreCoord, target_num_cores: int, - subcoregrids: ttnn.CoreRangeSet, + sub_core_grids: ttnn.CoreRangeSet, row_wise: bool = False, ): """ @@ -71,7 +71,7 @@ def num_cores_to_corerangeset_in_subcoregrids( return ttnn._ttnn.operations.core.num_cores_to_corerangeset_in_subcoregrids( start_core, target_num_cores, - subcoregrids, + sub_core_grids, row_wise, )