Skip to content

Commit

Permalink
#0: Update eltwise binary to support sharding on arbitrary cores on a…
Browse files Browse the repository at this point in the history
…n arbitrary sub-device grid
  • Loading branch information
tt-aho committed Dec 16, 2024
1 parent 388e187 commit 51f7ef6
Show file tree
Hide file tree
Showing 11 changed files with 608 additions and 616 deletions.
70 changes: 70 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,73 @@ def test_01_volume_tensors(device, data, memory_config):
c = ttnn.to_torch(ttnn_c).reshape((-1))

assert c.tolist() == c_golden


@pytest.mark.parametrize("input_a_sharded", [True, False])
@pytest.mark.parametrize("input_b_sharded", [True, False])
@pytest.mark.parametrize("out_sharded", [True, False])
@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR])
def test_add_with_sub_devices(device, input_a_sharded, input_b_sharded, out_sharded, shard_orientation):
torch.manual_seed(0)
shape = (1, 1, 1024, 1024)
torch_input_tensor_a = torch.rand(shape, dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shape, dtype=torch.bfloat16)

if shard_orientation == ttnn.ShardOrientation.ROW_MAJOR:
shard_shape = (1024 // 8, 1024)
else:
shard_shape = (1024, 1024 // 8)

core_range_set = ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(2, 2), ttnn.CoreCoord(3, 3)),
ttnn.CoreRange(ttnn.CoreCoord(1, 1), ttnn.CoreCoord(1, 1)),
ttnn.CoreRange(ttnn.CoreCoord(4, 0), ttnn.CoreCoord(4, 2)),
]
)

height_sharded_mem_config = ttnn.create_sharded_memory_config(
shape=shard_shape,
core_grid=core_range_set,
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=shard_orientation,
use_height_and_width_as_shard_shape=True,
)

torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

if input_a_sharded:
input_tensor_a = ttnn.to_memory_config(input_tensor_a, height_sharded_mem_config)

input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

if input_b_sharded:
input_tensor_b = ttnn.to_memory_config(input_tensor_b, height_sharded_mem_config)

if out_sharded:
out_mem_config = height_sharded_mem_config
else:
out_mem_config = ttnn.DRAM_MEMORY_CONFIG

sub_device = ttnn.SubDevice(
[
ttnn.CoreRangeSet(
[
ttnn.CoreRange(ttnn.CoreCoord(1, 1), ttnn.CoreCoord(4, 4)),
ttnn.CoreRange(ttnn.CoreCoord(4, 0), ttnn.CoreCoord(5, 0)),
]
)
]
)
sub_device_manager_id = device.create_sub_device_manager([sub_device], 0)
device.load_sub_device_manager(sub_device_manager_id)
output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=out_mem_config)
output_tensor = ttnn.to_torch(output_tensor)
assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988
assert output_tensor.shape == shape
24 changes: 24 additions & 0 deletions tt_metal/common/core_coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ std::vector<CoreCoord> grid_to_cores_with_noop(
return cores;
}

// Noop cores are appended at the end with no guarantees on ordering
std::vector<CoreCoord> grid_to_cores_with_noop(
const CoreRangeSet& used_cores, const CoreRangeSet& all_cores, const bool row_wise) {
ZoneScoped;
TT_ASSERT(all_cores.contains(used_cores));
// Most likely a lot of optimizations to do here
// Implemented this way for simplicity for now
std::vector<CoreCoord> cores;
cores.reserve(all_cores.num_cores());
cores = corerange_to_cores(used_cores, std::nullopt, row_wise);
std::vector<CoreCoord> all_cores_vec = corerange_to_cores(all_cores, std::nullopt, row_wise);
auto sorted_used_cores = cores;
std::sort(sorted_used_cores.begin(), sorted_used_cores.end());
std::sort(all_cores_vec.begin(), all_cores_vec.end());
std::set_difference(
all_cores_vec.begin(),
all_cores_vec.end(),
sorted_used_cores.begin(),
sorted_used_cores.end(),
std::back_inserter(cores));

return cores;
}

