Skip to content

Commit

Permalink
#0: added api for generating corerangeset from given subcoregrid
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar committed Dec 9, 2024
1 parent 84401ed commit c4ac298
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import pytest


@pytest.mark.parametrize(
"start_core, num_cores, subcoregrids, row_wise, expected_core_range_set",
[
# Test Case 1: Basic row-wise scenario with enough cores in subcoregrids
(
ttnn.CoreCoord(1, 0),
32,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
),
True,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 0)),
]
),
),
# Test Case 2: Basic Column-wise processing
(
ttnn.CoreCoord(1, 0),
32,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
),
False,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(5, 1)),
]
),
),
# Test Case 3: row-wise scenario with small target cores and start offset
(
ttnn.CoreCoord(3, 2),
8,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
),
True,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(3, 2), ttnn.CoreCoord(3, 2)),
ttnn.CoreRange(ttnn.CoreCoord(1, 3), ttnn.CoreCoord(3, 4)),
ttnn.CoreRange(ttnn.CoreCoord(1, 5), ttnn.CoreCoord(1, 5)),
]
),
),
# Test Case 4: col-wise scenario with small target cores and start offset
(
ttnn.CoreCoord(1, 8),
8,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)),
ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(6, 9)),
]
),
False,
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 8), ttnn.CoreCoord(1, 9)),
ttnn.CoreRange(ttnn.CoreCoord(2, 0), ttnn.CoreCoord(2, 5)),
]
),
),
],
)
def test_numcores_to_corerangeset(start_core, num_cores, subcoregrids, row_wise, expected_core_range_set):
output_corerangeset = ttnn.num_cores_to_corerangeset_in_subcoregrids(
start_core, num_cores, subcoregrids, row_wise=row_wise
)
assert output_corerangeset.to_json() == expected_core_range_set.to_json()
117 changes: 117 additions & 0 deletions tt_metal/common/work_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,123 @@ CoreRangeSet num_cores_to_corerangeset(
return num_cores_to_corerangeset({0, 0}, target_num_cores, grid_size, row_wise);
}

CoreRangeSet num_cores_to_corerangeset_in_subcoregrids(
const CoreCoord start_core,
const uint32_t target_num_cores,
const CoreRangeSet subcoregrids,
const bool row_wise = false) {
// If target_num_cores is 0 or input_corerangeset is empty, return empty CoreRangeSet
TT_FATAL(target_num_cores > 0, "Target number of cores must be greater than 0");
TT_FATAL(
target_num_cores <= subcoregrids.num_cores(),
"Target number of cores {} is greater than total number of available cores {}",
target_num_cores,
subcoregrids.num_cores());

// Validate that the start core is contained within the entire CoreRangeSet
TT_FATAL(subcoregrids.contains(start_core), "Start core must be contained within the input CoreRangeSet");

std::vector<CoreRange> result_coreranges;
bool start_core_found = false;
CoreCoord current_start_core = start_core;
CoreCoord current_end_core = start_core;
uint32_t remaining_cores = target_num_cores;

auto process_row_wise = [&](const CoreRange& subcoregrid) {
uint32_t subcoregrid_width = subcoregrid.grid_size().x;

for (uint32_t y = current_start_core.y; y <= subcoregrid.end_coord.y; ++y) {
if (remaining_cores == 0) {
break;
}

uint32_t current_width =
std::min(static_cast<uint32_t>(subcoregrid.end_coord.x - current_start_core.x + 1), remaining_cores);

if (current_width < subcoregrid_width) {
if (current_start_core != current_end_core) {
result_coreranges.push_back(CoreRange(current_start_core, current_end_core));
}

current_end_core = CoreCoord(current_start_core.x + current_width - 1, y);
remaining_cores -= current_width;

result_coreranges.push_back(
CoreRange(CoreCoord(current_start_core.x, y), CoreCoord(current_end_core.x, y)));

current_start_core = CoreCoord(subcoregrid.start_coord.x, y + 1);
current_end_core = current_start_core;
} else {
current_end_core = CoreCoord(subcoregrid.end_coord.x, y);
remaining_cores -= current_width;
}
}

if (current_start_core != current_end_core) {
result_coreranges.push_back(CoreRange(current_start_core, current_end_core));
}
};

auto process_col_wise = [&](const CoreRange& subcoregrid) {
uint32_t subcoregrid_height = subcoregrid.grid_size().y;

for (uint32_t x = current_start_core.x; x <= subcoregrid.end_coord.x; ++x) {
if (remaining_cores == 0) {
break;
}

uint32_t current_height =
std::min(static_cast<uint32_t>(subcoregrid.end_coord.y - current_start_core.y + 1), remaining_cores);

if (current_height < subcoregrid_height) {
if (current_start_core != current_end_core) {
result_coreranges.push_back(CoreRange(current_start_core, current_end_core));
}

current_end_core = CoreCoord(x, current_start_core.y + current_height - 1);
remaining_cores -= current_height;

result_coreranges.push_back(
CoreRange(CoreCoord(x, current_start_core.y), CoreCoord(x, current_end_core.y)));

current_start_core = CoreCoord(x + 1, subcoregrid.start_coord.y);
current_end_core = current_start_core;
} else {
current_end_core = CoreCoord(x, subcoregrid.end_coord.y);
remaining_cores -= current_height;
}
}

if (current_start_core != current_end_core) {
result_coreranges.push_back(CoreRange(current_start_core, current_end_core));
}
};

// Iterate over subcoregrids and process based on row_wise
for (const auto& subcoregrid : subcoregrids.ranges()) {
if (subcoregrid.contains(start_core)) {
start_core_found = true;
} else {
if (!start_core_found) {
continue;
} else {
current_start_core = subcoregrid.start_coord;
current_end_core = current_start_core;
}
}

if (row_wise) {
process_row_wise(subcoregrid);
} else {
process_col_wise(subcoregrid);
}
}

TT_FATAL(remaining_cores == 0, "Failed to split target number of cores into CoreRangeSet");

return CoreRangeSet(std::move(result_coreranges));
}

