From 965129940a83d92676d0797a9607a8a62147f699 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Fri, 13 Dec 2024 21:38:18 +0000 Subject: [PATCH] #0: Update eltwise binary to support sharding on arbitrary cores on an arbitrary sub-device grid --- .../unit_tests/operations/eltwise/test_add.py | 70 ++++ tt_metal/common/core_coord.cpp | 24 ++ tt_metal/common/core_coord.hpp | 4 + tt_metal/common/work_split.cpp | 180 +++++++--- tt_metal/common/work_split.hpp | 3 + tt_metal/impl/device/device.cpp | 6 +- tt_metal/impl/device/device.hpp | 1 + .../impl/sub_device/sub_device_manager.cpp | 2 + .../impl/sub_device/sub_device_manager.hpp | 1 + .../binary/device/binary_device_operation.cpp | 41 ++- .../binary/device/binary_device_operation.hpp | 5 +- ...lement_wise_multi_core_program_factory.cpp | 283 +--------------- ...ement_wise_multi_core_sfpu_pgm_factory.cpp | 287 +--------------- ...wise_multi_core_program_factory_common.hpp | 317 ++++++++++++++++++ .../reader_bcast_h_sharded_optimised.cpp | 6 +- 15 files changed, 613 insertions(+), 617 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_add.py b/tests/ttnn/unit_tests/operations/eltwise/test_add.py index 9344e59ccf0a..f29d5b4783a7 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_add.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_add.py @@ -515,3 +515,73 @@ def test_01_volume_tensors(device, data, memory_config): c = ttnn.to_torch(ttnn_c).reshape((-1)) assert c.tolist() == c_golden + + +@pytest.mark.parametrize("input_a_sharded", [True, False]) +@pytest.mark.parametrize("input_b_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR]) +def test_add_with_sub_devices(device, input_a_sharded, input_b_sharded, out_sharded, shard_orientation): + torch.manual_seed(0) + shape = (1, 1, 1024, 1024) + torch_input_tensor_a = torch.rand(shape, dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand(shape, dtype=torch.bfloat16) + + if shard_orientation == ttnn.ShardOrientation.ROW_MAJOR: + shard_shape = (1024 // 8, 1024) + else: + shard_shape = (1024, 1024 // 8) + + core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(2, 2), ttnn.CoreCoord(3, 3)), + ttnn.CoreRange(ttnn.CoreCoord(1, 1), ttnn.CoreCoord(1, 1)), + ttnn.CoreRange(ttnn.CoreCoord(4, 0), ttnn.CoreCoord(4, 2)), + ] + ) + + height_sharded_mem_config = ttnn.create_sharded_memory_config( + shape=shard_shape, + core_grid=core_range_set, + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=shard_orientation, + use_height_and_width_as_shard_shape=True, + ) + + torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + if input_a_sharded: + input_tensor_a = ttnn.to_memory_config(input_tensor_a, height_sharded_mem_config) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG + ) + + if input_b_sharded: + input_tensor_b = ttnn.to_memory_config(input_tensor_b, height_sharded_mem_config) + + if out_sharded: + out_mem_config = height_sharded_mem_config + else: + out_mem_config = ttnn.DRAM_MEMORY_CONFIG + + sub_device = ttnn.SubDevice( + [ + ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(1, 1), ttnn.CoreCoord(4, 4)), + ttnn.CoreRange(ttnn.CoreCoord(4, 0), ttnn.CoreCoord(5, 0)), + ] + ) + ] + ) + sub_device_manager_id = device.create_sub_device_manager([sub_device], 0) + device.load_sub_device_manager(sub_device_manager_id) + output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=out_mem_config) + output_tensor = ttnn.to_torch(output_tensor) + assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988 + assert output_tensor.shape == shape diff --git a/tt_metal/common/core_coord.cpp b/tt_metal/common/core_coord.cpp index 4e9a6123477f..c8281b40ac50 100644 --- a/tt_metal/common/core_coord.cpp +++ b/tt_metal/common/core_coord.cpp @@ -522,6 +522,30 @@ std::vector grid_to_cores_with_noop( return cores; } +// Noop cores are appended at the end with no guarantees on ordering +std::vector grid_to_cores_with_noop( + const CoreRangeSet& used_cores, const CoreRangeSet& all_cores, const bool row_wise) { + ZoneScoped; + TT_ASSERT(all_cores.contains(used_cores)); + // Most likely a lot of optimizations to do here + // Implemented this way for simplicity for now + std::vector cores; + cores.reserve(all_cores.num_cores()); + cores = corerange_to_cores(used_cores, std::nullopt, row_wise); + std::vector all_cores_vec = corerange_to_cores(all_cores, std::nullopt, row_wise); + auto sorted_used_cores = cores; + std::sort(sorted_used_cores.begin(), sorted_used_cores.end()); + std::sort(all_cores_vec.begin(), all_cores_vec.end()); + std::set_difference( + all_cores_vec.begin(), + all_cores_vec.end(), + sorted_used_cores.begin(), + sorted_used_cores.end(), + std::back_inserter(cores)); + + return cores; +} + std::vector corerange_to_cores(const CoreRangeSet& crs, std::optional max_cores, bool row_wise) { std::vector all_cores; auto num_cores = crs.num_cores(); diff --git a/tt_metal/common/core_coord.hpp b/tt_metal/common/core_coord.hpp index 93d55ac39f43..38dcdc225cb3 100644 --- a/tt_metal/common/core_coord.hpp +++ b/tt_metal/common/core_coord.hpp @@ -191,6 +191,10 @@ std::vector grid_to_cores_with_noop( const uint32_t grid_size_y, const bool row_wise = false); +// Noop cores are appended at the end with no guarantees on ordering +std::vector grid_to_cores_with_noop( + const CoreRangeSet& used_cores, const CoreRangeSet& all_cores, const bool row_wise = false); + std::vector corerange_to_cores( const CoreRangeSet& crs, std::optional max_cores = std::nullopt, bool row_wise = false); diff --git a/tt_metal/common/work_split.cpp b/tt_metal/common/work_split.cpp index ba687d9d3dab..ad04e169232b 100644 --- a/tt_metal/common/work_split.cpp +++ b/tt_metal/common/work_split.cpp @@ -268,66 +268,146 @@ CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( std::tuple split_work_to_cores( const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise) { ZoneScoped; - uint32_t num_cores_x = grid_size.x, num_cores_y = grid_size.y; - auto target_num_cores = std::min(units_to_divide, num_cores_x * num_cores_y); - CoreRangeSet all_cores = num_cores_to_corerangeset(target_num_cores, grid_size, row_wise); + if (units_to_divide == 0) { + return std::make_tuple(0, CoreRangeSet(), CoreRangeSet(), CoreRangeSet(), 0, 0); + } + uint32_t num_cores_x = grid_size.x, num_cores_y = grid_size.y, max_num_cores = num_cores_x * num_cores_y, + target_num_cores; + CoreRangeSet all_cores; + if (units_to_divide >= max_num_cores) { + target_num_cores = max_num_cores; + all_cores = CoreRangeSet(CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1})); + } else { + target_num_cores = units_to_divide; + all_cores = num_cores_to_corerangeset(target_num_cores, grid_size, row_wise); + } CoreRangeSet core_group_1; CoreRangeSet core_group_2; - uint32_t units_per_core_group_1 = target_num_cores == 0 ? 0 : units_to_divide / target_num_cores; + uint32_t units_per_core_group_1 = units_to_divide / target_num_cores; uint32_t units_per_core_group_2 = 0; + uint32_t num_cores_with_more_work = units_to_divide % target_num_cores; // Evenly divided units to all target cores - if (target_num_cores == 0 || units_to_divide % target_num_cores == 0) { + if (units_to_divide % target_num_cores == 0) { core_group_1 = all_cores; - // Uneven division of units across cores - // This case should only be hit when there are more units of work than a full grid of cores - // which is implicitly assumed in the following logic - } else { + } + // Uneven division of units across cores + // This case should only be hit when there are more units of work than a full grid of cores + // which is implicitly assumed in the following logic + else { // Group of cores that do more work - core_group_1 = num_cores_to_corerangeset(units_to_divide % target_num_cores, grid_size, row_wise); - const auto& last_block_group_1 = (*core_group_1.ranges().rbegin()); - const auto& last_block_all_cores = (*all_cores.ranges().rbegin()); + uint32_t num_core_group_1_cores = num_cores_with_more_work; + uint32_t num_core_group_2_cores = target_num_cores - num_core_group_1_cores; + core_group_1 = num_cores_to_corerangeset(num_core_group_1_cores, grid_size, row_wise); + const auto& last_core_group_1 = (*core_group_1.ranges().rbegin()).end_coord; if (row_wise) { - // Case where only the last row is divided between core group 1 and 2 - if (last_block_group_1.end_coord.y == last_block_all_cores.end_coord.y && - last_block_group_1.end_coord.x != last_block_all_cores.end_coord.x) { - CoreRange leftover_block( - CoreCoord(last_block_group_1.end_coord.x + 1, last_block_group_1.end_coord.y), - last_block_all_cores.end_coord); - core_group_2 = CoreRangeSet(leftover_block); - } else { - std::vector core_group_2_set; - // Case where a middle row is divided between core group 1 and 2 - if (last_block_group_1.end_coord.x != num_cores_x - 1) { - core_group_2_set.emplace_back( - CoreCoord(last_block_group_1.end_coord.x + 1, last_block_group_1.end_coord.y), - CoreCoord(num_cores_x - 1, last_block_group_1.end_coord.y)); - } - // Remaining rows of cores that does less work - core_group_2_set.emplace_back( - CoreCoord(0, last_block_group_1.end_coord.y + 1), last_block_all_cores.end_coord); - core_group_2 = CoreRangeSet(std::move(core_group_2_set)); + // Start in the same row + if (last_core_group_1.x != num_cores_x - 1) { + core_group_2 = num_cores_to_corerangeset( + {last_core_group_1.x + 1, last_core_group_1.y}, num_core_group_2_cores, grid_size, row_wise); + } + // Start in the next row + else { + core_group_2 = num_cores_to_corerangeset( + {0, last_core_group_1.y + 1}, num_core_group_2_cores, grid_size, row_wise); } } else { - // Case where only the last column is divided between core group 1 and 2 - if (last_block_group_1.end_coord.x == last_block_all_cores.end_coord.x && - last_block_group_1.end_coord.y != last_block_all_cores.end_coord.y) { - CoreRange leftover_block( - CoreCoord(last_block_group_1.end_coord.x, last_block_group_1.end_coord.y + 1), - last_block_all_cores.end_coord); - core_group_2 = CoreRangeSet(leftover_block); - } else { - std::vector core_group_2_set; - // Case where a middle column is divided between core group 1 and 2 - if (last_block_group_1.end_coord.y != num_cores_y - 1) { - core_group_2_set.emplace_back( - CoreCoord(last_block_group_1.end_coord.x, last_block_group_1.end_coord.y + 1), - CoreCoord(last_block_group_1.end_coord.x, num_cores_y - 1)); - } - // Remaining columns of cores that does less work - core_group_2_set.emplace_back( - CoreCoord(last_block_group_1.end_coord.x + 1, 0), last_block_all_cores.end_coord); - core_group_2 = CoreRangeSet(std::move(core_group_2_set)); + // Start in the same column + if (last_core_group_1.y != num_cores_y - 1) { + core_group_2 = num_cores_to_corerangeset( + {last_core_group_1.x, last_core_group_1.y + 1}, num_core_group_2_cores, grid_size, row_wise); + } + // Start in the next column + else { + core_group_2 = num_cores_to_corerangeset( + {last_core_group_1.x + 1, 0}, num_core_group_2_cores, grid_size, row_wise); + } + } + units_per_core_group_2 = units_per_core_group_1; + units_per_core_group_1++; + } + + return std::make_tuple( + target_num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2); +} + +std::tuple split_work_to_cores( + const CoreRangeSet& core_grid, const uint32_t units_to_divide, const bool row_wise) { + ZoneScoped; + if (units_to_divide == 0) { + return std::make_tuple(0, CoreRangeSet(), CoreRangeSet(), CoreRangeSet(), 0, 0); + } + uint32_t max_num_cores = core_grid.num_cores(), target_num_cores; + TT_FATAL(max_num_cores > 0, "Core grid must contain at least one core"); + auto start_core = core_grid.ranges().begin()->start_coord; + CoreRangeSet all_cores; + if (units_to_divide >= max_num_cores) { + target_num_cores = max_num_cores; + all_cores = core_grid; + } else { + target_num_cores = units_to_divide; + all_cores = num_cores_to_corerangeset_in_subcoregrids(start_core, target_num_cores, core_grid, row_wise); + } + + CoreRangeSet core_group_1; + CoreRangeSet core_group_2; + uint32_t units_per_core_group_1 = units_to_divide / target_num_cores; + uint32_t units_per_core_group_2 = 0; + uint32_t num_cores_with_more_work = units_to_divide % target_num_cores; + // Evenly divided units to all target cores + if (target_num_cores == 0 || num_cores_with_more_work == 0) { + core_group_1 = all_cores; + } + // Uneven division of units across cores + // This case should only be hit when there are more units of work than a full grid of cores + // which is implicitly assumed in the following logic + else { + // Group of cores that do more work + uint32_t num_core_group_1_cores = num_cores_with_more_work; + uint32_t num_core_group_2_cores = target_num_cores - num_core_group_1_cores; + core_group_1 = + num_cores_to_corerangeset_in_subcoregrids(start_core, num_core_group_1_cores, core_grid, row_wise); + const auto& last_core_group_1 = (*core_group_1.ranges().rbegin()).end_coord; + const auto& core_grid_ranges = core_grid.ranges(); + uint32_t num_cores_counted = 0, i; + for (i = 0; i < core_grid_ranges.size(); i++) { + num_cores_counted += core_grid_ranges[i].size(); + if (num_cores_counted >= num_core_group_1_cores) { + break; + } + } + const auto& range_containing_last_core_group_1 = core_grid_ranges[i]; + // Start in next core range + if (last_core_group_1 == range_containing_last_core_group_1.end_coord) { + core_group_2 = num_cores_to_corerangeset_in_subcoregrids( + core_grid_ranges[i + 1].start_coord, num_core_group_2_cores, core_grid, row_wise); + } else if (row_wise) { + // Start in the same row + if (last_core_group_1.x != range_containing_last_core_group_1.end_coord.x) { + core_group_2 = num_cores_to_corerangeset_in_subcoregrids( + {last_core_group_1.x + 1, last_core_group_1.y}, num_core_group_2_cores, core_grid, row_wise); + } + // Start in the next row + else { + core_group_2 = num_cores_to_corerangeset_in_subcoregrids( + {range_containing_last_core_group_1.start_coord.x, last_core_group_1.y + 1}, + num_core_group_2_cores, + core_grid, + row_wise); + } + } else { + // Start in the same column + if (last_core_group_1.y != range_containing_last_core_group_1.end_coord.y) { + core_group_2 = num_cores_to_corerangeset_in_subcoregrids( + {last_core_group_1.x, last_core_group_1.y + 1}, num_core_group_2_cores, core_grid, row_wise); + } + // Start in the next column + else { + core_group_2 = num_cores_to_corerangeset_in_subcoregrids( + {last_core_group_1.x + 1, range_containing_last_core_group_1.end_coord.y}, + num_core_group_2_cores, + core_grid, + row_wise); } } units_per_core_group_2 = units_per_core_group_1; diff --git a/tt_metal/common/work_split.hpp b/tt_metal/common/work_split.hpp index 2b5ae0ecb9d8..f024f016e655 100644 --- a/tt_metal/common/work_split.hpp +++ b/tt_metal/common/work_split.hpp @@ -53,5 +53,8 @@ CoreRangeSet num_cores_to_corerangeset_in_subcoregrids( std::tuple split_work_to_cores( const CoreCoord grid_size, const uint32_t units_to_divide, const bool row_wise = false); +std::tuple split_work_to_cores( + const CoreRangeSet& core_grid, const uint32_t units_to_divide, const bool row_wise = false); + } // namespace tt_metal } // namespace tt diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 815973247225..ed496630e5c0 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -3766,10 +3766,14 @@ void Device::remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id) this->sub_device_managers_.erase(sub_device_manager); } -const std::vector &Device::get_sub_device_ids() const { +const std::vector& Device::get_sub_device_ids() const { return this->active_sub_device_manager_->get_sub_device_ids(); } +const std::vector& Device::get_sub_devices() const { + return this->active_sub_device_manager_->get_sub_devices(); +} + std::vector Device::get_optimal_dram_bank_to_logical_worker_assignment() { // Top level function that users (ex: Op Writers) can use to assign Tensix Worker cores // as DRAM readers or writers. Returns logical coordinates of optimally placed workers. diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 7fbe89bd853c..21a5d3c9f77e 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -389,6 +389,7 @@ class Device { void clear_loaded_sub_device_manager(); void remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id); const std::vector &get_sub_device_ids() const; + const std::vector &get_sub_devices() const; // TODO #15944: Temporary api until migration to actual fabric is complete std::tuple create_sub_device_manager_with_fabric(tt::stl::Span sub_devices, DeviceAddr local_l1_size); diff --git a/tt_metal/impl/sub_device/sub_device_manager.cpp b/tt_metal/impl/sub_device/sub_device_manager.cpp index 2c4706590c84..7c9bf77862c0 100644 --- a/tt_metal/impl/sub_device/sub_device_manager.cpp +++ b/tt_metal/impl/sub_device/sub_device_manager.cpp @@ -77,6 +77,8 @@ uint8_t SubDeviceManager::num_sub_devices() const { return sub_devices_.size(); const std::vector& SubDeviceManager::get_sub_device_ids() const { return sub_device_ids_; } +const std::vector& SubDeviceManager::get_sub_devices() const { return sub_devices_; } + const SubDevice& SubDeviceManager::sub_device(SubDeviceId sub_device_id) const { auto sub_device_index = this->get_sub_device_index(sub_device_id); return sub_devices_[sub_device_index]; diff --git a/tt_metal/impl/sub_device/sub_device_manager.hpp b/tt_metal/impl/sub_device/sub_device_manager.hpp index 356555dfff9b..eef523aeeef4 100644 --- a/tt_metal/impl/sub_device/sub_device_manager.hpp +++ b/tt_metal/impl/sub_device/sub_device_manager.hpp @@ -46,6 +46,7 @@ class SubDeviceManager { ~SubDeviceManager(); const std::vector& get_sub_device_ids() const; + const std::vector& get_sub_devices() const; const SubDevice& sub_device(SubDeviceId sub_device_id) const; const vector_memcpy_aligned& noc_mcast_unicast_data() const; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index ce524ac4ae6c..d29651cd60af 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -351,6 +351,41 @@ BinaryDeviceOperation::invoke( output_dtype.value() == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match"); } + CoreRangeSet worker_grid; + // We assert all shard specs are the same if sharded, so only need to check the first shard spec + // This will create the worker grid based on the used sub-devices when sharded + // Otherwise this will use all cores of the sub-devices + // TODO #13655: Note that the current program ingfrastructure still only supports a single sub-device per program + if (input_tensor_a_arg.is_sharded()) { + const auto& input_grid = input_tensor_a_arg.shard_spec().value().grid; + for (const auto& sub_device : input_tensor_a_arg.device()->get_sub_devices()) { + const auto& sub_device_workers = sub_device.cores(HalProgrammableCoreType::TENSIX); + if (sub_device_workers.intersects(input_grid)) { + worker_grid = worker_grid.merge(sub_device_workers); + } + } + } else if (input_tensor_b_arg.is_sharded()) { + const auto& input_grid = input_tensor_b_arg.shard_spec().value().grid; + for (const auto& sub_device : input_tensor_b_arg.device()->get_sub_devices()) { + const auto& sub_device_workers = sub_device.cores(HalProgrammableCoreType::TENSIX); + if (sub_device_workers.intersects(input_grid)) { + worker_grid = worker_grid.merge(sub_device_workers); + } + } + } else if (optional_output_tensor.has_value() && optional_output_tensor->is_sharded()) { + const auto& output_grid = optional_output_tensor->shard_spec().value().grid; + for (const auto& sub_device : optional_output_tensor->device()->get_sub_devices()) { + const auto& sub_device_workers = sub_device.cores(HalProgrammableCoreType::TENSIX); + if (sub_device_workers.intersects(output_grid)) { + worker_grid = worker_grid.merge(sub_device_workers); + } + } + } else { + for (const auto& sub_device : input_tensor_a_arg.device()->get_sub_devices()) { + const auto& sub_device_workers = sub_device.cores(HalProgrammableCoreType::TENSIX); + worker_grid = worker_grid.merge(sub_device_workers); + } + } return { operation_attributes_t{ @@ -358,8 +393,9 @@ BinaryDeviceOperation::invoke( std::move(activations), std::move(input_tensor_a_activation), std::nullopt, - memory_config.value_or(input_tensor_a_arg.memory_config()), + memory_config.value_or(optional_output_tensor.has_value() ? optional_output_tensor->memory_config() : input_tensor_a_arg.memory_config()), output_dtype.value_or(input_tensor_a_arg.get_dtype()), + worker_grid, std::nullopt}, tensor_args_t{input_tensor_a_arg, input_tensor_b_arg, optional_output_tensor}}; } @@ -380,6 +416,8 @@ BinaryDeviceOperation::invoke( "If both output dtype and output tensor provided dtype should match"); } + // Currently unused/unsupported + CoreRangeSet worker_grid = CoreRangeSet(); return { operation_attributes_t{ binary_op_type, @@ -388,6 +426,7 @@ BinaryDeviceOperation::invoke( scalar, memory_config.value_or(input_tensor_a_arg.memory_config()), output_dtype.value_or(input_tensor_a_arg.get_dtype()), + worker_grid, std::nullopt}, tensor_args_t{input_tensor_a_arg, std::nullopt, optional_output_tensor}}; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index 11a77a206e9f..bef676471ba5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -34,6 +34,7 @@ struct BinaryDeviceOperation { const std::optional scalar; const MemoryConfig memory_config; const DataType dtype; + const CoreRangeSet worker_grid; std::optional compute_kernel_config; tt::stl::hash::hash_t to_hash() const { @@ -58,7 +59,7 @@ struct BinaryDeviceOperation { CBHandle cb_src0; CBHandle cb_src1; CBHandle cb_output; - CoreCoord compute_with_storage_grid_size; + CoreRangeSet all_device_cores; uint32_t src0_single_tile_size; uint32_t src1_single_tile_size; uint32_t dst_single_tile_size; @@ -85,7 +86,7 @@ struct BinaryDeviceOperation { CBHandle cb_src0; CBHandle cb_src1; CBHandle cb_output; - CoreCoord compute_with_storage_grid_size; + CoreRangeSet all_device_cores; uint32_t src0_single_tile_size; uint32_t src1_single_tile_size; uint32_t dst_single_tile_size; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp index ceef285e7172..438ff62cb135 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp @@ -5,6 +5,7 @@ #include #include "binary_device_operation.hpp" +#include "ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" #include "tt_metal/common/work_split.hpp" @@ -15,276 +16,6 @@ namespace ttnn::operations::binary { -template -inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( - Program& program, - const Tensor& a, - const Tensor& b, - const Tensor& output, - const KernelHandle binary_reader_kernel_id, - const KernelHandle unary_writer_kernel_id, - const KernelHandle eltwise_binary_kernel_id, - const CBHandle cb_src0, - const CBHandle cb_src1, - const CBHandle cb_output, - const CoreCoord compute_with_storage_grid_size, - const uint32_t src0_single_tile_size, - const uint32_t src1_single_tile_size, - const uint32_t dst_single_tile_size) { - using namespace tt; - using namespace tt::tt_metal; - using namespace tt::constants; - - auto src_buffer_a = a.buffer(); - auto src_buffer_b = b.buffer(); - auto dst_buffer = output.buffer(); - - CoreRangeSet all_cores, core_group_1, core_group_2; - - std::optional shard_spec = std::nullopt; - std::optional sharded_layout = std::nullopt; - bool src0_sharded = a.memory_config().is_sharded(); - bool src1_sharded = b.memory_config().is_sharded(); - bool out_sharded = output.memory_config().is_sharded(); - - bool block_or_width_sharded = false; - if (src0_sharded) { - shard_spec = a.shard_spec().value(); - block_or_width_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = a.memory_config().memory_layout; - } else if (src1_sharded) { - shard_spec = b.shard_spec().value(); - block_or_width_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = b.memory_config().memory_layout; - } else if (out_sharded) { - shard_spec = output.shard_spec().value(); - block_or_width_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = output.memory_config().memory_layout; - } - - uint32_t num_tiles = a.volume() / TILE_HW; - - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2; - uint32_t num_cores_total = num_cores_x * num_cores_y; - - uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; - - uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; - - bool row_major; - uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, - last_unpadded_block_width = 0; - CoreCoord end_core; - std::vector cores; - - if (shard_spec.has_value()) { - all_cores = shard_spec.value().grid; - num_cores = all_cores.num_cores(); - core_group_1 = all_cores; - core_group_2 = CoreRangeSet(); - num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; - num_tiles_per_core_group_2 = 0; - block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); - max_block_size = block_size_per_core_group_1; - - block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; - row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; - block_height = shard_spec.value().shape[0] / TILE_HEIGHT; - block_width = shard_spec.value().shape[1] / TILE_WIDTH; - if (block_or_width_sharded) { - block_size = block_width * block_height; - end_core = (*shard_spec.value().grid.ranges().begin()).end_coord; - output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; - last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); - last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); - } - auto bbox = core_group_1.bounding_box(); - cores = grid_to_cores_with_noop(bbox.end_coord.x, bbox.end_coord.y, num_cores_x, num_cores_y, row_major); - } else { - row_major = true; - std::tie( - num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); - block_cnt_per_core_group_1 = num_tiles_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2; - cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); - } - - uint32_t g1_numcores = core_group_1.num_cores(); - uint32_t g2_numcores = core_group_2.num_cores(); - - std::vector> binary_reader_args; - std::vector> eltwise_binary_args; - std::vector> unary_writer_args; - if constexpr (initialize_args) { - binary_reader_args = {cores.size(), std::vector(7)}; - eltwise_binary_args = {cores.size(), std::vector(2)}; - if (block_or_width_sharded and not out_sharded) { - unary_writer_args = {cores.size(), std::vector(7)}; - } else { - unary_writer_args = {cores.size(), std::vector(3)}; - } - } - - auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); - auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); - auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); - - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i) { - const CoreCoord& core = cores.at(i); - uint32_t num_tiles_per_core = 0; - uint32_t block_cnt_per_core = 0; - uint32_t block_size_per_core = 0; - uint32_t num_shardes_per_height = 0; - uint32_t num_shardes_per_width = 0; - uint32_t start_id = 0; - if (shard_spec.has_value()) { - if (sharded_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED) { - num_shardes_per_height = num_cores; - num_shardes_per_width = 1; - } else if (sharded_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { - num_shardes_per_width = num_cores; - num_shardes_per_height = 1; - } else { // block sharded - auto bbox = core_group_1.bounding_box(); - if (shard_spec.value().orientation == ShardOrientation::ROW_MAJOR) { - num_shardes_per_height = bbox.end_coord.y - bbox.start_coord.y + 1; - num_shardes_per_width = bbox.end_coord.x - bbox.start_coord.x + 1; - } else { - num_shardes_per_height = bbox.end_coord.x - bbox.start_coord.x + 1; - num_shardes_per_width = bbox.end_coord.y - bbox.start_coord.y + 1; - } - } - start_id = (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width; - } else { - start_id = num_tiles_read; - } - - if (i < g1_numcores) { - num_tiles_per_core = num_tiles_per_core_group_1; - block_cnt_per_core = block_cnt_per_core_group_1; - block_size_per_core = block_size_per_core_group_1; - } else if (i < num_cores) { - num_tiles_per_core = num_tiles_per_core_group_2; - block_cnt_per_core = block_cnt_per_core_group_2; - block_size_per_core = block_size_per_core_group_2; - } else { - // Zero out non-working cores RT args. Only necessary in override - // since initialization pushes zero vectors to unused cores. - if constexpr (!initialize_args) { - auto& reader_args = cached_reader_args.at(core.x).at(core.y); - reader_args[2] = 0; - auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); - eltwise_args[0] = 0; - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[1] = 0; - } - continue; - } - if constexpr (initialize_args) { - binary_reader_args[i] = { - src_buffer_a->address(), - src_buffer_b->address(), - num_tiles_per_core, - start_id, - block_height, - block_width, - num_shardes_per_width, - num_shardes_per_width}; - eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; - } else { - auto& reader_args = cached_reader_args.at(core.x).at(core.y); - reader_args[0] = src_buffer_a->address(); - reader_args[1] = src_buffer_b->address(); - reader_args[2] = num_tiles_per_core; - reader_args[3] = start_id; - reader_args[4] = block_height; - reader_args[5] = block_width; - reader_args[6] = num_shardes_per_width; - auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); - eltwise_args[0] = block_cnt_per_core; - eltwise_args[1] = block_size_per_core; - } - if (block_or_width_sharded and not out_sharded) { - uint32_t unpadded_block_height = block_height; - uint32_t unpadded_block_width = block_width; - if (row_major) { - if (core.x == end_core.x) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.y == end_core.y) { - unpadded_block_height = last_unpadded_block_height; - } - } else { - if (core.y == end_core.y) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.x == end_core.x) { - unpadded_block_height = last_unpadded_block_height; - } - } - if constexpr (initialize_args) { - unary_writer_args[i] = { - dst_buffer->address(), - block_height, - block_width, - unpadded_block_height, - unpadded_block_width, - output_width, - block_size, - (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width, - 0}; - } else { - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[0] = dst_buffer->address(); - writer_args[1] = block_height; - writer_args[2] = block_width; - writer_args[3] = unpadded_block_height; - writer_args[4] = unpadded_block_width; - writer_args[5] = output_width; - writer_args[6] = block_size; - writer_args[7] = (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width; - writer_args[8] = 0; - } - } else { - if constexpr (initialize_args) { - unary_writer_args[i] = {dst_buffer->address(), num_tiles_per_core, num_tiles_read}; - } else { - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[0] = dst_buffer->address(); - writer_args[1] = num_tiles_per_core; - writer_args[2] = num_tiles_read; - } - } - num_tiles_read += num_tiles_per_core; - } - - if constexpr (initialize_args) { - SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); - SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); - SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); - } - - if (src0_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_src0, *src_buffer_a, num_tiles_per_core_group_1 * src0_single_tile_size); - } - if (src1_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_src1, *src_buffer_b, num_tiles_per_core_group_1 * src1_single_tile_size); - } - if (out_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_output, *dst_buffer, num_tiles_per_core_group_1 * dst_single_tile_size); - } -} BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperation::ElementWiseMultiCore::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -324,10 +55,6 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat bool src1_sharded = b->memory_config().is_sharded(); bool out_sharded = output.memory_config().is_sharded(); - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - bool block_or_width_sharded = false; if (src0_sharded) { @@ -350,7 +77,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat tt_metal::Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + const auto& all_device_cores = operation_attributes.worker_grid; uint32_t src0_cb_index = tt::CBIndex::c_0; uint32_t num_input_tiles = src0_sharded ? num_tiles_per_shard : 2 * max_block_size; @@ -465,7 +192,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat cb_src0, cb_src1, cb_output, - compute_with_storage_grid_size, + all_device_cores, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size); @@ -478,7 +205,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat cb_src0, cb_src1, cb_output, - compute_with_storage_grid_size, + all_device_cores, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size}}; @@ -506,7 +233,7 @@ void BinaryDeviceOperation::ElementWiseMultiCore::override_runtime_arguments( shared_variables.cb_src0, shared_variables.cb_src1, shared_variables.cb_output, - shared_variables.compute_with_storage_grid_size, + shared_variables.all_device_cores, shared_variables.src0_single_tile_size, shared_variables.src1_single_tile_size, shared_variables.dst_single_tile_size); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_sfpu_pgm_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_sfpu_pgm_factory.cpp index f23f9b7fe608..a362e5ee9309 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_sfpu_pgm_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_sfpu_pgm_factory.cpp @@ -5,6 +5,7 @@ #include #include "binary_device_operation.hpp" +#include "ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" #include "tt_metal/common/work_split.hpp" @@ -15,276 +16,6 @@ namespace ttnn::operations::binary { -template -inline __attribute__((always_inline)) void set_eltwise_binary_sfpu_runtime_args( - Program& program, - const Tensor& a, - const Tensor& b, - const Tensor& output, - const KernelHandle binary_reader_kernel_id, - const KernelHandle unary_writer_kernel_id, - const KernelHandle eltwise_binary_kernel_id, - const CBHandle cb_src0, - const CBHandle cb_src1, - const CBHandle cb_output, - const CoreCoord compute_with_storage_grid_size, - const uint32_t src0_single_tile_size, - const uint32_t src1_single_tile_size, - const uint32_t dst_single_tile_size) { - using namespace tt; - using namespace tt::tt_metal; - using namespace tt::constants; - - auto src_buffer_a = a.buffer(); - auto src_buffer_b = b.buffer(); - auto dst_buffer = output.buffer(); - - CoreRangeSet all_cores, core_group_1, core_group_2; - - std::optional shard_spec = std::nullopt; - std::optional sharded_layout = std::nullopt; - bool src0_sharded = a.memory_config().is_sharded(); - bool src1_sharded = b.memory_config().is_sharded(); - bool out_sharded = output.memory_config().is_sharded(); - - bool block_or_width_sharded = false; - if (src0_sharded) { - shard_spec = a.shard_spec().value(); - block_or_width_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = a.memory_config().memory_layout; - } else if (src1_sharded) { - shard_spec = b.shard_spec().value(); - block_or_width_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = b.memory_config().memory_layout; - } else if (out_sharded) { - shard_spec = output.shard_spec().value(); - block_or_width_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; - sharded_layout = output.memory_config().memory_layout; - } - - uint32_t num_tiles = a.volume() / TILE_HW; - - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2; - uint32_t num_cores_total = num_cores_x * num_cores_y; - - uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; - - uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; - - bool row_major; - uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, - last_unpadded_block_width = 0; - CoreCoord end_core; - std::vector cores; - - if (shard_spec.has_value()) { - all_cores = shard_spec.value().grid; - num_cores = all_cores.num_cores(); - core_group_1 = all_cores; - core_group_2 = CoreRangeSet(); - num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; - num_tiles_per_core_group_2 = 0; - block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); - max_block_size = block_size_per_core_group_1; - - block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; - row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; - block_height = shard_spec.value().shape[0] / TILE_HEIGHT; - block_width = shard_spec.value().shape[1] / TILE_WIDTH; - if (block_or_width_sharded) { - block_size = block_width * block_height; - end_core = (*shard_spec.value().grid.ranges().begin()).end_coord; - output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; - last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); - last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); - } - auto bbox = core_group_1.bounding_box(); - cores = grid_to_cores_with_noop(bbox.end_coord.x, bbox.end_coord.y, num_cores_x, num_cores_y, row_major); - } else { - row_major = true; - std::tie( - num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); - block_cnt_per_core_group_1 = num_tiles_per_core_group_1; - block_cnt_per_core_group_2 = num_tiles_per_core_group_2; - cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); - } - - uint32_t g1_numcores = core_group_1.num_cores(); - uint32_t g2_numcores = core_group_2.num_cores(); - - std::vector> binary_reader_args; - std::vector> eltwise_binary_args; - std::vector> unary_writer_args; - if constexpr (initialize_args) { - binary_reader_args = {cores.size(), std::vector(7)}; - eltwise_binary_args = {cores.size(), std::vector(2)}; - if (block_or_width_sharded and not out_sharded) { - unary_writer_args = {cores.size(), std::vector(7)}; - } else { - unary_writer_args = {cores.size(), std::vector(3)}; - } - } - - auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); - auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); - auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); - - for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i) { - const CoreCoord& core = cores.at(i); - uint32_t num_tiles_per_core = 0; - uint32_t block_cnt_per_core = 0; - uint32_t block_size_per_core = 0; - uint32_t num_shardes_per_height = 0; - uint32_t num_shardes_per_width = 0; - uint32_t start_id = 0; - if (shard_spec.has_value()) { - if (sharded_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED) { - num_shardes_per_height = num_cores; - num_shardes_per_width = 1; - } else if (sharded_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { - num_shardes_per_width = num_cores; - num_shardes_per_height = 1; - } else { // block sharded - auto bbox = core_group_1.bounding_box(); - if (shard_spec.value().orientation == ShardOrientation::ROW_MAJOR) { - num_shardes_per_height = bbox.end_coord.y - bbox.start_coord.y + 1; - num_shardes_per_width = bbox.end_coord.x - bbox.start_coord.x + 1; - } else { - num_shardes_per_height = bbox.end_coord.x - bbox.start_coord.x + 1; - num_shardes_per_width = bbox.end_coord.y - bbox.start_coord.y + 1; - } - } - start_id = (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width; - } else { - start_id = num_tiles_read; - } - - if (i < g1_numcores) { - num_tiles_per_core = num_tiles_per_core_group_1; - block_cnt_per_core = block_cnt_per_core_group_1; - block_size_per_core = block_size_per_core_group_1; - } else if (i < num_cores) { - num_tiles_per_core = num_tiles_per_core_group_2; - block_cnt_per_core = block_cnt_per_core_group_2; - block_size_per_core = block_size_per_core_group_2; - } else { - // Zero out non-working cores RT args. Only necessary in override - // since initialization pushes zero vectors to unused cores. - if constexpr (!initialize_args) { - auto& reader_args = cached_reader_args.at(core.x).at(core.y); - reader_args[2] = 0; - auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); - eltwise_args[0] = 0; - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[1] = 0; - } - continue; - } - if constexpr (initialize_args) { - binary_reader_args[i] = { - src_buffer_a->address(), - src_buffer_b->address(), - num_tiles_per_core, - start_id, - block_height, - block_width, - num_shardes_per_width, - num_shardes_per_width}; - eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; - } else { - auto& reader_args = cached_reader_args.at(core.x).at(core.y); - reader_args[0] = src_buffer_a->address(); - reader_args[1] = src_buffer_b->address(); - reader_args[2] = num_tiles_per_core; - reader_args[3] = start_id; - reader_args[4] = block_height; - reader_args[5] = block_width; - reader_args[6] = num_shardes_per_width; - auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); - eltwise_args[0] = block_cnt_per_core; - eltwise_args[1] = block_size_per_core; - } - if (block_or_width_sharded and not out_sharded) { - uint32_t unpadded_block_height = block_height; - uint32_t unpadded_block_width = block_width; - if (row_major) { - if (core.x == end_core.x) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.y == end_core.y) { - unpadded_block_height = last_unpadded_block_height; - } - } else { - if (core.y == end_core.y) { - unpadded_block_width = last_unpadded_block_width; - } - if (core.x == end_core.x) { - unpadded_block_height = last_unpadded_block_height; - } - } - if constexpr (initialize_args) { - unary_writer_args[i] = { - dst_buffer->address(), - block_height, - block_width, - unpadded_block_height, - unpadded_block_width, - output_width, - block_size, - (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width, - 0}; - } else { - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[0] = dst_buffer->address(); - writer_args[1] = block_height; - writer_args[2] = block_width; - writer_args[3] = unpadded_block_height; - writer_args[4] = unpadded_block_width; - writer_args[5] = output_width; - writer_args[6] = block_size; - writer_args[7] = (i / num_shardes_per_width) * (block_height * block_width * num_shardes_per_width) + - (i % num_shardes_per_width) * block_width; - writer_args[8] = 0; - } - } else { - if constexpr (initialize_args) { - unary_writer_args[i] = {dst_buffer->address(), num_tiles_per_core, num_tiles_read}; - } else { - auto& writer_args = cached_writer_args.at(core.x).at(core.y); - writer_args[0] = dst_buffer->address(); - writer_args[1] = num_tiles_per_core; - writer_args[2] = num_tiles_read; - } - } - num_tiles_read += num_tiles_per_core; - } - - if constexpr (initialize_args) { - SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); - SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); - SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); - } - - if (src0_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_src0, *src_buffer_a, num_tiles_per_core_group_1 * src0_single_tile_size); - } - if (src1_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_src1, *src_buffer_b, num_tiles_per_core_group_1 * src1_single_tile_size); - } - if (out_sharded) { - UpdateDynamicCircularBufferAddressAndTotalSize( - program, cb_output, *dst_buffer, num_tiles_per_core_group_1 * dst_single_tile_size); - } -} BinaryDeviceOperation::ElementWiseMultiCoreSfpu::cached_program_t BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( const operation_attributes_t& operation_attributes, @@ -325,10 +56,6 @@ BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( bool src1_sharded = b->memory_config().is_sharded(); bool out_sharded = output.memory_config().is_sharded(); - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - bool block_or_width_sharded = false; if (src0_sharded) { @@ -351,7 +78,7 @@ BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( tt_metal::Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + const auto& all_device_cores = operation_attributes.worker_grid; uint32_t src0_cb_index = tt::CBIndex::c_0; uint32_t num_input_tiles = src0_sharded ? num_tiles_per_shard : 2 * max_block_size; @@ -460,7 +187,7 @@ BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( .unpack_to_dest_mode = unpack_to_dest_mode, .defines = eltwise_defines}); - set_eltwise_binary_sfpu_runtime_args( + set_eltwise_binary_runtime_args( program, a, *b, @@ -471,7 +198,7 @@ BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( cb_src0, cb_src1, cb_output, - compute_with_storage_grid_size, + all_device_cores, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size); @@ -484,7 +211,7 @@ BinaryDeviceOperation::ElementWiseMultiCoreSfpu::create( cb_src0, cb_src1, cb_output, - compute_with_storage_grid_size, + all_device_cores, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size}}; @@ -501,7 +228,7 @@ void BinaryDeviceOperation::ElementWiseMultiCoreSfpu::override_runtime_arguments const auto& shared_variables = cached_program.shared_variables; - set_eltwise_binary_sfpu_runtime_args( + set_eltwise_binary_runtime_args( cached_program.program, input_tensor_a, *input_tensor_b, @@ -512,7 +239,7 @@ void BinaryDeviceOperation::ElementWiseMultiCoreSfpu::override_runtime_arguments shared_variables.cb_src0, shared_variables.cb_src1, shared_variables.cb_output, - shared_variables.compute_with_storage_grid_size, + shared_variables.all_device_cores, shared_variables.src0_single_tile_size, shared_variables.src1_single_tile_size, shared_variables.dst_single_tile_size); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp new file mode 100644 index 000000000000..0b0d1366c7ca --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/eltwise_multi_core_program_factory_common.hpp @@ -0,0 +1,317 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "binary_device_operation.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" + +#include "tt_metal/common/work_split.hpp" + +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +namespace ttnn::operations::binary { + +template +inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( + Program& program, + const Tensor& a, + const Tensor& b, + const Tensor& output, + const KernelHandle binary_reader_kernel_id, + const KernelHandle unary_writer_kernel_id, + const KernelHandle eltwise_binary_kernel_id, + const CBHandle cb_src0, + const CBHandle cb_src1, + const CBHandle cb_output, + const CoreRangeSet& all_device_cores, + const uint32_t src0_single_tile_size, + const uint32_t src1_single_tile_size, + const uint32_t dst_single_tile_size) { + using namespace tt; + using namespace tt::tt_metal; + using namespace tt::constants; + bool zero_start_grid = false; + CoreCoord compute_with_storage_grid_size; + if (all_device_cores.size() == 1) { + const auto& cr = *all_device_cores.ranges().begin(); + if (cr.start_coord.x == 0 && cr.start_coord.y == 0) { + zero_start_grid = true; + compute_with_storage_grid_size = CoreCoord(cr.end_coord.x + 1, cr.end_coord.y + 1); + } + } + + auto src_buffer_a = a.buffer(); + auto src_buffer_b = b.buffer(); + auto dst_buffer = output.buffer(); + + CoreRangeSet all_cores, core_group_1, core_group_2; + + std::optional shard_spec = std::nullopt; + std::optional sharded_layout = std::nullopt; + bool src0_sharded = a.memory_config().is_sharded(); + bool src1_sharded = b.memory_config().is_sharded(); + bool out_sharded = output.memory_config().is_sharded(); + + bool block_or_width_sharded = false; + if (src0_sharded) { + shard_spec = a.shard_spec().value(); + block_or_width_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + sharded_layout = a.memory_config().memory_layout; + } else if (src1_sharded) { + shard_spec = b.shard_spec().value(); + block_or_width_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + sharded_layout = b.memory_config().memory_layout; + } else if (out_sharded) { + shard_spec = output.shard_spec().value(); + block_or_width_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + sharded_layout = output.memory_config().memory_layout; + } + + uint32_t num_tiles = a.volume() / TILE_HW; + + uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2, num_cores_total; + if (zero_start_grid) { + num_cores_total = compute_with_storage_grid_size.x * compute_with_storage_grid_size.y; + } else { + num_cores_total = all_device_cores.num_cores(); + } + + uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; + + uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; + + bool row_major; + uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, + last_unpadded_block_width = 0; + CoreCoord end_core; + std::vector cores; + + if (shard_spec.has_value()) { + all_cores = shard_spec.value().grid; + num_cores = all_cores.num_cores(); + core_group_1 = all_cores; + core_group_2 = CoreRangeSet(); + num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_core_group_2 = 0; + block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); + max_block_size = block_size_per_core_group_1; + + block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; + row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; + block_height = shard_spec.value().shape[0] / TILE_HEIGHT; + block_width = shard_spec.value().shape[1] / TILE_WIDTH; + if (block_or_width_sharded) { + block_size = block_width * block_height; + end_core = (*shard_spec.value().grid.ranges().begin()).end_coord; + output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; + last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); + last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); + } + if (zero_start_grid) { + auto bbox = core_group_1.bounding_box(); + cores = grid_to_cores_with_noop( + bbox.end_coord.x, + bbox.end_coord.y, + compute_with_storage_grid_size.x, + compute_with_storage_grid_size.y, + row_major); + } else { + cores = grid_to_cores_with_noop(all_cores, all_device_cores, row_major); + } + } else { + row_major = true; + std::tie( + num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = + zero_start_grid ? tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major) + : tt::tt_metal::split_work_to_cores(all_device_cores, num_tiles, row_major); + block_cnt_per_core_group_1 = num_tiles_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2; + if (zero_start_grid) { + cores = grid_to_cores( + num_cores_total, compute_with_storage_grid_size.x, compute_with_storage_grid_size.y, row_major); + } else { + cores = corerange_to_cores(all_cores, {}, row_major); + } + } + + uint32_t g1_numcores = core_group_1.num_cores(); + uint32_t g2_numcores = core_group_2.num_cores(); + + std::vector> binary_reader_args; + std::vector> eltwise_binary_args; + std::vector> unary_writer_args; + if constexpr (initialize_args) { + binary_reader_args = {cores.size(), std::vector(7)}; + eltwise_binary_args = {cores.size(), std::vector(2)}; + if (block_or_width_sharded and not out_sharded) { + unary_writer_args = {cores.size(), std::vector(7)}; + } else { + unary_writer_args = {cores.size(), std::vector(3)}; + } + } + + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i) { + const CoreCoord& core = cores.at(i); + uint32_t num_tiles_per_core = 0; + uint32_t block_cnt_per_core = 0; + uint32_t block_size_per_core = 0; + uint32_t num_shards_per_height = 0; + uint32_t num_shards_per_width = 0; + uint32_t start_id = 0; + if (shard_spec.has_value()) { + if (sharded_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED) { + num_shards_per_height = num_cores; + num_shards_per_width = 1; + } else if (sharded_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + num_shards_per_width = num_cores; + num_shards_per_height = 1; + } else { // block sharded + auto bbox = core_group_1.bounding_box(); + if (shard_spec.value().orientation == ShardOrientation::ROW_MAJOR) { + num_shards_per_height = bbox.end_coord.y - bbox.start_coord.y + 1; + num_shards_per_width = bbox.end_coord.x - bbox.start_coord.x + 1; + } else { + num_shards_per_height = bbox.end_coord.x - bbox.start_coord.x + 1; + num_shards_per_width = bbox.end_coord.y - bbox.start_coord.y + 1; + } + } + start_id = (i / num_shards_per_width) * (block_height * block_width * num_shards_per_width) + + (i % num_shards_per_width) * block_width; + } else { + start_id = num_tiles_read; + } + + if (i < g1_numcores) { + num_tiles_per_core = num_tiles_per_core_group_1; + block_cnt_per_core = block_cnt_per_core_group_1; + block_size_per_core = block_size_per_core_group_1; + } else if (i < num_cores) { + num_tiles_per_core = num_tiles_per_core_group_2; + block_cnt_per_core = block_cnt_per_core_group_2; + block_size_per_core = block_size_per_core_group_2; + } else { + // Zero out non-working cores RT args. Only necessary in override + // since initialization pushes zero vectors to unused cores. + if constexpr (!initialize_args) { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[2] = 0; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = 0; + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[1] = 0; + } + continue; + } + if constexpr (initialize_args) { + binary_reader_args[i] = { + src_buffer_a->address(), + src_buffer_b->address(), + num_tiles_per_core, + start_id, + block_height, + block_width, + num_shards_per_width, + num_shards_per_width}; + eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; + } else { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[0] = src_buffer_a->address(); + reader_args[1] = src_buffer_b->address(); + reader_args[2] = num_tiles_per_core; + reader_args[3] = start_id; + reader_args[4] = block_height; + reader_args[5] = block_width; + reader_args[6] = num_shards_per_width; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = block_cnt_per_core; + eltwise_args[1] = block_size_per_core; + } + if (block_or_width_sharded and not out_sharded) { + uint32_t unpadded_block_height = block_height; + uint32_t unpadded_block_width = block_width; + if (row_major) { + if (core.x == end_core.x) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.y == end_core.y) { + unpadded_block_height = last_unpadded_block_height; + } + } else { + if (core.y == end_core.y) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.x == end_core.x) { + unpadded_block_height = last_unpadded_block_height; + } + } + if constexpr (initialize_args) { + unary_writer_args[i] = { + dst_buffer->address(), + block_height, + block_width, + unpadded_block_height, + unpadded_block_width, + output_width, + block_size, + (i / num_shards_per_width) * (block_height * block_width * num_shards_per_width) + + (i % num_shards_per_width) * block_width, + 0}; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = block_height; + writer_args[2] = block_width; + writer_args[3] = unpadded_block_height; + writer_args[4] = unpadded_block_width; + writer_args[5] = output_width; + writer_args[6] = block_size; + writer_args[7] = (i / num_shards_per_width) * (block_height * block_width * num_shards_per_width) + + (i % num_shards_per_width) * block_width; + writer_args[8] = 0; + } + } else { + if constexpr (initialize_args) { + unary_writer_args[i] = {dst_buffer->address(), num_tiles_per_core, num_tiles_read}; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = num_tiles_per_core; + writer_args[2] = num_tiles_read; + } + } + num_tiles_read += num_tiles_per_core; + } + + if constexpr (initialize_args) { + SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); + SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); + SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); + } + + if (src0_sharded) { + UpdateDynamicCircularBufferAddressAndTotalSize( + program, cb_src0, *src_buffer_a, num_tiles_per_core_group_1 * src0_single_tile_size); + } + if (src1_sharded) { + UpdateDynamicCircularBufferAddressAndTotalSize( + program, cb_src1, *src_buffer_b, num_tiles_per_core_group_1 * src1_single_tile_size); + } + if (out_sharded) { + UpdateDynamicCircularBufferAddressAndTotalSize( + program, cb_output, *dst_buffer, num_tiles_per_core_group_1 * dst_single_tile_size); + } +} + +} // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/dataflow/reader_bcast_h_sharded_optimised.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/dataflow/reader_bcast_h_sharded_optimised.cpp index 84e091ba56a5..464cdf72bd06 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/dataflow/reader_bcast_h_sharded_optimised.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/dataflow/reader_bcast_h_sharded_optimised.cpp @@ -3,12 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 // This code is temporarily copied from ttnn/cpp/ttnn/operations/datamovement/binary/device/ to demonstrate -<<<<<<< HEAD -// the new ability to keep the CircularBufferConfigs continuous during dispatching. See the use of CBIndex::c_16 below. -======= // the new ability to keep the CircularBufferConfigs continuous during dispatching. See the use of CBIndex::c_2 below. ->>>>>>> 500923c2b7... #7493: Updating some ops to use c_2 instead of c_16 given the dependency on eltwise -// When broadcating is properly supported we expect this code to be deleted or refactored substantially. +// When broadcasting is properly supported we expect this code to be deleted or refactored substantially. #include #include "dataflow_api.h"