Skip to content

Commit

Permalink
#8865: Optimized ttnn.bcast dispatch times (#12383)
Browse files Browse the repository at this point in the history
#8865: Optimize ttnn.bcast dispatch times
  • Loading branch information
nemanjagrujic authored Sep 12, 2024
1 parent 3964adc commit afdd15d
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,69 +222,68 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor

auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, Ht);

auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id);
auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id);
auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++){
CoreCoord core = {i / num_cores_y, i % num_cores_y};
uint32_t Ht_per_core;

auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y);
auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y);
auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y);

if (core_group_1.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Ht_per_core = Ht_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(15, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
binary_reader_args[3] = 0;
binary_reader_args[7] = 0;
binary_reader_args[8] = 0;

bcast_kernel_args[0] = 0;
bcast_kernel_args[1] = 0;
bcast_kernel_args[2] = 0;

unary_writer_args[3] = 0;
unary_writer_args[4] = 0;
unary_writer_args[6] = 0;
unary_writer_args[7] = 0;
continue;
}
uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt;

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
core,
{
src_dram_buffer_a->address(), // 0
0, // 1
0, // 2
num_tensor_tiles_per_core, // 3
src_dram_buffer_b->address(), // 4
0, // 5
0, // 6
num_btensor_tiles, // 7
num_tensor_tiles_per_core, // 8
NC, // 9
Ht_per_core, // 10
Wt, // 11
bnc1, // 12
num_Wtiles_read, // 13
Ht*Wt, // 14
}
);

tt_metal::SetRuntimeArgs(
program,
bcast_kernel_id,
core,
{
NC, // B
Ht_per_core, // Ht
Wt // Wt
}
);

tt_metal::SetRuntimeArgs(
program, unary_writer_kernel_id, core,
{
dst_dram_buffer->address(),
0,
0,
Ht_per_core,
Wt,
num_Wtiles_read,
0,
NC,
Ht*Wt,
}
);
binary_reader_args[0] = src_dram_buffer_a->address();
// binary_reader_args[1] = 0;
// binary_reader_args[2] = 0;
binary_reader_args[3] = num_tensor_tiles_per_core;
binary_reader_args[4] = src_dram_buffer_b->address();
// binary_reader_args[5] = 0;
// binary_reader_args[6] = 0;
binary_reader_args[7] = num_btensor_tiles;
binary_reader_args[8] = num_tensor_tiles_per_core;
binary_reader_args[9] = NC;
binary_reader_args[10] = Ht_per_core;
binary_reader_args[11] = Wt;
binary_reader_args[12] = bnc1;
binary_reader_args[13] = num_Wtiles_read;
binary_reader_args[14] = Ht*Wt;

bcast_kernel_args[0] = NC;
bcast_kernel_args[1] = Ht_per_core;
bcast_kernel_args[2] = Wt;

unary_writer_args[0] = dst_dram_buffer->address();
// unary_writer_args[1] = 0;
// unary_writer_args[2] = 0;
unary_writer_args[3] = Ht_per_core;
unary_writer_args[4] = Wt;
unary_writer_args[5] = num_Wtiles_read;
// unary_writer_args[6] = 0;
unary_writer_args[7] = NC;
unary_writer_args[8] = Ht*Wt;

num_Wtiles_read += Ht_per_core * Wt;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
num_tensor_tiles_per_core = num_tiles_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(7, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, {1, 1, 0});
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(3, 0));
continue;
}
Expand Down Expand Up @@ -255,7 +255,6 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
uint32_t HtWt = Ht * Wt;

uint32_t num_tensor_tiles = NC*Ht*Wt;

uint32_t bnc1 = (bN*bC == 1);

auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles);
Expand All @@ -270,54 +269,45 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
core_group_2 = CoreRangeSet({});
}

auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id);
auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id);
auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id);

for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++){
CoreCoord core = {i / num_cores_y, i % num_cores_y};
uint32_t num_tensor_tiles_per_core;

auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y);
auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y);
auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y);

if (core_group_1.core_coord_in_core_ranges(core)) {
num_tensor_tiles_per_core = num_tiles_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
num_tensor_tiles_per_core = num_tiles_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(7, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(3, 0));
binary_reader_args[2] = 0;
bcast_kernel_args[2] = 0;
unary_writer_args[1] = 0;
continue;
}

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
core,
{
src_buffer_a->address(), // 0
src_dram_buffer_b->address(),
num_tensor_tiles_per_core,
HtWt,
num_tiles_read / HtWt * HtWt,
num_tiles_read % HtWt,
bnc1 ? 0 : num_tiles_read / HtWt
}
);

tt_metal::SetRuntimeArgs(
program,
bcast_kernel_id,
core,
{
1, // B
1, // Ht
num_tensor_tiles_per_core // Wt
}
);