std::vector<CoreCoord> corerange_to_cores(const CoreRangeSet& crs, std::optional<uint32_t> max_cores, bool row_wise) {
std::vector<CoreCoord> all_cores;
auto num_cores = crs.num_cores();
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/common/core_coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ std::vector<CoreCoord> grid_to_cores_with_noop(
const uint32_t grid_size_y,
const bool row_wise = false);

// Noop cores are appended at the end with no guarantees on ordering
std::vector<CoreCoord> grid_to_cores_with_noop(
const CoreRangeSet& used_cores, const CoreRangeSet& all_cores, const bool row_wise = false);

std::vector<CoreCoord> corerange_to_cores(
const CoreRangeSet& crs, std::optional<uint32_t> max_cores = std::nullopt, bool row_wise = false);

Expand Down
180 changes: 130 additions & 50 deletions tt_metal/common/work_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,66 +268,146 @@ CoreRangeSet num_cores_to_corerangeset_in_subcoregrids(
std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_t> split_work_to_cores(
const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise) {
ZoneScoped;
uint32_t num_cores_x = grid_size.x, num_cores_y = grid_size.y;
auto target_num_cores = std::min(units_to_divide, num_cores_x * num_cores_y);
CoreRangeSet all_cores = num_cores_to_corerangeset(target_num_cores, grid_size, row_wise);
if (units_to_divide == 0) {
return std::make_tuple(0, CoreRangeSet(), CoreRangeSet(), CoreRangeSet(), 0, 0);
}
uint32_t num_cores_x = grid_size.x, num_cores_y = grid_size.y, max_num_cores = num_cores_x * num_cores_y,
target_num_cores;
CoreRangeSet all_cores;
if (units_to_divide >= max_num_cores) {
target_num_cores = max_num_cores;
all_cores = CoreRangeSet(CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}));
} else {
target_num_cores = units_to_divide;
all_cores = num_cores_to_corerangeset(target_num_cores, grid_size, row_wise);
}

CoreRangeSet core_group_1;
CoreRangeSet core_group_2;
uint32_t units_per_core_group_1 = target_num_cores == 0 ? 0 : units_to_divide / target_num_cores;
uint32_t units_per_core_group_1 = units_to_divide / target_num_cores;
uint32_t units_per_core_group_2 = 0;
uint32_t num_cores_with_more_work = units_to_divide % target_num_cores;
// Evenly divided units to all target cores
if (target_num_cores == 0 || units_to_divide % target_num_cores == 0) {
if (units_to_divide % target_num_cores == 0) {
core_group_1 = all_cores;
// Uneven division of units across cores
// This case should only be hit when there are more units of work than a full grid of cores
// which is implicitly assumed in the following logic
} else {
}
// Uneven division of units across cores
// This case should only be hit when there are more units of work than a full grid of cores
// which is implicitly assumed in the following logic
else {
// Group of cores that do more work
core_group_1 = num_cores_to_corerangeset(units_to_divide % target_num_cores, grid_size, row_wise);
const auto& last_block_group_1 = (*core_group_1.ranges().rbegin());
const auto& last_block_all_cores = (*all_cores.ranges().rbegin());
uint32_t num_core_group_1_cores = num_cores_with_more_work;
uint32_t num_core_group_2_cores = target_num_cores - num_core_group_1_cores;
core_group_1 = num_cores_to_corerangeset(num_core_group_1_cores, grid_size, row_wise);
const auto& last_core_group_1 = (*core_group_1.ranges().rbegin()).end_coord;
if (row_wise) {
// Case where only the last row is divided between core group 1 and 2
if (last_block_group_1.end_coord.y == last_block_all_cores.end_coord.y &&
last_block_group_1.end_coord.x != last_block_all_cores.end_coord.x) {
CoreRange leftover_block(
CoreCoord(last_block_group_1.end_coord.x + 1, last_block_group_1.end_coord.y),
last_block_all_cores.end_coord);
core_group_2 = CoreRangeSet(leftover_block);
} else {
std::vector<CoreRange> core_group_2_set;
// Case where a middle row is divided between core group 1 and 2
if (last_block_group_1.end_coord.x != num_cores_x - 1) {
core_group_2_set.emplace_back(
CoreCoord(last_block_group_1.end_coord.x + 1, last_block_group_1.end_coord.y),
CoreCoord(num_cores_x - 1, last_block_group_1.end_coord.y));
}
// Remaining rows of cores that does less work
core_group_2_set.emplace_back(
CoreCoord(0, last_block_group_1.end_coord.y + 1), last_block_all_cores.end_coord);
core_group_2 = CoreRangeSet(std::move(core_group_2_set));
// Start in the same row
if (last_core_group_1.x != num_cores_x - 1) {
core_group_2 = num_cores_to_corerangeset(
{last_core_group_1.x + 1, last_core_group_1.y}, num_core_group_2_cores, grid_size, row_wise);
}
// Start in the next row
else {
core_group_2 = num_cores_to_corerangeset(
{0, last_core_group_1.y + 1}, num_core_group_2_cores, grid_size, row_wise);
}
} else {
// Case where only the last column is divided between core group 1 and 2
if (last_block_group_1.end_coord.x == last_block_all_cores.end_coord.x &&
last_block_group_1.end_coord.y != last_block_all_cores.end_coord.y) {
CoreRange leftover_block(
CoreCoord(last_block_group_1.end_coord.x, last_block_group_1.end_coord.y + 1),
last_block_all_cores.end_coord);
core_group_2 = CoreRangeSet(leftover_block);
} else {
std::vector<CoreRange> core_group_2_set;
// Case where a middle column is divided between core group 1 and 2
if (last_block_group_1.end_coord.y != num_cores_y - 1) {
core_group_2_set.emplace_back(
CoreCoord(last_block_group_1.end_coord.x, last_block_group_1.end_coord.y + 1),
CoreCoord(last_block_group_1.end_coord.x, num_cores_y - 1));
}
// Remaining columns of cores that does less work
core_group_2_set.emplace_back(
CoreCoord(last_block_group_1.end_coord.x + 1, 0), last_block_all_cores.end_coord);
core_group_2 = CoreRangeSet(std::move(core_group_2_set));
// Start in the same column
if (last_core_group_1.y != num_cores_y - 1) {
core_group_2 = num_cores_to_corerangeset(
{last_core_group_1.x, last_core_group_1.y + 1}, num_core_group_2_cores, grid_size, row_wise);
}
// Start in the next column
else {
core_group_2 = num_cores_to_corerangeset(
{last_core_group_1.x + 1, 0}, num_core_group_2_cores, grid_size, row_wise);
}
}
units_per_core_group_2 = units_per_core_group_1;
units_per_core_group_1++;
}

return std::make_tuple(
target_num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2);
}

std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_t> split_work_to_cores(
const CoreRangeSet& core_grid, const uint32_t units_to_divide, const bool row_wise) {
ZoneScoped;
if (units_to_divide == 0) {
return std::make_tuple(0, CoreRangeSet(), CoreRangeSet(), CoreRangeSet(), 0, 0);
}
uint32_t max_num_cores = core_grid.num_cores(), target_num_cores;
TT_FATAL(max_num_cores > 0, "Core grid must contain at least one core");
auto start_core = core_grid.ranges().begin()->start_coord;
CoreRangeSet all_cores;
if (units_to_divide >= max_num_cores) {
target_num_cores = max_num_cores;
all_cores = core_grid;
} else {
target_num_cores = units_to_divide;
all_cores = num_cores_to_corerangeset_in_subcoregrids(start_core, target_num_cores, core_grid, row_wise);
}

CoreRangeSet core_group_1;
CoreRangeSet core_group_2;
uint32_t units_per_core_group_1 = units_to_divide / target_num_cores;
uint32_t units_per_core_group_2 = 0;
uint32_t num_cores_with_more_work = units_to_divide % target_num_cores;
// Evenly divided units to all target cores
if (target_num_cores == 0 || num_cores_with_more_work == 0) {
core_group_1 = all_cores;
}
// Uneven division of units across cores
// This case should only be hit when there are more units of work than a full grid of cores
// which is implicitly assumed in the following logic
else {
// Group of cores that do more work
uint32_t num_core_group_1_cores = num_cores_with_more_work;
uint32_t num_core_group_2_cores = target_num_cores - num_core_group_1_cores;
core_group_1 =
num_cores_to_corerangeset_in_subcoregrids(start_core, num_core_group_1_cores, core_grid, row_wise);
const auto& last_core_group_1 = (*core_group_1.ranges().rbegin()).end_coord;
const auto& core_grid_ranges = core_grid.ranges();
uint32_t num_cores_counted = 0, i;
for (i = 0; i < core_grid_ranges.size(); i++) {
num_cores_counted += core_grid_ranges[i].size();
if (num_cores_counted >= num_core_group_1_cores) {
break;
}
}
const auto& range_containing_last_core_group_1 = core_grid_ranges[i];
// Start in next core range
if (last_core_group_1 == range_containing_last_core_group_1.end_coord) {
core_group_2 = num_cores_to_corerangeset_in_subcoregrids(
core_grid_ranges[i + 1].start_coord, num_core_group_2_cores, core_grid, row_wise);
} else if (row_wise) {
// Start in the same row
if (last_core_group_1.x != range_containing_last_core_group_1.end_coord.x) {
core_group_2 = num_cores_to_corerangeset_in_subcoregrids(
{last_core_group_1.x + 1, last_core_group_1.y}, num_core_group_2_cores, core_grid, row_wise);
}
// Start in the next row
else {
core_group_2 = num_cores_to_corerangeset_in_subcoregrids(
{range_containing_last_core_group_1.start_coord.x, last_core_group_1.y + 1},
num_core_group_2_cores,
core_grid,
row_wise);
}
} else {
// Start in the same column
if (last_core_group_1.y != range_containing_last_core_group_1.end_coord.y) {
core_group_2 = num_cores_to_corerangeset_in_subcoregrids(
{last_core_group_1.x, last_core_group_1.y + 1}, num_core_group_2_cores, core_grid, row_wise);
}
// Start in the next column
else {
core_group_2 = num_cores_to_corerangeset_in_subcoregrids(
{last_core_group_1.x + 1, range_containing_last_core_group_1.end_coord.y},
num_core_group_2_cores,
core_grid,
row_wise);
}
}
units_per_core_group_2 = units_per_core_group_1;
Expand Down
3 changes: 3 additions & 0 deletions tt_metal/common/work_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,8 @@ CoreRangeSet num_cores_to_corerangeset_in_subcoregrids(
std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_t> split_work_to_cores(
const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise = false);

std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_t> split_work_to_cores(
const CoreRangeSet& core_grid, const uint32_t units_to_divide, const bool row_wise = false);

} // namespace tt_metal
} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,55 @@ BinaryDeviceOperation::invoke(
output_dtype.value() == optional_output_tensor.value().get_dtype(),
"If both output dtype and output tensor provided dtype should match");
}
CoreRangeSet worker_grid;
// We assert all shard specs are the same if sharded, so only need to check the first shard spec
// This will create the worker grid based on the used sub-devices when sharded
// Otherwise this will use all cores of the sub-devices
// TODO #13655: Note that the current program ingfrastructure still only supports a single sub-device per program
if (input_tensor_a_arg.is_sharded()) {
const auto& input_grid = input_tensor_a_arg.shard_spec().value().grid;
auto device = input_tensor_a_arg.device();
for (const auto& sub_device_id : device->get_sub_device_ids()) {
const auto& sub_device_workers = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id);
if (sub_device_workers.intersects(input_grid)) {
worker_grid = worker_grid.merge(sub_device_workers);
}
}
} else if (input_tensor_b_arg.is_sharded()) {
const auto& input_grid = input_tensor_b_arg.shard_spec().value().grid;
auto device = input_tensor_b_arg.device();
for (const auto& sub_device_id : device->get_sub_device_ids()) {
const auto& sub_device_workers = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id);
if (sub_device_workers.intersects(input_grid)) {
worker_grid = worker_grid.merge(sub_device_workers);
}
}
} else if (optional_output_tensor.has_value() && optional_output_tensor->is_sharded()) {
const auto& output_grid = optional_output_tensor->shard_spec().value().grid;
auto device = optional_output_tensor->device();
for (const auto& sub_device_id : device->get_sub_device_ids()) {
const auto& sub_device_workers = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id);
if (sub_device_workers.intersects(output_grid)) {
worker_grid = worker_grid.merge(sub_device_workers);
}
}
} else {
auto device = input_tensor_a_arg.device();
for (const auto& sub_device_id : device->get_sub_device_ids()) {
const auto& sub_device_workers = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id);
worker_grid = worker_grid.merge(sub_device_workers);
}
}