std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_t> split_work_to_cores(
const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise) {
ZoneScoped;
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/common/work_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ CoreRangeSet num_cores_to_corerangeset(
CoreRangeSet num_cores_to_corerangeset(
const uint32_t target_num_cores, const CoreCoord grid_size, const bool row_wise = false);

CoreRangeSet num_cores_to_corerangeset_in_subcoregrids(
const CoreCoord start_core,
const uint32_t target_num_cores,
const CoreRangeSet subcoregrids,
const bool row_wise = false);
// This function takes in the core grid size, as well as the number of units of work to divide between the cores
// This function returns the number of cores, the CoreRangeSet of all cores, and then the CoreRangeSet that does
// the greater amount of work, and the CoreRangeSet that does less work if work cannot be evenly divided
Expand Down
6 changes: 6 additions & 0 deletions ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ void py_module(py::module& module) {
"num_cores_to_corerangeset",
py::overload_cast<const uint32_t, const CoreCoord, const bool>(&tt::tt_metal::num_cores_to_corerangeset),
R"doc(Create a CoreRangeSet containing the specified number of cores)doc");

module.def(
"num_cores_to_corerangeset_in_subcoregrids",
py::overload_cast<const CoreCoord, const uint32_t, const CoreRangeSet, const bool>(
&tt::tt_metal::num_cores_to_corerangeset_in_subcoregrids),
R"doc(Create a CoreRangeSet containing the specified number of cores starting from start_core in given subcoregrids)doc");
}

} // namespace core
Expand Down
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def manage_config(name, value):
load_memory_config,
dump_stack_trace_on_segfault,
num_cores_to_corerangeset,
num_cores_to_corerangeset_in_subcoregrids,
)

import ttnn.reflection
Expand Down
17 changes: 17 additions & 0 deletions ttnn/ttnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def num_cores_to_corerangeset(
)


def num_cores_to_corerangeset_in_subcoregrids(
start_core: ttnn.CoreCoord,
target_num_cores: int,
subcoregrids: ttnn.CoreRangeSet,
row_wise: bool = False,
):
"""
Create a CoreRangeSet containing the specified number of cores starting from start_core in given subcoregrids
"""
return ttnn._ttnn.operations.core.num_cores_to_corerangeset_in_subcoregrids(
start_core,
target_num_cores,
subcoregrids,
row_wise,
)


def has_tile_padding(tensor, *, dim=None):
if dim is not None:
rank = tensor.shape.rank
Expand Down

0 comments on commit c4ac298

Please sign in to comment.