tt_metal::SetRuntimeArgs(
program, unary_writer_kernel_id, core,
{
dst_buffer->address(),
num_tensor_tiles_per_core,
num_tiles_read,
}
);
binary_reader_args[0] = src_buffer_a->address();
binary_reader_args[1] = src_dram_buffer_b->address();
binary_reader_args[2] = num_tensor_tiles_per_core;
binary_reader_args[3] = HtWt;
binary_reader_args[4] = num_tiles_read / HtWt * HtWt;
binary_reader_args[5] = num_tiles_read % HtWt;
binary_reader_args[6] = bnc1 ? 0 : num_tiles_read / HtWt;

// bcast_kernel_args[0] = 1;
// bcast_kernel_args[1] = 1;
bcast_kernel_args[2] = num_tensor_tiles_per_core;

unary_writer_args[0] = dst_buffer->address();
unary_writer_args[1] = num_tensor_tiles_per_core;
unary_writer_args[2] = num_tiles_read;

num_tiles_read += num_tensor_tiles_per_core;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,70 +223,72 @@ operation::ProgramWithCallbacks bcast_multi_core_w(const Tensor &a, const Tensor
auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] =
tt_metal::split_work_to_cores(compute_with_storage_grid_size, Wt);

auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id);
auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id);
auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id);

for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) {
CoreCoord core = {i / num_cores_y, i % num_cores_y};
uint32_t Wt_per_core;

auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y);
auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y);
auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y);

if (core_group_1.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_1;
} else if (core_group_2.core_coord_in_core_ranges(core)) {
Wt_per_core = Wt_per_core_group_2;
} else {
tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector<uint32_t>(16, 0));
tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector<uint32_t>(3, 0));
tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector<uint32_t>(9, 0));
binary_reader_args[3] = 0;
binary_reader_args[7] = 0;
binary_reader_args[8] = 0;

bcast_kernel_args[0] = 0;
bcast_kernel_args[1] = 0;
bcast_kernel_args[2] = 0;

unary_writer_args[3] = 0;
unary_writer_args[4] = 0;
unary_writer_args[7] = 0;
unary_writer_args[8] = 0;
continue;
}

uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core;
uint32_t Wt_skip = Wt - Wt_per_core;

tt_metal::SetRuntimeArgs(
program,
binary_reader_kernel_id,
core,
{
src_dram_buffer_a->address(), // 0
0, // 1
0, // 2
num_tensor_tiles_per_core, // 3
src_dram_buffer_b->address(), // 4
0, // 5
0, // 6
num_btensor_tiles, // 7
num_tensor_tiles_per_core, // 8
NC, // 9
Ht, // 10
Wt_per_core, // 11
bnc1, // 12
num_Wtiles_read, // 13
Ht * Wt, // 14
Wt_skip, // 15
});

tt_metal::SetRuntimeArgs(
program,
bcast_kernel_id,
core,
{
NC, // B
Ht, // Ht
Wt_per_core // Wt
});

tt_metal::SetRuntimeArgs(
program,
unary_writer_kernel_id,
core,
{
dst_dram_buffer->address(),
0,
0,
Ht,
Wt_per_core,
num_Wtiles_read,
Wt_skip,
NC,
Ht * Wt,
});
binary_reader_args[0] = src_dram_buffer_a->address();
// binary_reader_args[1] = 0;
// binary_reader_args[2] = 0;
binary_reader_args[3] = num_tensor_tiles_per_core;
binary_reader_args[4] = src_dram_buffer_b->address();
// binary_reader_args[5] = 0;
// binary_reader_args[6] = 0;
binary_reader_args[7] = num_btensor_tiles;
binary_reader_args[8] = num_tensor_tiles_per_core;
binary_reader_args[9] = NC;
binary_reader_args[10] = Ht;
binary_reader_args[11] = Wt_per_core;
binary_reader_args[12] = bnc1;
binary_reader_args[13] = num_Wtiles_read;
binary_reader_args[14] = Ht * Wt;
binary_reader_args[15] = Wt_skip;

bcast_kernel_args[0] = NC;
bcast_kernel_args[1] = Ht;
bcast_kernel_args[2] = Wt_per_core;

unary_writer_args[0] = dst_dram_buffer->address();
// unary_writer_args[1] = 0;
// unary_writer_args[2] = 0;
unary_writer_args[3] = Ht;
unary_writer_args[4] = Wt_per_core;
unary_writer_args[5] = num_Wtiles_read;
unary_writer_args[6] = Wt_skip;
unary_writer_args[7] = NC;
unary_writer_args[8] = Ht * Wt;

num_Wtiles_read += Wt_per_core;
}
};
Expand Down

0 comments on commit afdd15d

Please sign in to comment.