return {
operation_attributes_t{
binary_op_type,
std::move(activations),
std::move(input_tensor_a_activation),
std::nullopt,
memory_config.value_or(input_tensor_a_arg.memory_config()),
memory_config.value_or(optional_output_tensor.has_value() ? optional_output_tensor->memory_config() : input_tensor_a_arg.memory_config()),
output_dtype.value_or(input_tensor_a_arg.get_dtype()),
std::move(worker_grid),
std::nullopt},
tensor_args_t{input_tensor_a_arg, input_tensor_b_arg, optional_output_tensor}};
}
Expand All @@ -380,6 +420,8 @@ BinaryDeviceOperation::invoke(
"If both output dtype and output tensor provided dtype should match");
}

// Currently unused/unsupported
CoreRangeSet worker_grid = CoreRangeSet();
return {
operation_attributes_t{
binary_op_type,
Expand All @@ -388,6 +430,7 @@ BinaryDeviceOperation::invoke(
scalar,
memory_config.value_or(input_tensor_a_arg.memory_config()),
output_dtype.value_or(input_tensor_a_arg.get_dtype()),
std::move(worker_grid),
std::nullopt},
tensor_args_t{input_tensor_a_arg, std::nullopt, optional_output_tensor}};
}
Expand Down
Loading

0 comments on commit 51f7ef6

Please sign in to comment.