diff --git a/tests/tt_eager/python_api_testing/sweep_tests/common.py b/tests/tt_eager/python_api_testing/sweep_tests/common.py index 0f44028361ef..6438f5b2cd81 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/common.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/common.py @@ -668,11 +668,17 @@ def _gen_tt_nn_rmsnorm_shapes(shape): def _gen_tt_nn_bcast_shapes(shape): shape_type = random.randint(0, 2) + second_shape = shape.copy() if shape_type == 0: - second_shape = [1] + second_shape[-2] = 1 + second_shape[-1] = 1 elif shape_type == 1: - second_shape = [shape[-1]] + second_shape[-2] = shape[-2] + second_shape[-1] = 1 + elif shape_type == 2: + second_shape[-2] = 1 + second_shape[-1] = shape[-1] # elif shape_type == 2: # second_shape = [shape[-2], shape[-1]] # elif shape_type == 3: diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp index 005802f66dbd..1d1d77f59168 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp @@ -258,68 +258,75 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) { const CoreCoord& core = cores.at(i); uint32_t Ht_per_core; + + auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y); + auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y); + auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y); + if (core_group_1.core_coord_in_core_ranges(core)) { Ht_per_core = Ht_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { Ht_per_core = Ht_per_core_group_2; } else { - tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(15, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + binary_reader_args[3] = 0; + binary_reader_args[7] = 0; + binary_reader_args[8] = 0; + binary_reader_args[9] = 0; + binary_reader_args[10] = 0; + binary_reader_args[11] = 0; + binary_reader_args[12] = 0; + binary_reader_args[13] = 0; + binary_reader_args[14] = 0; + + bcast_kernel_args[0] = 0; + bcast_kernel_args[1] = 0; + bcast_kernel_args[2] = 0; + + unary_writer_args[3] = 0; + unary_writer_args[4] = 0; + unary_writer_args[5] = 0; + unary_writer_args[7] = 0; + unary_writer_args[8] = 0; continue; } uint32_t num_tensor_tiles_per_core = NC * Ht_per_core * Wt; - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - { - src_dram_buffer_a->address(), // 0 - 0, // 1 - 0, // 2 - num_tensor_tiles_per_core, // 3 - src_dram_buffer_b->address(), // 4 - 0, // 5 - 0, // 6 - num_btensor_tiles, // 7 - num_tensor_tiles_per_core, // 8 - NC, // 9 - Ht_per_core, // 10 - Wt, // 11 - bnc1, // 12 - num_Wtiles_read, // 13 - Ht * Wt, // 14 - }); - - tt_metal::SetRuntimeArgs( - program, - bcast_kernel_id, - core, - { - NC, // B - Ht_per_core, // Ht - Wt // Wt - }); - - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - { - dst_dram_buffer->address(), - 0, - 0, - Ht_per_core, - Wt, - num_Wtiles_read, - 0, - NC, - Ht * Wt, - }); + binary_reader_args[0] = src_dram_buffer_a->address(); + // binary_reader_args[1] = 0; + // binary_reader_args[2] = 0; + binary_reader_args[3] = num_tensor_tiles_per_core; + binary_reader_args[4] = src_dram_buffer_b->address(); + // binary_reader_args[5] = 0; + // binary_reader_args[6] = 0; + binary_reader_args[7] = num_btensor_tiles; + binary_reader_args[8] = num_tensor_tiles_per_core; + binary_reader_args[9] = NC; + binary_reader_args[10] = Ht_per_core; + binary_reader_args[11] = Wt; + binary_reader_args[12] = bnc1; + binary_reader_args[13] = num_Wtiles_read; + binary_reader_args[14] = Ht * Wt; + + bcast_kernel_args[0] = NC; + bcast_kernel_args[1] = Ht_per_core; + bcast_kernel_args[2] = Wt; + + unary_writer_args[0] = dst_dram_buffer->address(); + // unary_writer_args[1] = 0; + // unary_writer_args[2] = 0; + unary_writer_args[3] = Ht_per_core; + unary_writer_args[4] = Wt; + unary_writer_args[5] = num_Wtiles_read; + // unary_writer_args[6] = 0; + unary_writer_args[7] = NC; + unary_writer_args[8] = Ht * Wt; num_Wtiles_read += Ht_per_core * Wt; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp index 996ba3128ec1..cf64b77b250d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp @@ -258,70 +258,79 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments( auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, bcast_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + for (uint32_t i = 0, num_Wtiles_read = 0; i < num_cores_total; i++) { const CoreCoord& core = cores.at(i); uint32_t Wt_per_core; + + auto& binary_reader_args = cached_reader_args.at(core.x).at(core.y); + auto& bcast_kernel_args = cached_eltwise_args.at(core.x).at(core.y); + auto& unary_writer_args = cached_writer_args.at(core.x).at(core.y); + if (core_group_1.core_coord_in_core_ranges(core)) { Wt_per_core = Wt_per_core_group_1; } else if (core_group_2.core_coord_in_core_ranges(core)) { Wt_per_core = Wt_per_core_group_2; } else { - tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, std::vector(16, 0)); - tt_metal::SetRuntimeArgs(program, bcast_kernel_id, core, std::vector(3, 0)); - tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, std::vector(9, 0)); + binary_reader_args[3] = 0; + binary_reader_args[7] = 0; + binary_reader_args[8] = 0; + binary_reader_args[9] = 0; + binary_reader_args[10] = 0; + binary_reader_args[11] = 0; + binary_reader_args[12] = 0; + binary_reader_args[13] = 0; + binary_reader_args[14] = 0; + binary_reader_args[15] = 0; + + bcast_kernel_args[0] = 0; + bcast_kernel_args[1] = 0; + bcast_kernel_args[2] = 0; + + unary_writer_args[3] = 0; + unary_writer_args[4] = 0; + unary_writer_args[5] = 0; + unary_writer_args[7] = 0; + unary_writer_args[8] = 0; continue; } uint32_t num_tensor_tiles_per_core = NC * Ht * Wt_per_core; uint32_t Wt_skip = Wt - Wt_per_core; - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - { - src_dram_buffer_a->address(), // 0 - 0, // 1 - 0, // 2 - num_tensor_tiles_per_core, // 3 - src_dram_buffer_b->address(), // 4 - 0, // 5 - 0, // 6 - num_btensor_tiles, // 7 - num_tensor_tiles_per_core, // 8 - NC, // 9 - Ht, // 10 - Wt_per_core, // 11 - bnc1, // 12 - num_Wtiles_read, // 13 - Ht * Wt, // 14 - Wt_skip, // 15 - }); - - tt_metal::SetRuntimeArgs( - program, - bcast_kernel_id, - core, - { - NC, // B - Ht, // Ht - Wt_per_core // Wt - }); + binary_reader_args[0] = src_dram_buffer_a->address(); + // binary_reader_args[1] = 0; + // binary_reader_args[2] = 0; + binary_reader_args[3] = num_tensor_tiles_per_core; + binary_reader_args[4] = src_dram_buffer_b->address(); + // binary_reader_args[5] = 0; + // binary_reader_args[6] = 0; + binary_reader_args[7] = num_btensor_tiles; + binary_reader_args[8] = num_tensor_tiles_per_core; + binary_reader_args[9] = NC; + binary_reader_args[10] = Ht; + binary_reader_args[11] = Wt_per_core; + binary_reader_args[12] = bnc1; + binary_reader_args[13] = num_Wtiles_read; + binary_reader_args[14] = Ht * Wt; + binary_reader_args[15] = Wt_skip; + + bcast_kernel_args[0] = NC; + bcast_kernel_args[1] = Ht; + bcast_kernel_args[2] = Wt_per_core; + + unary_writer_args[0] = dst_dram_buffer->address(); + // unary_writer_args[1] = 0; + // unary_writer_args[2] = 0; + unary_writer_args[3] = Ht; + unary_writer_args[4] = Wt_per_core; + unary_writer_args[5] = num_Wtiles_read; + unary_writer_args[6] = Wt_skip; + unary_writer_args[7] = NC; + unary_writer_args[8] = Ht * Wt; - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - { - dst_dram_buffer->address(), - 0, - 0, - Ht, - Wt_per_core, - num_Wtiles_read, - Wt_skip, - NC, - Ht * Wt, - }); num_Wtiles_read += Wt_per_core; } }