From 92bce922fe0c1c0e377119cdb3f4cd1deabea75f Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 12 Jun 2024 19:34:58 +0000 Subject: [PATCH] #0: Fix calculation for remaining pages in fast height reshard --- .../sharded/multi_core/sharded_op_multi_core.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp index f896ccffbf0..6af4500fe73 100644 --- a/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/sharded/multi_core/sharded_op_multi_core.cpp @@ -886,6 +886,7 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu uint32_t total_size, unit_size, input_units_per_shard, output_units_per_shard; auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + uint32_t num_output_units = input.buffer()->num_pages(); if (input.get_layout() == Layout::TILE) { unit_size = tt_metal::detail::TileSize(data_format); input_units_per_shard = input_shard_spec.numel() / TILE_HW; @@ -922,15 +923,17 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu auto input_buffer_type = input.buffer()->buffer_type(); std::array kernels = {kernel_id_0, kernel_id_1}; + uint32_t output_units_left = num_output_units; for (const auto& core : output_cores) { - uint32_t output_units_left = output_units_per_shard; - uint32_t output_units_per_kernel = div_up(output_units_per_shard, kernels.size()); + uint32_t output_units_per_core = std::min(output_units_left, output_units_per_shard); + output_units_left -= output_units_per_core; + uint32_t output_units_per_kernel = div_up(output_units_per_core, kernels.size()); for (const auto& kernel_id : kernels) { std::vector kernel_args = {input_address, 0, 0}; - uint32_t output_units_to_get = std::min(output_units_left, output_units_per_kernel); + uint32_t output_units_to_get = std::min(output_units_per_core, output_units_per_kernel); if (output_units_to_get != 0) { uint32_t num_reads = 0; - kernel_args[1] = (output_units_per_shard - output_units_left) * unit_size; + kernel_args[1] = (output_units_per_shard - output_units_per_core) * unit_size; auto bank_id = device->bank_ids_from_logical_core(input_buffer_type, input_cores[input_core_idx])[0]; uint32_t bank_offset = device->bank_offset(input_buffer_type, bank_id); while (output_units_to_get > 0) { @@ -943,7 +946,7 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu static_cast(input_core.y), (input_units_per_shard - input_core_units_rem) * unit_size + bank_offset, units_to_read * unit_size}); - output_units_left -= units_to_read; + output_units_per_core -= units_to_read; output_units_to_get -= units_to_read; input_core_units_rem -= units_to_read; if (input_core_units_rem == 0) {