Skip to content

Commit

Permalink
#9773: Fix PCC issue for bcast unary ops when input is sharded and ou…
Browse files Browse the repository at this point in the history
…tput is interleaved

- Add eltise mul test for sharded input
- Use grid_to_cores to iterate across output cores for eltwise binary
- Update output work distribution for bcast hw to respect input shard orientation when output is interleaved
  • Loading branch information
TT-BrianLiu committed Jul 5, 2024
1 parent 5ee242c commit afec31e
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 19 deletions.
24 changes: 24 additions & 0 deletions tests/ttnn/unit_tests/operations/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ def test_multiply_with_scalar(device, scalar):
assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.parametrize("output_memory_config", [ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
@pytest.mark.parametrize("input_shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR])
@pytest.mark.parametrize("scalar", [3.0, 0.125])
def test_multiply_with_scalar_sharded(device, scalar, input_shard_orientation, output_memory_config):
torch.manual_seed(0)
torch_input_tensor_a = torch.rand(1024 * 32, dtype=torch.bfloat16).reshape(32, 32, 32)
torch_output_tensor = scalar * torch_input_tensor_a

shard_config = ttnn.create_sharded_memory_config(
shape=(32, 32),
core_grid=ttnn.CoreGrid(y=4, x=8),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=input_shard_orientation,
use_height_and_width_as_shard_shape=True,
)
input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, memory_config=shard_config, device=device
)
output = ttnn.mul(input_tensor_a, scalar, memory_config=output_memory_config)
output = ttnn.to_torch(output)

assert_with_pcc(torch_output_tensor, output, 0.9999)


@pytest.mark.skip(reason="Unable to multiply scalar to tensor with int")
# fmt: off
@pytest.mark.parametrize("input_a,scalar", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,17 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create(
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1});

bool row_major = false;
if (shard_spec.has_value()) {
row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR;
}
auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles);
split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

auto src0_buffer = a.buffer();
auto src1_buffer = b.buffer();
Expand Down Expand Up @@ -169,8 +176,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create(
all_device_cores,
tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_compute_defines});

for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t num_tensor_tiles_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
num_tensor_tiles_per_core = num_tiles_per_core_group_1;
Expand Down Expand Up @@ -255,6 +262,7 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a
auto& program = cached_program.program;
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;

auto src_buffer_a = input_tensor_a.buffer();
auto src_dram_buffer_b = input_tensor_b.buffer();
Expand Down Expand Up @@ -291,8 +299,14 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a

uint32_t bnc1 = (bN * bC == 1);

bool row_major = false;
if (shard_spec.has_value()) {
row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR;
}
auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles);
split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

if (shard_spec.has_value()) {
uint32_t num_tiles_per_shard = 0;
Expand All @@ -304,8 +318,8 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a
core_group_2 = CoreRangeSet({});
}

for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t num_tensor_tiles_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
num_tensor_tiles_per_core = num_tiles_per_core_group_1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create(
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1});

constexpr bool row_major = false;
auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, Ht);
split_work_to_cores(compute_with_storage_grid_size, Ht, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

auto src0_buffer = a.buffer();
auto src1_buffer = b.buffer();
Expand Down Expand Up @@ -128,8 +132,8 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create(
all_device_cores,
tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines});

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Ht_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_1;
Expand Down Expand Up @@ -220,6 +224,7 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument

uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;

auto src_dram_buffer_a = input_tensor_a.buffer();
auto src_dram_buffer_b = input_tensor_b.buffer();
Expand Down Expand Up @@ -247,11 +252,14 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument

uint32_t bnc1 = (bN * bC == 1) ? 1 : 0;

constexpr bool row_major = false;
auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, Ht);
split_work_to_cores(compute_with_storage_grid_size, Ht, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Ht_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1});

constexpr bool row_major = false;
auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, Wt);
split_work_to_cores(compute_with_storage_grid_size, Wt, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

auto src0_buffer = a.buffer();
auto src1_buffer = b.buffer();
Expand Down Expand Up @@ -127,8 +131,8 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe
all_device_cores,
tt_metal::ComputeConfig{.compile_args = {}, .defines = bcast_defines});

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Wt_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_1;
Expand Down Expand Up @@ -220,6 +224,7 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments(

uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;

auto src_dram_buffer_a = input_tensor_a.buffer();
auto src_dram_buffer_b = input_tensor_b.buffer();
Expand Down Expand Up @@ -247,11 +252,14 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments(

uint32_t bnc1 = (bN * bC == 1) ? 1 : 0;

constexpr bool row_major = false;
auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] =
split_work_to_cores(compute_with_storage_grid_size, Wt);
split_work_to_cores(compute_with_storage_grid_size, Wt, row_major);

auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) {
const CoreCoord& core = cores.at(i);
uint32_t Wt_per_core;
if (core_group_1.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args(
split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major);
block_cnt_per_core_group_1 = num_tiles_per_core_group_1;
block_cnt_per_core_group_2 = num_tiles_per_core_group_2;
cores = grid_to_cores(num_cores_x * num_cores_y, num_cores_x, num_cores_y, row_major);
cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major);
}

uint32_t g1_numcores = core_group_1.num_cores();
Expand Down

0 comments on commit afec31e

Please sign in to comment.