From afdd15dc9ba3093de682ca9841107531f77771e1 Mon Sep 17 00:00:00 2001 From: Nemanja Grujic <109360083+nemanjagrujic@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:22:01 +0200 Subject: [PATCH] #8865: Optimized ttnn.bcast dispatch times (#12383) #8865: Optimize ttnn.bcast dispatch times --- .../multi_core_h/bcast_op_multi_core_h.cpp | 101 +++++++++-------- .../multi_core_hw/bcast_op_multi_core_hw.cpp | 68 +++++------- .../multi_core_w/bcast_op_multi_core_w.cpp | 104 +++++++++--------- 3 files changed, 132 insertions(+), 141 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_h/bcast_op_multi_core_h.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_h/bcast_op_multi_core_h.cpp index 8661038564e..4c013e66bd8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_h/bcast_op_multi_core_h.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_h/bcast_op_multi_core_h.cpp @@ -222,69 +222,68 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor auto [num_cores, all_cores, core_group_1, core_group_2, Ht_per_core_group_1, Ht_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, Ht); + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++){ CoreCoord core = {i / num_cores_y, i % num_cores_y}; uint32_t Ht_per_core; + + auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y); + auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y); + auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y); + if (core_group_1.core_coord_in_core_ranges(core)) { Ht_per_core = Ht_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { Ht_per_core = Ht_per_core_group_2; } else { - tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(15, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + binary_reader_args[3] = 0; + binary_reader_args[7] = 0; + binary_reader_args[8] = 0; + + bcast_kernel_args[0] = 0; + bcast_kernel_args[1] = 0; + bcast_kernel_args[2] = 0; + + unary_writer_args[3] = 0; + unary_writer_args[4] = 0; + unary_writer_args[6] = 0; + unary_writer_args[7] = 0; continue; } uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt; - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - { - src_dram_buffer_a->address(), // 0 - 0, // 1 - 0, // 2 - num_tensor_tiles_per_core, // 3 - src_dram_buffer_b->address(), // 4 - 0, // 5 - 0, // 6 - num_btensor_tiles, // 7 - num_tensor_tiles_per_core, // 8 - NC, // 9 - Ht_per_core, // 10 - Wt, // 11 - bnc1, // 12 - num_Wtiles_read, // 13 - Ht*Wt, // 14 - } - ); - - tt_metal::SetRuntimeArgs( - program, - bcast_kernel_id, - core, - { - NC, // B - Ht_per_core, // Ht - Wt // Wt - } - ); - - tt_metal::SetRuntimeArgs( - program, unary_writer_kernel_id, core, - { - dst_dram_buffer->address(), - 0, - 0, - Ht_per_core, - Wt, - num_Wtiles_read, - 0, - NC, - Ht*Wt, - } - ); + binary_reader_args[0] = src_dram_buffer_a->address(); + // binary_reader_args[1] = 0; + // binary_reader_args[2] = 0; + binary_reader_args[3] = num_tensor_tiles_per_core; + binary_reader_args[4] = src_dram_buffer_b->address(); + // binary_reader_args[5] = 0; + // binary_reader_args[6] = 0; + binary_reader_args[7] = num_btensor_tiles; + binary_reader_args[8] = num_tensor_tiles_per_core; + binary_reader_args[9] = NC; + binary_reader_args[10] = Ht_per_core; + binary_reader_args[11] = Wt; + binary_reader_args[12] = bnc1; + binary_reader_args[13] = num_Wtiles_read; + binary_reader_args[14] = Ht*Wt; + + bcast_kernel_args[0] = NC; + bcast_kernel_args[1] = Ht_per_core; + bcast_kernel_args[2] = Wt; + + unary_writer_args[0] = dst_dram_buffer->address(); + // unary_writer_args[1] = 0; + // unary_writer_args[2] = 0; + unary_writer_args[3] = Ht_per_core; + unary_writer_args[4] = Wt; + unary_writer_args[5] = num_Wtiles_read; + // unary_writer_args[6] = 0; + unary_writer_args[7] = NC; + unary_writer_args[8] = Ht*Wt; num_Wtiles_read += Ht_per_core * Wt; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_hw/bcast_op_multi_core_hw.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_hw/bcast_op_multi_core_hw.cpp index 89152d88143..b88c38589bb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_hw/bcast_op_multi_core_hw.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_hw/bcast_op_multi_core_hw.cpp @@ -158,7 +158,7 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso num_tensor_tiles_per_core = num_tiles_per_core_group_2; } else { tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(7, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); + tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, {1, 1, 0}); tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(3, 0)); continue; } @@ -255,7 +255,6 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso uint32_t HtWt = Ht * Wt; uint32_t num_tensor_tiles = NC*Ht*Wt; - uint32_t bnc1 = (bN*bC == 1); auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles); @@ -270,54 +269,45 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso core_group_2 = CoreRangeSet({}); } + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_y * num_cores_x; i++){ CoreCoord core = {i / num_cores_y, i % num_cores_y}; uint32_t num_tensor_tiles_per_core; + + auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y); + auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y); + auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y); + if (core_group_1.core_coord_in_core_ranges(core)) { num_tensor_tiles_per_core = num_tiles_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { num_tensor_tiles_per_core = num_tiles_per_core_group_2; } else { - tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(7, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(3, 0)); + binary_reader_args[2] = 0; + bcast_kernel_args[2] = 0; + unary_writer_args[1] = 0; continue; } - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - { - src_buffer_a->address(), // 0 - src_dram_buffer_b->address(), - num_tensor_tiles_per_core, - HtWt, - num_tiles_read / HtWt * HtWt, - num_tiles_read % HtWt, - bnc1 ? 0 : num_tiles_read / HtWt - } - ); - - tt_metal::SetRuntimeArgs( - program, - bcast_kernel_id, - core, - { - 1, // B - 1, // Ht - num_tensor_tiles_per_core // Wt - } - ); - - tt_metal::SetRuntimeArgs( - program, unary_writer_kernel_id, core, - { - dst_buffer->address(), - num_tensor_tiles_per_core, - num_tiles_read, - } - ); + binary_reader_args[0] = src_buffer_a->address(); + binary_reader_args[1] = src_dram_buffer_b->address(); + binary_reader_args[2] = num_tensor_tiles_per_core; + binary_reader_args[3] = HtWt; + binary_reader_args[4] = num_tiles_read / HtWt * HtWt; + binary_reader_args[5] = num_tiles_read % HtWt; + binary_reader_args[6] = bnc1 ? 0 : num_tiles_read / HtWt; + + // bcast_kernel_args[0] = 1; + // bcast_kernel_args[1] = 1; + bcast_kernel_args[2] = num_tensor_tiles_per_core; + + unary_writer_args[0] = dst_buffer->address(); + unary_writer_args[1] = num_tensor_tiles_per_core; + unary_writer_args[2] = num_tiles_read; + num_tiles_read += num_tensor_tiles_per_core; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_w/bcast_op_multi_core_w.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_w/bcast_op_multi_core_w.cpp index b54876bb54d..b899b8c7aad 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_w/bcast_op_multi_core_w.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/multi_core_w/bcast_op_multi_core_w.cpp @@ -223,70 +223,72 @@ operation::ProgramWithCallbacks bcast_multi_core_w(const Tensor &a, const Tensor auto [num_cores, all_cores, core_group_1, core_group_2, Wt_per_core_group_1, Wt_per_core_group_2] = tt_metal::split_work_to_cores(compute_with_storage_grid_size, Wt); + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_y * num_cores_x; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; uint32_t Wt_per_core; + + auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y); + auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y); + auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y); + if (core_group_1.core_coord_in_core_ranges(core)) { Wt_per_core = Wt_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { Wt_per_core = Wt_per_core_group_2; } else { - tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(16, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + binary_reader_args[3] = 0; + binary_reader_args[7] = 0; + binary_reader_args[8] = 0; + + bcast_kernel_args[0] = 0; + bcast_kernel_args[1] = 0; + bcast_kernel_args[2] = 0; + + unary_writer_args[3] = 0; + unary_writer_args[4] = 0; + unary_writer_args[7] = 0; + unary_writer_args[8] = 0; continue; } + uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core; uint32_t Wt_skip = Wt - Wt_per_core; - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - { - src_dram_buffer_a->address(), // 0 - 0, // 1 - 0, // 2 - num_tensor_tiles_per_core, // 3 - src_dram_buffer_b->address(), // 4 - 0, // 5 - 0, // 6 - num_btensor_tiles, // 7 - num_tensor_tiles_per_core, // 8 - NC, // 9 - Ht, // 10 - Wt_per_core, // 11 - bnc1, // 12 - num_Wtiles_read, // 13 - Ht * Wt, // 14 - Wt_skip, // 15 - }); - - tt_metal::SetRuntimeArgs( - program, - bcast_kernel_id, - core, - { - NC, // B - Ht, // Ht - Wt_per_core // Wt - }); - - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - { - dst_dram_buffer->address(), - 0, - 0, - Ht, - Wt_per_core, - num_Wtiles_read, - Wt_skip, - NC, - Ht * Wt, - }); + binary_reader_args[0] = src_dram_buffer_a->address(); + // binary_reader_args[1] = 0; + // binary_reader_args[2] = 0; + binary_reader_args[3] = num_tensor_tiles_per_core; + binary_reader_args[4] = src_dram_buffer_b->address(); + // binary_reader_args[5] = 0; + // binary_reader_args[6] = 0; + binary_reader_args[7] = num_btensor_tiles; + binary_reader_args[8] = num_tensor_tiles_per_core; + binary_reader_args[9] = NC; + binary_reader_args[10] = Ht; + binary_reader_args[11] = Wt_per_core; + binary_reader_args[12] = bnc1; + binary_reader_args[13] = num_Wtiles_read; + binary_reader_args[14] = Ht * Wt; + binary_reader_args[15] = Wt_skip; + + bcast_kernel_args[0] = NC; + bcast_kernel_args[1] = Ht; + bcast_kernel_args[2] = Wt_per_core; + + unary_writer_args[0] = dst_dram_buffer->address(); + // unary_writer_args[1] = 0; + // unary_writer_args[2] = 0; + unary_writer_args[3] = Ht; + unary_writer_args[4] = Wt_per_core; + unary_writer_args[5] = num_Wtiles_read; + unary_writer_args[6] = Wt_skip; + unary_writer_args[7] = NC; + unary_writer_args[8] = Ht * Wt; + num_Wtiles_read += Wt_per_core; } };