Skip to content

Commit

Permalink
#0: Fix calculation for remaining pages in fast height reshard
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jun 13, 2024
1 parent 891a333 commit 92bce92
Showing 1 changed file with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu
uint32_t total_size, unit_size, input_units_per_shard, output_units_per_shard;
auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype());

uint32_t num_output_units = input.buffer()->num_pages();
if (input.get_layout() == Layout::TILE) {
unit_size = tt_metal::detail::TileSize(data_format);
input_units_per_shard = input_shard_spec.numel() / TILE_HW;
Expand Down Expand Up @@ -922,15 +923,17 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu
auto input_buffer_type = input.buffer()->buffer_type();

std::array<tt_metal::KernelHandle, 2> kernels = {kernel_id_0, kernel_id_1};
uint32_t output_units_left = num_output_units;
for (const auto& core : output_cores) {
uint32_t output_units_left = output_units_per_shard;
uint32_t output_units_per_kernel = div_up(output_units_per_shard, kernels.size());
uint32_t output_units_per_core = std::min(output_units_left, output_units_per_shard);
output_units_left -= output_units_per_core;
uint32_t output_units_per_kernel = div_up(output_units_per_core, kernels.size());
for (const auto& kernel_id : kernels) {
std::vector<uint32_t> kernel_args = {input_address, 0, 0};
uint32_t output_units_to_get = std::min(output_units_left, output_units_per_kernel);
uint32_t output_units_to_get = std::min(output_units_per_core, output_units_per_kernel);
if (output_units_to_get != 0) {
uint32_t num_reads = 0;
kernel_args[1] = (output_units_per_shard - output_units_left) * unit_size;
kernel_args[1] = (output_units_per_shard - output_units_per_core) * unit_size;
auto bank_id = device->bank_ids_from_logical_core(input_buffer_type, input_cores[input_core_idx])[0];
uint32_t bank_offset = device->bank_offset(input_buffer_type, bank_id);
while (output_units_to_get > 0) {
Expand All @@ -943,7 +946,7 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu
static_cast<uint32_t>(input_core.y),
(input_units_per_shard - input_core_units_rem) * unit_size + bank_offset,
units_to_read * unit_size});
output_units_left -= units_to_read;
output_units_per_core -= units_to_read;
output_units_to_get -= units_to_read;
input_core_units_rem -= units_to_read;
if (input_core_units_rem == 0) {
Expand Down

0 comments on commit 92bce92

Please sign in to comment.