diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index 219b68433761..1037101d1b37 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -2217,7 +2217,6 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: unpadded_shape = x.shape_without_padding() x = tt_lib.tensor.untilize_with_unpadding( x, - (0, 0, 0, 0), (unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1), self.memory_config, ) @@ -2274,7 +2273,7 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: ] if self.sharded: x = tt_lib.tensor.untilize_with_unpadding( - x, (0, 0, 0, 0), unpadded_shape_end, output_mem_config=self.width_sharded_memory_config + x, unpadded_shape_end, output_mem_config=self.width_sharded_memory_config ) else: x = tt_lib.tensor.untilize(x, self.memory_config, use_multicore=True) @@ -2313,7 +2312,6 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: desired_shape[-1] = 1000 x = tt_lib.tensor.untilize_with_unpadding( x, - [0, 0, 0, 0], (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1), self.memory_config, ) diff --git a/models/experimental/resnet/tt/ttnn_functional_resnet50.py b/models/experimental/resnet/tt/ttnn_functional_resnet50.py index 1b5e127388bc..a98e0f648942 100644 --- a/models/experimental/resnet/tt/ttnn_functional_resnet50.py +++ b/models/experimental/resnet/tt/ttnn_functional_resnet50.py @@ -676,7 +676,6 @@ def __call__(self, input_tensor) -> ttnn.Tensor: unpadded_shape = x.shape_without_padding() x = ttnn.experimental.tensor.untilize_with_unpadding( x, - (0, 0, 0, 0), (unpadded_shape[0] - 1, unpadded_shape[1] - 1, unpadded_shape[2] - 1, unpadded_shape[3] - 1), ttnn.L1_MEMORY_CONFIG, ) @@ -735,7 +734,7 @@ def __call__(self, input_tensor) -> ttnn.Tensor: x.get_legacy_shape()[3] - 1, ] x = ttnn.experimental.tensor.untilize_with_unpadding( - x, (0, 0, 0, 0), unpadded_shape_end, output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG + x, unpadded_shape_end, output_mem_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG ) x = ttnn.reshape( @@ -763,7 +762,6 @@ def __call__(self, input_tensor) -> ttnn.Tensor: desired_shape[-1] = 1000 x = ttnn.experimental.tensor.untilize_with_unpadding( x, - [0, 0, 0, 0], (desired_shape[0] - 1, desired_shape[1] - 1, desired_shape[2] - 1, desired_shape[3] - 1), ttnn.L1_MEMORY_CONFIG, ) diff --git a/tests/tt_eager/python_api_testing/non_working_unit_tests/wormhole/test_untilize_with_unpadding.py b/tests/tt_eager/python_api_testing/non_working_unit_tests/wormhole/test_untilize_with_unpadding.py index 0fbfdcf11292..aabdd6ec7c00 100644 --- a/tests/tt_eager/python_api_testing/non_working_unit_tests/wormhole/test_untilize_with_unpadding.py +++ b/tests/tt_eager/python_api_testing/non_working_unit_tests/wormhole/test_untilize_with_unpadding.py @@ -25,7 +25,6 @@ def run_untilize_with_unpadding_tests( in_mem_config, out_mem_config, data_seed, - output_tensor_start, output_tensor_end, device, ): @@ -38,12 +37,11 @@ def run_untilize_with_unpadding_tests( # compute ref value x_ref = x.detach().clone() ref_value = pytorch_ops.untilize_with_unpadding( - x_ref, output_tensor_start=output_tensor_start, output_tensor_end=output_tensor_end + x_ref, output_tensor_end=output_tensor_end ) tt_result = tt_untilize_with_unpadding( x=x, - output_tensor_start=output_tensor_start, output_tensor_end=output_tensor_end, device=device, dtype=[dtype], @@ -68,14 +66,13 @@ def run_untilize_with_unpadding_tests( "SYSTEM_MEMORY", ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), 5263366, - [0, 0, 0, 0], [10, 9, 4, 1], ), ] @pytest.mark.parametrize( - "input_shape, dtype, dlayout, in_mem_config, out_mem_config, data_seed, output_tensor_start, output_tensor_end", + "input_shape, dtype, dlayout, in_mem_config, out_mem_config, data_seed, output_tensor_end", (test_sweep_args), ) def test_untilize_with_unpadding_test( @@ -85,7 +82,6 @@ def test_untilize_with_unpadding_test( in_mem_config, out_mem_config, data_seed, - output_tensor_start, output_tensor_end, device, ): @@ -97,7 +93,6 @@ def test_untilize_with_unpadding_test( in_mem_config, out_mem_config, data_seed, - output_tensor_start, output_tensor_end, device, ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py b/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py index 4cbaaa75e042..e507123efaba 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/generation_funcs.py @@ -823,14 +823,13 @@ def gen_untilize_with_unpadding_args( input_shapes, dtypes, layouts, mem_configs, do_sanitize_args=do_sanitize_args ): if input_info is not None: - output_tensor_start = [0, 0, 0, 0] - output_tensor_end = [random.randrange(output_tensor_start[i], input_shapes[0][i], 1) for i in range(4)] + output_tensor_end = [random.randrange(0, input_shapes[0][i], 1) for i in range(4)] if output_tensor_end[-1] % 2 == 0: output_tensor_end[-1] += 1 input_info.update( { - "output_tensor_start": output_tensor_start, "output_tensor_end": output_tensor_end, + "use_multicore": True, } ) yield input_info diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py index 52b68e525f54..a5390bcd7773 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py @@ -38,7 +38,6 @@ def create_grid(x, y): "output_mem_config": ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM ), - "output_tensor_start": [0, 0, 0, 0], "output_tensor_end": [0, 0, 119, 7299], }, ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index b2be7ec52c46..e5a97e749e99 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1200,13 +1200,13 @@ def tilize_with_val_padding(x, output_tensor_shape, pad_value, *args, **kwargs): return tilized -def untilize_with_unpadding(x, output_tensor_start, output_tensor_end, *args, **kwargs): +def untilize_with_unpadding(x, output_tensor_end, *args, **kwargs): untilized = untilize_util(x) unpad = untilized[ - output_tensor_start[0] : output_tensor_end[0] + 1, - output_tensor_start[1] : output_tensor_end[1] + 1, - output_tensor_start[2] : output_tensor_end[2] + 1, - output_tensor_start[3] : output_tensor_end[3] + 1, + : output_tensor_end[0] + 1, + : output_tensor_end[1] + 1, + : output_tensor_end[2] + 1, + : output_tensor_end[3] + 1, ] return unpad diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 38d4fe6dcd2a..c748acefa75c 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -2063,7 +2063,6 @@ def untilize_with_unpadding( layout, input_mem_config, output_mem_config, - output_tensor_start, output_tensor_end, **kwargs, ): @@ -2085,7 +2084,7 @@ def untilize_with_unpadding( ) t1 = ttl.tensor.untilize_with_unpadding( - t0, output_tensor_start, output_tensor_end, output_mem_config=output_mem_config + t0, output_tensor_end, output_mem_config=output_mem_config ) return tt2torch_tensor(t1) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py index 83565ec96978..443b08c678a7 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py @@ -1655,7 +1655,6 @@ def test_block_sharded_untilize_with_unpadding(in_sharded, out_sharded, dtype, d yt = ttl.tensor.untilize_with_unpadding( xt, - ttl.tensor.Shape([0, 0, 0, 0]), ttl.tensor.Shape([0, 0, 391, 511]), output_mem_config=out_mem_config, ) @@ -1744,7 +1743,6 @@ def test_width_sharded_untilize_with_unpadding( yt = ttl.tensor.untilize_with_unpadding( xt, - ttl.tensor.Shape([0, 0, 0, 0]), ttl.tensor.Shape([N - 1, C - 1, output_H - 1, W - 1]), output_mem_config=out_mem_config, ) diff --git a/tt_eager/tt_dnn/op_library/auto_format.cpp b/tt_eager/tt_dnn/op_library/auto_format.cpp index d2bc209b35e0..56ad92cc639f 100644 --- a/tt_eager/tt_dnn/op_library/auto_format.cpp +++ b/tt_eager/tt_dnn/op_library/auto_format.cpp @@ -153,7 +153,6 @@ Tensor AutoFormat::format_output_tensor( } else if (formatted_output.get_layout() == Layout::TILE && AutoFormat::legal_rm_shape(shape)) { formatted_output = untilize_with_unpadding( formatted_output, - {0, 0, 0, 0}, {shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}, mem_config); return formatted_output; @@ -163,7 +162,6 @@ Tensor AutoFormat::format_output_tensor( AutoFormat::legal_rm_shape(shape)) { formatted_output = untilize_with_unpadding( formatted_output, - {0, 0, 0, 0}, {shape[0] - 1, shape[1] - 1, shape[2] - 1, shape[3] - 1}, mem_config); return formatted_output; diff --git a/tt_eager/tt_dnn/op_library/tilize/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp b/tt_eager/tt_dnn/op_library/tilize/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp index 3ff0569cd747..bbb583e089a1 100644 --- a/tt_eager/tt_dnn/op_library/tilize/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp +++ b/tt_eager/tt_dnn/op_library/tilize/kernels/dataflow/reader_unary_pad_dims_split_rows_multicore.cpp @@ -53,9 +53,8 @@ void kernel_main() { cb_reserve_back(cb_id_in0, num_tiles_per_row * has_rows); uint32_t l1_write_addr = get_write_ptr(cb_id_in0); - uint32_t curr_stick_id = base_stick_id; for (uint32_t k = 0; k < num_rows; k++) { - uint64_t src_noc_addr = get_noc_addr(curr_stick_id + k, s); + uint64_t src_noc_addr = get_noc_addr(base_stick_id + k, s); // Read from DRAM to tmp buffer noc_async_read(src_noc_addr, l1_write_addr, unpadded_X_size); diff --git a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/padding.h b/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/padding.h deleted file mode 100644 index ec9a90308042..000000000000 --- a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/padding.h +++ /dev/null @@ -1,113 +0,0 @@ -// # SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// # SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include - -namespace tt::tt_metal { -// BlockRep represents a repeated sequence of data blocks, mixed blocks, and padding blocks. -// It is convient to pass to the device kernels because a single data structure made of 4 ints -// can represent pure data rows, pure padding rows or a mixture thereof. -struct BlockRep { - // number of data blocks - uint32_t n_data; - // number of mixed data rows in a mixed block, 0 means no mixed block - uint32_t n_mixed; - // number of padding blocks - uint32_t n_pads; - // total repeat times - uint32_t times; - - BlockRep(uint32_t n_data, uint32_t n_mixed, uint32_t n_pads, uint32_t times) - : n_data(n_data), n_mixed(n_mixed), n_pads(n_pads), times(times) { - if (n_data == 0 && n_mixed == 0) { - n_pads *= times; - times = 1; - } else if (n_pads == 0 && n_mixed == 0) { - n_data *= times; - times = 1; - } - } - - bool has_mixed_block() const { return n_mixed > 0; } - - uint32_t single_rep() const { return n_data + has_mixed_block() + n_pads; } - - uint32_t block_count() const { return single_rep() * times; } - - uint32_t data_row_count() const { return (n_data * 32 + n_mixed) * times; } - - std::pair, std::vector> split_at(uint32_t idx) const { - // TT_ASSERT(idx <= block_count()); - - std::vector first; - std::vector second; - - int rep_idx = idx / single_rep(); - if (rep_idx > 0) { - first.emplace_back(n_data, n_mixed, n_pads, rep_idx); - } - - int within_rep_idx = idx % single_rep(); - bool is_within_rep = within_rep_idx > 0; - if (is_within_rep) { - if (within_rep_idx <= n_data) { - first.emplace_back(within_rep_idx, 0, 0, 1); - second.emplace_back(n_data - within_rep_idx, n_mixed, n_pads, 1); - } else if (within_rep_idx == n_data + 1 && has_mixed_block()) { - first.emplace_back(n_data, n_mixed, 0, 1); - second.emplace_back(0, 0, n_pads, 1); - } else { - within_rep_idx -= n_data + has_mixed_block(); - first.emplace_back(n_data, n_mixed, within_rep_idx, 1); - second.emplace_back(0, 0, n_pads - within_rep_idx, 1); - } - } - - int remaining_times = times - rep_idx - is_within_rep; - if (remaining_times > 0) { - second.emplace_back(n_data, n_mixed, n_pads, remaining_times); - } - - return {first, second}; - } -}; - -// FullRep is a repeated sequence of data rows followed by pure padding. It represents the row -// pattern seen from the outer-most dimension of a 4D tensor when padding is added to the second -// or the thrird dimension. -struct FullRep { - uint32_t n_rows; - uint32_t n_pads; - uint32_t times; - - uint32_t pads_mul; - uint32_t times_total; - - BlockRep rep; - BlockRep pad; - - FullRep(uint32_t n_rows, uint32_t n_pads, uint32_t times, uint32_t pads_mul, uint32_t times_total) - : n_rows(n_rows), n_pads(n_pads), times(times), pads_mul(pads_mul), - times_total(times_total), rep{n_rows / 32, n_rows % 32, n_pads / 32, times}, - pad{0, 0, (n_rows + n_pads) * pads_mul, 1} { - // TT_ASSERT((n_rows + n_pads) % 32 == 0 && "total rows must be divisible by 32"); - } - - std::vector to_block_reps() const { - std::vector block_reps; - block_reps.reserve(2 * times_total); - - for (int i = 0; i < times_total; ++i) { - block_reps.push_back(rep); - block_reps.push_back(pad); - } - - return block_reps; - } -}; - -} // namespace tt::tt_metal diff --git a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp index bd67e5deb7a3..36cdfb9986c3 100644 --- a/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/tilize/tilize_multi_core/tilize_op_multi_core.cpp @@ -4,12 +4,9 @@ #include -#include "padding.h" #include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/operation.hpp" #include "tt_dnn/op_library/work_split_tilize.hpp" - -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/common/math.hpp" #include "tt_metal/detail/util.hpp" @@ -20,60 +17,6 @@ using namespace tt::constants; namespace tt::tt_metal { -inline std::vector> distribute_work(const Shape& shape, uint32_t num_cores, uint32_t blocks_per_core, bool has_cliff, uint32_t nblocks_per_core_cliff) { - const auto& unpadded = shape.without_padding(); - auto input_w = unpadded.rank() >= 4 ? unpadded[-4] : 1; - auto input_z = unpadded.rank() >= 3 ? unpadded[-3] : 1; - auto input_y = unpadded.rank() >= 2 ? unpadded[-2] : 1; - - const auto& padding = shape.padding(); - auto padding_w = unpadded.rank() >= 4 ? padding[shape.get_normalized_index(-4)].back : 0; - auto padding_z = unpadded.rank() >= 3 ? padding[shape.get_normalized_index(-3)].back : 0; - auto padding_y = unpadded.rank() >= 2 ? padding[shape.get_normalized_index(-2)].back : 0; - - // total work is a full rep followed by a padding. - auto full_rep_blocks = FullRep(input_y, padding_y, input_z, padding_z, input_w).to_block_reps(); - std::deque total_work(full_rep_blocks.begin(), full_rep_blocks.end()); - total_work.emplace_back(0, 0, (input_y + padding_y) * (input_z + padding_z) * padding_w, 1); - - std::vector> core_assignments; - for (int i = 0; i < num_cores; i++) { - int blocks_to_process = blocks_per_core; - if (i == num_cores - 1 && has_cliff) { - blocks_to_process = nblocks_per_core_cliff; - } - - // Assign blocks to cores - std::vector core_blocks; - int core_blocks_count = 0; - while (core_blocks_count < blocks_to_process) { - if (total_work.empty()) { - break; - } - - int remaining_core_blocks = blocks_to_process - core_blocks_count; - auto& first = total_work.front(); - if (first.block_count() <= remaining_core_blocks) { - core_blocks.push_back(first); - core_blocks_count += first.block_count(); - total_work.pop_front(); - } else { - auto [head, tail] = first.split_at(remaining_core_blocks); - for (auto& el : head) { - core_blocks.push_back(el); - core_blocks_count += el.block_count(); - } - total_work.pop_front(); - total_work.insert(total_work.begin(), tail.begin(), tail.end()); - } - } - - core_assignments.push_back(core_blocks); - } - - return core_assignments; -} - operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, Tensor& output) { tt_metal::Program program = tt_metal::CreateProgram(); @@ -89,7 +32,8 @@ operation::ProgramWithCallbacks tilize_multi_core_interleaved(const Tensor& a, T Device* device = a.device(); auto grid_size = device->compute_with_storage_grid_size(); - auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = split_blocks_for_tilize(grid_size, nblocks); + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, nblocks); uint32_t src0_cb_index = CB::c_in0; uint32_t num_input_tiles = ntiles_per_block; @@ -346,21 +290,22 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); uint32_t output_single_tile_size = detail::TileSize(output_cb_data_format); - const Shape& true_input_shape = a.get_legacy_shape(); - const Shape& true_output_shape = output.get_legacy_shape(); + const Shape& input_shape = a.get_legacy_shape(); + const Shape& output_shape = output.get_legacy_shape(); Device* device = a.device(); CoreCoord grid_size = device->compute_with_storage_grid_size(); - uint32_t num_blocks = output.volume() / true_output_shape[-1] / TILE_HEIGHT; + uint32_t num_blocks = output.volume() / output_shape[-1] / TILE_HEIGHT; uint32_t num_tiles_per_row = output.get_legacy_shape()[-1] / TILE_WIDTH; - auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = split_blocks_for_tilize(grid_size, num_blocks); + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, num_blocks); bool has_cliff = core_range_cliff.size() > 0; - uint32_t unpadded_row_size_bytes = true_input_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat - uint32_t padded_row_size_bytes = true_output_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat + uint32_t unpadded_row_size_bytes = input_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat + uint32_t padded_row_size_bytes = output_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat uint32_t src0_cb_index = CB::c_in0; tt_metal::CircularBufferConfig src0_cb_config = @@ -426,7 +371,13 @@ operation::ProgramWithCallbacks tilize_with_val_padding_multi_core_interleaved( uint32_t packed_pad_value = pack_two_bfloat16_into_uint32({bfloat_pad_value, bfloat_pad_value}); // 1D distribution of blocks across cores - auto core_assignments = distribute_work(true_output_shape, ncores, nblocks_per_core, has_cliff, nblocks_per_core_cliff); + auto core_assignments = distribute_work( + output_shape.without_padding(), + output_shape.padding(), + ncores, + nblocks_per_core, + has_cliff, + nblocks_per_core_cliff); uint32_t tile_start_id = 0; uint32_t row_start_id = 0; diff --git a/tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/writer_unary_stick_layout_split_rows_multicore.cpp b/tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/writer_unary_stick_layout_split_rows_multicore.cpp new file mode 100644 index 000000000000..45dbf2aede74 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/writer_unary_stick_layout_split_rows_multicore.cpp @@ -0,0 +1,80 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + // Constexpr + constexpr uint32_t cb_id_out0 = 16; + constexpr uint32_t tile_height = 32; + + const uint32_t dst_addr = get_arg_val(0); + const uint32_t unpadded_X_size = get_arg_val(1); + const uint32_t padded_X_size = get_arg_val(2); + const uint32_t start_stick_id = get_arg_val(3); + const uint32_t n_block_reps = get_arg_val(4); + + constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool FLOAT32_DTYPE = get_compile_time_arg_val(3) == 1; + + const uint32_t num_tiles_per_row = padded_X_size >> (FLOAT32_DTYPE ? 7 : 6); + +#define stick_size_is_power_of_two get_compile_time_arg_val(1) == 1 + +#if (stick_size_is_power_of_two) + constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(2); + const InterleavedPow2AddrGen s = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = log_base_2_of_page_size}; +#else + const InterleavedAddrGen s = {.bank_base_address = dst_addr, .page_size = unpadded_X_size}; +#endif + + auto pop_blocks = [&](uint32_t num_blocks) { + for (uint32_t i = 0; i < num_blocks; i++) { + cb_wait_front(cb_id_out0, num_tiles_per_row); + cb_pop_front(cb_id_out0, num_tiles_per_row); + } + }; + + auto write_block = [&](uint32_t base_stick_id, uint32_t num_rows) { + uint32_t padding_rows = (tile_height - num_rows) & 31; + bool has_rows = (num_rows + padding_rows) > 0; + + cb_wait_front(cb_id_out0, num_tiles_per_row * has_rows); + uint32_t l1_read_addr = get_read_ptr(cb_id_out0); + for (uint32_t k = 0; k < num_rows; k++) { + uint64_t dst_noc_addr = get_noc_addr(base_stick_id + k, s); + + // Write out tmp buffer + noc_async_write(l1_read_addr, dst_noc_addr, unpadded_X_size); + + noc_async_write_barrier(); + l1_read_addr += padded_X_size; + } + cb_pop_front(cb_id_out0, num_tiles_per_row * has_rows); + }; + + uint32_t stick_id = start_stick_id; + uint32_t rt_arg_idx = 5; + for (uint32_t block_rep_idx = 0; block_rep_idx < n_block_reps; ++block_rep_idx) { + const uint32_t n_data = get_arg_val(rt_arg_idx++); // number of full tile-rows + const uint32_t n_mixed = get_arg_val(rt_arg_idx++); // number of rows in a partially filled tile-row + const uint32_t n_pads = get_arg_val(rt_arg_idx++); // number of padding tile-rows + const uint32_t times = get_arg_val(rt_arg_idx++); // number of times the pattern of tile-rows repeats + + for (uint32_t t = 0; t < times; ++t) { + for (uint32_t y_t = 0; y_t < n_data; y_t++) { + write_block(stick_id, tile_height); + stick_id += tile_height; + } + + write_block(stick_id, n_mixed); + stick_id += n_mixed; + + pop_blocks(n_pads); + } + } +} diff --git a/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp index 7aeacfaf5bc0..843d8c100b64 100644 --- a/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/multi_core/untilize_op_multi_core.cpp @@ -4,13 +4,12 @@ #include - +#include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/untilize/untilize_op.hpp" #include "tt_dnn/op_library/work_split_tilize.hpp" -#include "tt_dnn/op_library/math.hpp" -#include "tt_metal/host_api.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; @@ -25,14 +24,15 @@ uint32_t get_num_cores(CoreCoord grid_size, uint32_t nblocks) { if (nblocks <= ncores) { ncores = nblocks; } else { - uint32_t nblocks_per_core = ceil((float) nblocks / ncores); - ncores = ceil((float) nblocks / nblocks_per_core); + uint32_t nblocks_per_core = ceil((float)nblocks / ncores); + ncores = ceil((float)nblocks / nblocks_per_core); } return ncores; } -} +} // namespace untilize_helpers -operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { +operation::ProgramWithCallbacks untilize_multi_core( + const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { tt_metal::Program program = tt_metal::CreateProgram(); bool src_sharded = a.memory_config().is_sharded(); @@ -43,21 +43,23 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_cb_data_format); - Device *device = a.device(); + Device* device = a.device(); uint32_t ntiles = a.volume() / TILE_HW; uint32_t ntiles_per_block = a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t nblocks = ceil((float) ntiles / ntiles_per_block); + uint32_t nblocks = ceil((float)ntiles / ntiles_per_block); uint32_t block_size_nbytes = a.get_legacy_shape()[-1] * output.element_size(); auto grid_size = device->compute_with_storage_grid_size(); - auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = split_blocks_for_tilize(grid_size, nblocks); + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, nblocks); uint32_t ncores_x = grid_size.x; uint32_t ncores_y = std::ceil(static_cast(ncores) / ncores_x); bool row_major = true; bool src_block_sharded = false; - uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0, num_output_rows_unpadded = 0; + uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0, + num_output_rows_unpadded = 0; CoreCoord end_core; std::vector cores_with_rtargs; @@ -76,19 +78,22 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out nblocks_per_core_cliff = 0; num_rows_block = shard_spec.shape[0]; - block_row_size = shard_spec.shape[1] * output.element_size(); // in0_block_w * TILE_WIDTH * dtype_nbytes - output_row_size = output.get_legacy_shape()[-1] * output.element_size(); // output row size bytes - last_block_row_size_unpadded = block_row_size - (round_up(output.get_legacy_shape()[-1], shard_spec.shape[1]) - output.get_legacy_shape()[-1]) * output.element_size(); + block_row_size = shard_spec.shape[1] * output.element_size(); // in0_block_w * TILE_WIDTH * dtype_nbytes + output_row_size = output.get_legacy_shape()[-1] * output.element_size(); // output row size bytes + last_block_row_size_unpadded = block_row_size - (round_up(output.get_legacy_shape()[-1], shard_spec.shape[1]) - + output.get_legacy_shape()[-1]) * + output.element_size(); uint32_t num_output_rows = output.volume() / output.get_legacy_shape()[-1]; num_output_rows_unpadded = num_rows_block - (round_up(num_output_rows, shard_spec.shape[0]) - num_output_rows); end_core = (*shard_spec.grid.ranges().begin()).end; - } uint32_t src0_cb_index = CB::c_in0; uint32_t num_input_tiles = src_sharded ? ntiles_per_block * nblocks_per_core : ntiles_per_block * 2; - tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) - .set_page_size(src0_cb_index, input_single_tile_size); + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, input_single_tile_size); if (src_sharded) { src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer()); } @@ -96,15 +101,17 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out uint32_t output_cb_index = CB::c_out0; uint32_t num_output_tiles = out_sharded ? ntiles_per_block * nblocks_per_core : ntiles_per_block * 2; - tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) - .set_page_size(output_cb_index, output_single_tile_size); + tt_metal::CircularBufferConfig output_cb_config = + tt_metal::CircularBufferConfig( + num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); if (out_sharded) { output_cb_config = output_cb_config.set_globally_allocated_address(*output.buffer()); } auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); /** reader @@ -112,9 +119,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out KernelHandle unary_reader_kernel_id; if (src_sharded) { - std::vector reader_ct_args = { - (std::uint32_t) src0_cb_index - }; + std::vector reader_ct_args = {(std::uint32_t)src0_cb_index}; unary_reader_kernel_id = tt_metal::CreateKernel( program, @@ -123,9 +128,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out tt_metal::ReaderDataMovementConfig(reader_ct_args)); } else { bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; - vector reader_ct_args = { - (uint32_t) src0_is_dram - }; + vector reader_ct_args = {(uint32_t)src0_is_dram}; unary_reader_kernel_id = CreateKernel( program, @@ -138,9 +141,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out */ KernelHandle unary_writer_kernel_id; if (out_sharded) { - std::vector writer_ct_args = { - (std::uint32_t) output_cb_index - }; + std::vector writer_ct_args = {(std::uint32_t)output_cb_index}; unary_writer_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded.cpp", @@ -150,9 +151,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; if (src_block_sharded) { vector writer_ct_args = { - (uint32_t) out_is_dram, - (uint32_t) (input_cb_data_format == tt::DataFormat::Float32) - }; + (uint32_t)out_is_dram, (uint32_t)(input_cb_data_format == tt::DataFormat::Float32)}; unary_writer_kernel_id = CreateKernel( program, "tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_blocks.cpp", @@ -160,16 +159,17 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out WriterDataMovementConfig(writer_ct_args)); } else { bool stick_size_is_power_of_two = is_power_of_two_at_least_32(block_size_nbytes); - uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t) std::log2(block_size_nbytes) : 0; + uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(block_size_nbytes) : 0; vector writer_ct_args = { - (uint32_t) out_is_dram, - (uint32_t) stick_size_is_power_of_two, - (uint32_t) log2_stick_size, + (uint32_t)out_is_dram, + (uint32_t)stick_size_is_power_of_two, + (uint32_t)log2_stick_size, }; unary_writer_kernel_id = CreateKernel( program, - "tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/writer_unary_stick_layout_split_rows_interleaved.cpp", + "tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/" + "writer_unary_stick_layout_split_rows_interleaved.cpp", all_cores, WriterDataMovementConfig(writer_ct_args)); } @@ -178,12 +178,12 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out /** compute */ vector compute_args = { - (uint32_t) nblocks_per_core, // per_core_block_cnt - (uint32_t) ntiles_per_block, // per_block_ntiles + (uint32_t)nblocks_per_core, // per_core_block_cnt + (uint32_t)ntiles_per_block, // per_block_ntiles }; vector compute_args_cliff = { - (uint32_t) nblocks_per_core_cliff, - (uint32_t) ntiles_per_block, // per_block_ntiles + (uint32_t)nblocks_per_core_cliff, + (uint32_t)ntiles_per_block, // per_block_ntiles }; std::string compute_kernel("tt_eager/tt_dnn/op_library/untilize/kernels/compute/pack_untilize.cpp"); @@ -199,18 +199,14 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out program, compute_kernel, core_range, - ComputeConfig{ - .fp32_dest_acc_en = fp32_dest_acc_en, - .compile_args = compute_args}); + ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_args}); } if (core_range_cliff.ranges().size() > 0) { auto untilize_cliff_kernel_id = CreateKernel( program, compute_kernel, core_range_cliff, - ComputeConfig{ - .fp32_dest_acc_en = fp32_dest_acc_en, - .compile_args = compute_args_cliff}); + ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_args_cliff}); } // 1D distribution of blocks across all cores @@ -224,7 +220,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out uint32_t tile_start_id = 0; uint32_t row_start_id = 0; auto cores = grid_to_cores(ncores_x * ncores_y, ncores_x, ncores_y, row_major); - for (uint32_t i = 0; i < cores.size(); i++){ + for (uint32_t i = 0; i < cores.size(); i++) { CoreCoord core = cores[i]; if (!full_cores.core_coord_in_core_ranges(core)) { continue; @@ -234,22 +230,23 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out if (src_sharded) { reader_rt_args = { - ntiles_per_block * nblocks_per_core // ntiles + ntiles_per_block * nblocks_per_core // ntiles }; } else { reader_rt_args = { - src0_buffer->address(), // src_addr - ntiles_per_block * nblocks_per_core, // ntiles - tile_start_id // start_id + src0_buffer->address(), // src_addr + ntiles_per_block * nblocks_per_core, // ntiles + tile_start_id // start_id }; } - // log_debug("reader[{}]: {},{} = {} ({})", src0_buffer->address(), core.x, core.y, tile_start_id, ntiles_per_block * nblocks_per_core); + // log_debug("reader[{}]: {},{} = {} ({})", src0_buffer->address(), core.x, core.y, tile_start_id, + // ntiles_per_block * nblocks_per_core); // writer runtime args vector writer_rt_args; - if (out_sharded) { + if (out_sharded) { writer_rt_args = { - ntiles_per_block * nblocks_per_core // ntiles + ntiles_per_block * nblocks_per_core // ntiles }; } else { if (src_block_sharded) { @@ -278,7 +275,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out } writer_rt_args = { - dst_buffer->address(), // dst_addr + dst_buffer->address(), // dst_addr num_rows_block, block_row_size, 1, @@ -288,65 +285,56 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out row_size_unpadded, num_rows_unpadded, block_start_row_id_offset, - block_start_row_offset - }; + block_start_row_offset}; } else { writer_rt_args = { - dst_buffer->address(), // dst_addr - nblocks_per_core * TILE_HEIGHT, // nblocks per core - block_size_nbytes, // block_size_nbytes - ntiles_per_block, // ntiles_per_block - block_size_nbytes, // block_size_nbytes - 1, // full blocks in a row + dst_buffer->address(), // dst_addr + nblocks_per_core * TILE_HEIGHT, // nblocks per core + block_size_nbytes, // block_size_nbytes + ntiles_per_block, // ntiles_per_block + block_size_nbytes, // block_size_nbytes + 1, // full blocks in a row 0, 0, - row_start_id - }; + row_start_id}; } } - // log_debug("writer[{}]: {},{} = {} {}", dst_buffer->address(), core.x, core.y, block_size_nbytes, row_start_id); + // log_debug("writer[{}]: {},{} = {} {}", dst_buffer->address(), core.x, core.y, block_size_nbytes, + // row_start_id); - tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - core, - reader_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - writer_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); cores_with_rtargs.push_back(core); tile_start_id += ntiles_per_block * nblocks_per_core; row_start_id += TILE_HEIGHT * nblocks_per_core; } if (ncores_full < ncores) { // last core is the cliff core with nblocks_per_core_cliff blocks - CoreCoord core = row_major ? CoreCoord{ ncores_full % ncores_x, ncores_full / ncores_x} : CoreCoord{ ncores_full / ncores_y, ncores_full % ncores_y}; + CoreCoord core = row_major ? CoreCoord{ncores_full % ncores_x, ncores_full / ncores_x} + : CoreCoord{ncores_full / ncores_y, ncores_full % ncores_y}; // reader runtime args vector reader_rt_args; if (src_sharded) { reader_rt_args = { - ntiles_per_block * nblocks_per_core_cliff // ntiles + ntiles_per_block * nblocks_per_core_cliff // ntiles }; } else { reader_rt_args = { - src0_buffer->address(), // src_addr - (uint32_t) ntiles_per_block * nblocks_per_core_cliff, // ntiles - tile_start_id // start_id + src0_buffer->address(), // src_addr + (uint32_t)ntiles_per_block * nblocks_per_core_cliff, // ntiles + tile_start_id // start_id }; } - // log_debug("reader: {},{} = {} ({})", core.x, core.y, tile_start_id, ntiles_per_block * nblocks_per_core_cliff); + // log_debug("reader: {},{} = {} ({})", core.x, core.y, tile_start_id, ntiles_per_block * + // nblocks_per_core_cliff); // writer runtime args vector writer_rt_args; if (out_sharded) { writer_rt_args = { - ntiles_per_block * nblocks_per_core_cliff // ntiles + ntiles_per_block * nblocks_per_core_cliff // ntiles }; } else { if (src_block_sharded) { @@ -374,7 +362,7 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out } } writer_rt_args = { - dst_buffer->address(), // dst_addr + dst_buffer->address(), // dst_addr num_rows_block, block_row_size, 1, @@ -384,54 +372,37 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out row_size_unpadded, num_rows_unpadded, block_start_row_id_offset, - block_start_row_offset - }; + block_start_row_offset}; } else { writer_rt_args = { - dst_buffer->address(), // dst_addr - nblocks_per_core_cliff * TILE_HEIGHT, // nsticks - block_size_nbytes, // stick_size_nbytes - ntiles_per_block, // ntiles_per_block - block_size_nbytes, // block_width_nbytes - 1, // full blocks in a row - 0, // UNUSED - 0, // UNUSED - row_start_id - }; + dst_buffer->address(), // dst_addr + nblocks_per_core_cliff * TILE_HEIGHT, // nsticks + block_size_nbytes, // stick_size_nbytes + ntiles_per_block, // ntiles_per_block + block_size_nbytes, // block_width_nbytes + 1, // full blocks in a row + 0, // UNUSED + 0, // UNUSED + row_start_id}; } } // log_debug("writer: {},{} = {} {}", core.x, core.y, block_size_nbytes, row_start_id); - tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - core, - reader_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - writer_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); cores_with_rtargs.push_back(core); } - auto override_runtime_arguments_callback = [ - reader_kernel_id=unary_reader_kernel_id, - writer_kernel_id=unary_writer_kernel_id, - cb_src0=cb_src0, - cb_output=cb_output, - cores_with_rtargs - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - + auto override_runtime_arguments_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id, + cb_src0 = cb_src0, + cb_output = cb_output, + cores_with_rtargs]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); @@ -441,8 +412,8 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out if (src_sharded) { UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); } else { - for (const CoreCoord& core : cores_with_rtargs){ - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (const CoreCoord& core : cores_with_rtargs) { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); runtime_args[0] = src_buffer->address(); } } @@ -450,19 +421,196 @@ operation::ProgramWithCallbacks untilize_multi_core(const Tensor& a, Tensor& out if (out_sharded) { UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); } else { - for (const CoreCoord& core : cores_with_rtargs){ - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (const CoreCoord& core : cores_with_rtargs) { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); } } }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; } -// This purely supports input block shard -> output interleaved for now -operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor &a, Tensor& output, const Shape &output_tensor_start, const Shape &output_tensor_end, bool use_pack_untilize, bool fp32_dest_acc_en) { +operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( + const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { + tt_metal::Program program = tt_metal::CreateProgram(); + + DataFormat input_cb_data_format = datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = detail::TileSize(input_cb_data_format); + DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = detail::TileSize(output_cb_data_format); + + const Shape& input_shape = a.get_legacy_shape(); + const Shape& output_shape = output.get_legacy_shape(); + + Device* device = a.device(); + CoreCoord grid_size = device->compute_with_storage_grid_size(); + + uint32_t num_blocks = a.volume() / output_shape[-1] / TILE_HEIGHT; + uint32_t num_tiles_per_row = a.get_legacy_shape()[-1] / TILE_WIDTH; + + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + split_blocks_for_tilize(grid_size, num_blocks); + + bool has_cliff = core_range_cliff.size() > 0; + + uint32_t padded_row_size_bytes = input_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat + uint32_t unpadded_row_size_bytes = output_shape[-1] * a.element_size(); // Assuming bfloat16 dataformat + uint32_t src0_cb_index = CB::c_in0; + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig( + num_tiles_per_row * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, input_single_tile_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config); + + uint32_t output_cb_index = CB::c_out0; + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig( + num_tiles_per_row * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); + auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_output_config); + + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + /** reader + */ + uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + + KernelHandle unary_reader_kernel_id = CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/reader_unary_interleaved_start_id.cpp", + all_cores, + ReaderDataMovementConfig({src0_is_dram})); + + /** writer + */ + uint32_t out_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t stick_size = unpadded_row_size_bytes; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0; + + KernelHandle unary_writer_kernel_id = CreateKernel( + program, + "tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/" + "writer_unary_stick_layout_split_rows_multicore.cpp", + all_cores, + WriterDataMovementConfig( + {out_is_dram, + stick_size_is_power_of_two, + log2_stick_size, + input_cb_data_format == tt::DataFormat::Float32})); + + /** compute + */ + std::string compute_kernel("tt_eager/tt_dnn/op_library/untilize/kernels/compute/pack_untilize.cpp"); + if (num_tiles_per_row > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) { + compute_kernel = "tt_eager/tt_dnn/op_library/untilize/kernels/compute/untilize.cpp"; + } + + if (core_range.size() > 0) { + auto tilize_kernel_id = CreateKernel( + program, + compute_kernel, + core_range, + ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = {nblocks_per_core, num_tiles_per_row}}); + } + if (has_cliff) { + auto tilize_cliff_kernel_id = CreateKernel( + program, + compute_kernel, + core_range_cliff, + ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = {nblocks_per_core_cliff, num_tiles_per_row}}); + } + + auto input_w = input_shape.rank() >= 4 ? input_shape[-4] : 1; + auto input_z = input_shape.rank() >= 3 ? input_shape[-3] : 1; + auto input_y = input_shape.rank() >= 2 ? input_shape[-2] : 1; + auto input_x = input_shape[-1]; + + auto output_w = output_shape.rank() >= 4 ? output_shape[-4] : 1; + auto output_z = output_shape.rank() >= 3 ? output_shape[-3] : 1; + auto output_y = output_shape.rank() >= 2 ? output_shape[-2] : 1; + auto output_x = output_shape[-1]; + + Padding padding( + {{0, input_w - output_w}, {0, input_z - output_z}, {0, input_y - output_y}, {0, input_x - output_x}}, + Padding::PadValue::Any); + auto core_assignments = + distribute_work(output_shape, padding, ncores, nblocks_per_core, has_cliff, nblocks_per_core_cliff); + + uint32_t tile_start_id = 0; + uint32_t row_start_id = 0; + uint32_t ncores_x = grid_size.x; + + for (uint32_t i = 0; i < ncores; ++i) { + const std::vector& assignment = core_assignments.at(i); + + uint32_t num_tiles_per_core = num_tiles_per_row * nblocks_per_core; + + // reader runtime args + vector reader_rt_args = {src0_buffer->address(), num_tiles_per_core, tile_start_id}; + + // writer runtime args + vector writer_rt_args = { + dst_buffer->address(), + unpadded_row_size_bytes, + padded_row_size_bytes, + row_start_id, + static_cast(assignment.size()), + }; + + uint32_t nblocks_per_core = 0; + + for (const auto& el : assignment) { + nblocks_per_core += el.block_count(); + row_start_id += el.data_row_count(); + writer_rt_args.push_back(el.n_data); + writer_rt_args.push_back(el.n_mixed); + writer_rt_args.push_back(el.n_pads); + writer_rt_args.push_back(el.times); + } + + CoreCoord core = {i % ncores_x, i / ncores_x}; + + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); + SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); + + tile_start_id += num_tiles_per_core; + } + + auto override_runtime_args_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id, + ncores = ncores, + ncores_x = ncores_x]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + for (uint32_t i = 0; i < ncores; ++i) { + CoreCoord core = {i % ncores_x, i / ncores_x}; + { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +// This purely supports input block shard -> output interleaved for now +operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_sharded( + const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { tt_metal::Program program = tt_metal::CreateProgram(); bool src_sharded = a.memory_config().is_sharded(); @@ -473,10 +621,11 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor DataFormat output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_cb_data_format); - Device *device = a.device(); + Device* device = a.device(); auto grid_size = device->compute_with_storage_grid_size(); - uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0, num_output_rows_unpadded = 0; + uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0, + num_output_rows_unpadded = 0; CoreCoord end_core; uint32_t last_idx = 0; auto shard_spec = a.shard_spec().value(); @@ -496,9 +645,11 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor uint32_t ntiles_per_batch = ntiles_per_block * nblocks_per_core / batch; num_rows_block = out_shard_spec.shape[0]; - block_row_size = out_shard_spec.shape[1] * output.element_size(); // in0_block_w * TILE_WIDTH * dtype_nbytes - output_row_size = output.get_legacy_shape()[-1] * output.element_size(); // output row size bytes - last_block_row_size_unpadded = block_row_size - (round_up(output.get_legacy_shape()[-1], out_shard_spec.shape[1]) - output.get_legacy_shape()[-1]) * output.element_size(); + block_row_size = out_shard_spec.shape[1] * output.element_size(); // in0_block_w * TILE_WIDTH * dtype_nbytes + output_row_size = output.get_legacy_shape()[-1] * output.element_size(); // output row size bytes + last_block_row_size_unpadded = block_row_size - (round_up(output.get_legacy_shape()[-1], out_shard_spec.shape[1]) - + output.get_legacy_shape()[-1]) * + output.element_size(); uint32_t num_output_rows = output.volume() / output.get_legacy_shape()[-1]; num_output_rows_unpadded = num_rows_block - (round_up(num_output_rows, out_shard_spec.shape[0]) - num_output_rows); if (a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { @@ -506,15 +657,19 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor } else if (a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { last_idx = div_up(num_output_rows, out_shard_spec.shape[0]) - 1; } else { - end_core = {div_up(output.get_legacy_shape()[-1], out_shard_spec.shape[1]) - 1, div_up(num_output_rows, out_shard_spec.shape[0]) - 1}; + end_core = { + div_up(output.get_legacy_shape()[-1], out_shard_spec.shape[1]) - 1, + div_up(num_output_rows, out_shard_spec.shape[0]) - 1}; } if (!row_major) { std::swap(end_core.x, end_core.y); } uint32_t src0_cb_index = CB::c_in0; uint32_t num_input_tiles = ntiles_per_block * nblocks_per_core; - tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) - .set_page_size(src0_cb_index, input_single_tile_size); + tt_metal::CircularBufferConfig src0_cb_config = + tt_metal::CircularBufferConfig( + num_input_tiles * input_single_tile_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, input_single_tile_size); if (src_sharded) { src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer()); } @@ -522,28 +677,31 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor uint32_t output_cb_index = CB::c_out0; uint32_t num_output_tiles = out_sharded ? ntiles_per_batch * 2 : ntiles_per_block * 2; - tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) - .set_page_size(output_cb_index, output_single_tile_size); + tt_metal::CircularBufferConfig output_cb_config = + tt_metal::CircularBufferConfig( + num_output_tiles * output_single_tile_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, output_single_tile_size); auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); CBHandle cb_sharded_output = 0; uint32_t sharded_output_cb_index = CB::c_out1; if (out_sharded) { - tt_metal::CircularBufferConfig sharded_output_cb_config = tt_metal::CircularBufferConfig(num_output_rows_unpadded * block_row_size, {{sharded_output_cb_index, output_cb_data_format}}) - .set_page_size(sharded_output_cb_index, block_row_size).set_globally_allocated_address(*output.buffer()); + tt_metal::CircularBufferConfig sharded_output_cb_config = + tt_metal::CircularBufferConfig( + num_output_rows_unpadded * block_row_size, {{sharded_output_cb_index, output_cb_data_format}}) + .set_page_size(sharded_output_cb_index, block_row_size) + .set_globally_allocated_address(*output.buffer()); cb_sharded_output = tt_metal::CreateCircularBuffer(program, all_cores, sharded_output_cb_config); } - Buffer *src0_buffer = a.buffer(); - Buffer *dst_buffer = output.buffer(); + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); /** reader */ KernelHandle unary_reader_kernel_id; - std::vector reader_ct_args = { - (std::uint32_t) src0_cb_index - }; + std::vector reader_ct_args = {(std::uint32_t)src0_cb_index}; unary_reader_kernel_id = tt_metal::CreateKernel( program, @@ -555,10 +713,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor */ KernelHandle unary_writer_kernel_id; if (out_sharded) { - vector writer_ct_args = { - (uint32_t) output_cb_index, - (uint32_t) sharded_output_cb_index - }; + vector writer_ct_args = {(uint32_t)output_cb_index, (uint32_t)sharded_output_cb_index}; unary_writer_kernel_id = CreateKernel( program, "tt_eager/tt_dnn/op_library/untilize/kernels/dataflow/writer_unary_unpad_batch_rows_sharded.cpp", @@ -567,9 +722,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor } else { bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; vector writer_ct_args = { - (uint32_t) out_is_dram, - (uint32_t) (input_cb_data_format == tt::DataFormat::Float32) - }; + (uint32_t)out_is_dram, (uint32_t)(input_cb_data_format == tt::DataFormat::Float32)}; unary_writer_kernel_id = CreateKernel( program, "tt_eager/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_blocks.cpp", @@ -580,8 +733,8 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor /** compute */ vector compute_args = { - (uint32_t) nblocks_per_core, // per_core_block_cnt - (uint32_t) ntiles_per_block, // per_block_ntiles + (uint32_t)nblocks_per_core, // per_core_block_cnt + (uint32_t)ntiles_per_block, // per_block_ntiles }; std::string compute_kernel("tt_eager/tt_dnn/op_library/untilize/kernels/compute/pack_untilize.cpp"); @@ -596,20 +749,13 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor program, compute_kernel, all_cores, - ComputeConfig{ - .fp32_dest_acc_en = fp32_dest_acc_en, - .compile_args = compute_args}); + ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_args}); // reader runtime args vector reader_rt_args = { - ntiles_per_block * nblocks_per_core // ntiles + ntiles_per_block * nblocks_per_core // ntiles }; - tt_metal::SetRuntimeArgs( - program, - unary_reader_kernel_id, - all_cores, - reader_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, all_cores, reader_rt_args); std::vector cores; if (out_sharded) { @@ -619,20 +765,14 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor out_shard_spec.shape[0] / batch, shard_spec.shape[1] * a.element_size(), block_row_size, - batch - }; - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - all_cores, - writer_rt_args - ); + batch}; + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, all_cores, writer_rt_args); } else { uint32_t tile_start_id = 0; uint32_t row_start_id = 0; cores = grid_to_cores(ncores, ncores_x, ncores_y, row_major); - for (uint32_t i = 0; i < cores.size(); ++i){ - CoreCoord &core = cores[i]; + for (uint32_t i = 0; i < cores.size(); ++i) { + CoreCoord& core = cores[i]; // writer runtime args vector writer_rt_args; @@ -646,8 +786,8 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor if (i == last_idx) { row_size_unpadded = last_block_row_size_unpadded; } else if (i > last_idx) { - row_size_unpadded = 0; - num_rows_unpadded = 0; + row_size_unpadded = 0; + num_rows_unpadded = 0; } } else if (a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { block_start_row_offset = 0; @@ -655,8 +795,8 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor if (i == last_idx) { num_rows_unpadded = num_output_rows_unpadded; } else if (i > last_idx) { - row_size_unpadded = 0; - num_rows_unpadded = 0; + row_size_unpadded = 0; + num_rows_unpadded = 0; } } else { if (row_major) { @@ -679,13 +819,13 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor } } if (core.x > end_core.x || core.y > end_core.y) { - row_size_unpadded = 0; - num_rows_unpadded = 0; + row_size_unpadded = 0; + num_rows_unpadded = 0; } } writer_rt_args = { - dst_buffer->address(), // dst_addr + dst_buffer->address(), // dst_addr num_rows_block, block_row_size, 1, @@ -695,35 +835,22 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor row_size_unpadded, num_rows_unpadded, block_start_row_id_offset, - block_start_row_offset - }; + block_start_row_offset}; - tt_metal::SetRuntimeArgs( - program, - unary_writer_kernel_id, - core, - writer_rt_args - ); + tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); } } - - - auto override_runtime_arguments_callback = [ - reader_kernel_id=unary_reader_kernel_id, - writer_kernel_id=unary_writer_kernel_id, - cb_src0=cb_src0, - cb_sharded_output=cb_sharded_output, - cores - ] - ( - const void* operation, - Program& program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - + auto override_runtime_arguments_callback = [reader_kernel_id = unary_reader_kernel_id, + writer_kernel_id = unary_writer_kernel_id, + cb_src0 = cb_src0, + cb_sharded_output = cb_sharded_output, + cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); @@ -735,14 +862,23 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor if (out_sharded) { UpdateDynamicCircularBufferAddress(program, cb_sharded_output, *dst_buffer); } else { - for (const CoreCoord& core : cores){ - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (const CoreCoord& core : cores) { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); runtime_args[0] = dst_buffer->address(); } } }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_arguments_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +operation::ProgramWithCallbacks untilize_with_unpadding_multi_core( + const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { + if (a.memory_config().is_sharded()) { + return untilize_with_unpadding_multi_core_sharded(a, output, use_pack_untilize, fp32_dest_acc_en); + } else { + return untilize_with_unpadding_multi_core_interleaved(a, output, use_pack_untilize, fp32_dest_acc_en); + } } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/untilize/single_core/untilize_op_single_core.cpp b/tt_eager/tt_dnn/op_library/untilize/single_core/untilize_op_single_core.cpp index 99e5ac1c040b..1e3f6c2d1998 100644 --- a/tt_eager/tt_dnn/op_library/untilize/single_core/untilize_op_single_core.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/single_core/untilize_op_single_core.cpp @@ -181,7 +181,7 @@ operation::ProgramWithCallbacks untilize_single_core(const Tensor &a, Tensor& ou } -operation::ProgramWithCallbacks untilize_with_unpadding_single_core(const Tensor &a, Tensor& output, const Shape &output_tensor_start, const Shape &output_tensor_end, bool use_pack_untilize, bool fp32_dest_acc_en) { +operation::ProgramWithCallbacks untilize_with_unpadding_single_core(const Tensor &a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { const Shape input_shape = a.get_legacy_shape(); const Shape output_shape = output.get_legacy_shape(); diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp b/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp index 7ca06831cc5c..a0d5bb4fce7d 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_op.cpp @@ -2,17 +2,17 @@ // // SPDX-License-Identifier: Apache-2.0 -#include +#include "tt_dnn/op_library/untilize/untilize_op.hpp" +#include -#include "tt_dnn/op_library/untilize/untilize_op.hpp" +#include "tensor/tensor_utils.hpp" #include "tt_dnn/op_library/copy/copy_op.hpp" -#include "tt_dnn/op_library/work_split.hpp" #include "tt_dnn/op_library/math.hpp" -#include "tensor/tensor_utils.hpp" -#include "tt_metal/host_api.hpp" +#include "tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" using namespace tt::constants; @@ -20,10 +20,10 @@ namespace tt { namespace tt_metal { -void Untilize::validate(const std::vector &input_tensors) const { +void Untilize::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to untilize need to be on device!"); - TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands to untilize need to be allocated in buffers on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands to untilize need to be allocated in buffers on device!"); TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only untilize tile major data"); TT_FATAL(input_tensor_a.volume() % TILE_HW == 0); @@ -41,8 +41,9 @@ void Untilize::validate(const std::vector &input_tensors) const { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); uint32_t ntiles = input_tensor_a.volume() / TILE_HW; uint32_t ntiles_per_block = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t nblocks = ceil((float) ntiles / ntiles_per_block); - auto num_cores = untilize_helpers::get_num_cores(input_tensor_a.device()->compute_with_storage_grid_size(), nblocks); + uint32_t nblocks = ceil((float)ntiles / ntiles_per_block); + auto num_cores = + untilize_helpers::get_num_cores(input_tensor_a.device()->compute_with_storage_grid_size(), nblocks); uint32_t fused_height = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]; TT_FATAL(fused_height % num_cores == 0); } else { @@ -51,50 +52,65 @@ void Untilize::validate(const std::vector &input_tensors) const { } } -std::vector Untilize::compute_output_shapes(const std::vector &input_tensors) const { +std::vector Untilize::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); return {input_tensor_a.get_legacy_shape()}; } -std::vector Untilize::create_output_tensors(const std::vector &input_tensors) const { +std::vector Untilize::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - DataType output_dtype = input_tensor.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor.get_dtype(); + DataType output_dtype = + input_tensor.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor.get_dtype(); if (output_mem_config.is_sharded()) { if (input_tensor.memory_config().is_sharded()) { auto mem_config = this->output_mem_config; mem_config.shard_spec = input_tensor.memory_config().shard_spec; - return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), output_dtype, Layout::ROW_MAJOR, input_tensor.device(), mem_config)}; + return {create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + output_dtype, + Layout::ROW_MAJOR, + input_tensor.device(), + mem_config)}; } else { uint32_t ntiles = input_tensor.volume() / TILE_HW; uint32_t ntiles_per_block = input_tensor.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t nblocks = ceil((float) ntiles / ntiles_per_block); - auto num_cores = untilize_helpers::get_num_cores(input_tensor.device()->compute_with_storage_grid_size(), nblocks); - auto shard_grid = num_cores_to_corerange_set(num_cores, input_tensor.device()->compute_with_storage_grid_size(), true); + uint32_t nblocks = ceil((float)ntiles / ntiles_per_block); + auto num_cores = + untilize_helpers::get_num_cores(input_tensor.device()->compute_with_storage_grid_size(), nblocks); + auto shard_grid = + num_cores_to_corerange_set(num_cores, input_tensor.device()->compute_with_storage_grid_size(), true); uint32_t fused_height = input_tensor.volume() / input_tensor.get_legacy_shape()[-1]; std::array shard_shape = {fused_height / num_cores, input_tensor.get_legacy_shape()[-1]}; ShardSpec shard_spec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), output_dtype, Layout::ROW_MAJOR, input_tensor.device(), mem_config)}; + return {create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + output_dtype, + Layout::ROW_MAJOR, + input_tensor.device(), + mem_config)}; } } else { - return operation::generic_create_output_tensors(*this, input_tensors, output_dtype, Layout::ROW_MAJOR, this->output_mem_config); + return operation::generic_create_output_tensors( + *this, input_tensors, output_dtype, Layout::ROW_MAJOR, this->output_mem_config); } } -operation::ProgramWithCallbacks Untilize::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { +operation::ProgramWithCallbacks Untilize::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); switch (this->get_parallelization_strategy(input_tensors)) { case UntilizeOpParallelizationStrategy::MULTI_CORE: return untilize_multi_core(input_tensor_a, output_tensor, use_pack_untilize, this->fp32_dest_acc_en); case UntilizeOpParallelizationStrategy::SINGLE_CORE: - default: - return untilize_single_core(input_tensor_a, output_tensor, use_pack_untilize, this->fp32_dest_acc_en); + default: return untilize_single_core(input_tensor_a, output_tensor, use_pack_untilize, this->fp32_dest_acc_en); } } -UntilizeOpParallelizationStrategy Untilize::get_parallelization_strategy(const std::vector &input_tensors) const { +UntilizeOpParallelizationStrategy Untilize::get_parallelization_strategy( + const std::vector& input_tensors) const { if (this->use_multicore) { return UntilizeOpParallelizationStrategy::MULTI_CORE; } else { @@ -102,56 +118,60 @@ UntilizeOpParallelizationStrategy Untilize::get_parallelization_strategy(const s } } -Tensor untilize(const Tensor &input_tensor_a, const MemoryConfig& output_mem_config, bool use_multicore, bool use_pack_untilize) { +Tensor untilize( + const Tensor& input_tensor_a, const MemoryConfig& output_mem_config, bool use_multicore, bool use_pack_untilize) { // No-op (Will do a tensor copy) std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; operation::launch_op( - [output_mem_config, use_multicore, use_pack_untilize] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [output_mem_config, use_multicore, use_pack_untilize]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { const auto& input_tensor_a = input_tensors.at(0); if (input_tensor_a.get_layout() == Layout::ROW_MAJOR) { log_warning("Perf warning: Trying to untilize non-tilized data."); return {AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config)}; } - bool fp32_dest_acc_en = input_tensor_a.get_dtype() == DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b - return operation::run_without_autoformat(Untilize{output_mem_config, use_multicore, use_pack_untilize, fp32_dest_acc_en}, {input_tensor_a}); - }, {input_tensor_a}, output_tensors); + bool fp32_dest_acc_en = + input_tensor_a.get_dtype() == + DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b + return operation::run_without_autoformat( + Untilize{output_mem_config, use_multicore, use_pack_untilize, fp32_dest_acc_en}, {input_tensor_a}); + }, + {input_tensor_a}, + output_tensors); return output_tensors.at(0); } - -void UntilizeWithUnpadding::validate(const std::vector &input_tensors) const { +void UntilizeWithUnpadding::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!"); - TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands need to be allocated in buffers on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!"); TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Can only untilize tile major data"); - TT_FATAL( - (this->output_tensor_start[0] == 0 && this->output_tensor_start[1] == 0 && this->output_tensor_start[2] == 0 && this->output_tensor_start[3] == 0), - "On device unpadding only supports unpadding at end of dims" - ); - TT_FATAL(input_tensor_a.volume() % TILE_HW == 0); for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) { - TT_FATAL(this->output_tensor_start[i] < input_tensor_a.get_legacy_shape()[i]); + TT_FATAL(input_tensor_a.get_legacy_shape()[i] > 0); TT_FATAL(this->output_tensor_end[i] < input_tensor_a.get_legacy_shape()[i]); - - // Check if start shape is <= end shape - TT_FATAL(this->output_tensor_start[i] <= this->output_tensor_end[i]); } - TT_FATAL(((this->output_tensor_end[-1] - this->output_tensor_start[-1] + 1) % 2 == 0), "Can only unpad to row major tensor of even width"); + TT_FATAL(((this->output_tensor_end[-1] + 1) % 2 == 0), "Can only unpad to row major tensor of even width"); if (input_tensor_a.memory_config().is_sharded()) { if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { TT_FATAL(input_tensor_a.shard_spec().value().grid.ranges().size() == 1); TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED); - TT_FATAL(input_tensor_a.volume() / (input_tensor_a.get_legacy_shape()[-2] * input_tensor_a.get_legacy_shape()[-1]) == 1, "Can only write unbatched output interleaved"); + TT_FATAL( + input_tensor_a.volume() / + (input_tensor_a.get_legacy_shape()[-2] * input_tensor_a.get_legacy_shape()[-1]) == + 1, + "Can only write unbatched output interleaved"); } else if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { if (output_mem_config.is_sharded()) { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); } // What else? - } else if(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + } else if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { auto output_shape = this->compute_output_shapes(input_tensors).at(0); // Minor host code changes required to remove this restriction TT_FATAL(input_tensor_a.shard_spec().value().grid.ranges().size() == 1); @@ -163,8 +183,14 @@ void UntilizeWithUnpadding::validate(const std::vector &input_tensors) c TT_FATAL(input_tensor_a.get_legacy_shape()[-1] == output_shape[-1]); } else { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED); - TT_FATAL(input_tensor_a.volume() / (input_tensor_a.get_legacy_shape()[-2] * input_tensor_a.get_legacy_shape()[-1]) == 1, "Can only write unbatched output interleaved"); - TT_FATAL(input_tensor_a.get_legacy_shape()[-1] - output_shape[-1] < input_tensor_a.shard_spec().value().shape[1]); + TT_FATAL( + input_tensor_a.volume() / + (input_tensor_a.get_legacy_shape()[-2] * input_tensor_a.get_legacy_shape()[-1]) == + 1, + "Can only write unbatched output interleaved"); + TT_FATAL( + input_tensor_a.get_legacy_shape()[-1] - output_shape[-1] < + input_tensor_a.shard_spec().value().shape[1]); } } else { TT_FATAL(false, "Unsupported sharding scheme"); @@ -174,19 +200,20 @@ void UntilizeWithUnpadding::validate(const std::vector &input_tensors) c TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } -std::vector UntilizeWithUnpadding::compute_output_shapes(const std::vector &input_tensors) const { +std::vector UntilizeWithUnpadding::compute_output_shapes(const std::vector& input_tensors) const { std::vector out_shape; auto rank = input_tensors[0].get_legacy_shape().rank(); out_shape.reserve(rank); for (uint32_t i = 0; i < rank; i++) { - out_shape.push_back(this->output_tensor_end[i] - this->output_tensor_start[i] + 1); + out_shape.push_back(this->output_tensor_end[i] + 1); } Shape output_tensor_shape(out_shape); return {output_tensor_shape}; } -std::vector UntilizeWithUnpadding::create_output_tensors(const std::vector &input_tensors) const { +std::vector UntilizeWithUnpadding::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); - DataType output_dtype = input_tensor_a.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor_a.get_dtype(); + DataType output_dtype = + input_tensor_a.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor_a.get_dtype(); if (input_tensor_a.memory_config().is_sharded() && this->output_mem_config.is_sharded()) { auto output_shape = this->compute_output_shapes(input_tensors).at(0); uint32_t fused_height = tt_metal::compute_volume(output_shape) / output_shape[-1]; @@ -201,56 +228,81 @@ std::vector UntilizeWithUnpadding::create_output_tensors(const std::vect shard_spec.shape = shard_shape; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor(this->compute_output_shapes(input_tensors).at(0), output_dtype, Layout::ROW_MAJOR, input_tensor_a.device(), mem_config)}; + return {create_device_tensor( + this->compute_output_shapes(input_tensors).at(0), + output_dtype, + Layout::ROW_MAJOR, + input_tensor_a.device(), + mem_config)}; } else { - return operation::generic_create_output_tensors(*this, input_tensors, output_dtype, Layout::ROW_MAJOR, this->output_mem_config); + return operation::generic_create_output_tensors( + *this, input_tensors, output_dtype, Layout::ROW_MAJOR, this->output_mem_config); } } -operation::ProgramWithCallbacks UntilizeWithUnpadding::create_program(const std::vector& input_tensors, std::vector &output_tensors) const { +operation::ProgramWithCallbacks UntilizeWithUnpadding::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); auto& output_tensor = output_tensors.at(0); switch (this->get_parallelization_strategy(input_tensors)) { case UntilizeWithUnpaddingOpParallelizationStrategy::MULTI_CORE: - return untilize_with_unpadding_multi_core(input_tensor_a, output_tensor, output_tensor_start, output_tensor_end, use_pack_untilize, this->fp32_dest_acc_en); + return untilize_with_unpadding_multi_core( + input_tensor_a, output_tensor, use_pack_untilize, this->fp32_dest_acc_en); break; case UntilizeWithUnpaddingOpParallelizationStrategy::SINGLE_CORE: - default: return untilize_with_unpadding_single_core(input_tensor_a, output_tensor, output_tensor_start, output_tensor_end, use_pack_untilize, this->fp32_dest_acc_en); + default: + return untilize_with_unpadding_single_core( + input_tensor_a, output_tensor, use_pack_untilize, this->fp32_dest_acc_en); } } -UntilizeWithUnpaddingOpParallelizationStrategy UntilizeWithUnpadding::get_parallelization_strategy(const std::vector &input_tensors) const { - if (input_tensors.at(0).memory_config().is_sharded()) { +UntilizeWithUnpaddingOpParallelizationStrategy UntilizeWithUnpadding::get_parallelization_strategy( + const std::vector& input_tensors) const { + if (input_tensors.at(0).memory_config().is_sharded() || this->use_multicore) { return UntilizeWithUnpaddingOpParallelizationStrategy::MULTI_CORE; } else { return UntilizeWithUnpaddingOpParallelizationStrategy::SINGLE_CORE; } } -Tensor untilize_with_unpadding(const Tensor &input_tensor_a, const Shape &output_tensor_start, const Shape &output_tensor_end, const MemoryConfig& output_mem_config, bool use_pack_untilize) { +Tensor untilize_with_unpadding( + const Tensor& input_tensor_a, + const Shape& output_tensor_end, + const MemoryConfig& output_mem_config, + bool use_pack_untilize, + bool use_multicore) { // No-op (Will do a tensor copy) // TODO: We need to run asserts before this std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; operation::launch_op( - [output_tensor_start, output_tensor_end, output_mem_config, use_pack_untilize] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [output_tensor_end, output_mem_config, use_pack_untilize]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { auto& input_tensor_a = input_tensors.at(0); const Shape output_tensor_shape = { - output_tensor_end[0] - output_tensor_start[0] + 1, - output_tensor_end[1] - output_tensor_start[1] + 1, - output_tensor_end[2] - output_tensor_start[2] + 1, - output_tensor_end[3] - output_tensor_start[3] + 1, + output_tensor_end[0] + 1, + output_tensor_end[1] + 1, + output_tensor_end[2] + 1, + output_tensor_end[3] + 1, }; if (input_tensor_a.get_layout() != Layout::TILE) { if (input_tensor_a.get_legacy_shape() == output_tensor_shape) { - log_warning("Perf warning: Untilize with unpadding called on already untilized tensor of target shape"); + log_warning( + "Perf warning: Untilize with unpadding called on already untilized tensor of target shape"); return {AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config)}; } else { TT_FATAL(false, "Cannot untilize and unpad input which is not tilized"); } } - bool fp32_dest_acc_en = input_tensor_a.get_dtype() == DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b - return operation::run_without_autoformat(UntilizeWithUnpadding{output_tensor_start, output_tensor_end, output_mem_config, use_pack_untilize, fp32_dest_acc_en}, {input_tensor_a}); - }, {input_tensor_a}, output_tensors); + // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b + bool fp32_dest_acc_en = input_tensor_a.get_dtype() == DataType::UINT32; + return operation::run_without_autoformat( + UntilizeWithUnpadding{output_tensor_end, output_mem_config, use_pack_untilize, fp32_dest_acc_en}, + {input_tensor_a}); + }, + {input_tensor_a}, + output_tensors); return output_tensors.at(0); } diff --git a/tt_eager/tt_dnn/op_library/untilize/untilize_op.hpp b/tt_eager/tt_dnn/op_library/untilize/untilize_op.hpp index 08170215b9f3..eb30aa62700a 100644 --- a/tt_eager/tt_dnn/op_library/untilize/untilize_op.hpp +++ b/tt_eager/tt_dnn/op_library/untilize/untilize_op.hpp @@ -12,11 +12,9 @@ namespace tt { namespace tt_metal { -#define MAX_PACK_UNTILIZE_WIDTH 8 // pack untilize currently does not support > 8 width +#define MAX_PACK_UNTILIZE_WIDTH 8 // pack untilize currently does not support > 8 width -enum class UntilizeOpParallelizationStrategy { - MULTI_CORE, SINGLE_CORE -}; +enum class UntilizeOpParallelizationStrategy { MULTI_CORE, SINGLE_CORE }; struct Untilize { const MemoryConfig output_mem_config; @@ -27,49 +25,71 @@ struct Untilize { void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; UntilizeOpParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; static constexpr auto attribute_names = - std::make_tuple("output_mem_config", "use_multicore"); + std::make_tuple("output_mem_config", "use_multicore", "use_pack_untilize", "fp32_dest_acc_en"); const auto attribute_values() const { return std::make_tuple( - std::cref(this->output_mem_config), std::cref(this->use_multicore)); + std::cref(this->output_mem_config), + std::cref(this->use_multicore), + std::cref(this->use_pack_untilize), + std::cref(this->fp32_dest_acc_en)); } }; -enum class UntilizeWithUnpaddingOpParallelizationStrategy { - MULTI_CORE, SINGLE_CORE -}; +enum class UntilizeWithUnpaddingOpParallelizationStrategy { MULTI_CORE, SINGLE_CORE }; struct UntilizeWithUnpadding { - const Shape output_tensor_start; const Shape output_tensor_end; const MemoryConfig output_mem_config; + const bool use_multicore; const bool use_pack_untilize; const bool fp32_dest_acc_en; void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; - UntilizeWithUnpaddingOpParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; + UntilizeWithUnpaddingOpParallelizationStrategy get_parallelization_strategy( + const std::vector &input_tensors) const; - static constexpr auto attribute_names = - std::make_tuple("output_tensor_start", "output_tensor_end", "output_mem_config"); + static constexpr auto attribute_names = std::make_tuple( + "output_tensor_end", "output_mem_config", "use_multicore", "use_pack_untilize", "fp32_dest_acc_en"); const auto attribute_values() const { return std::make_tuple( - std::cref(this->output_tensor_start), std::cref(this->output_tensor_end), std::cref(this->output_mem_config)); + std::cref(this->output_tensor_end), + std::cref(this->output_mem_config), + std::cref(this->use_multicore), + std::cref(this->use_pack_untilize), + std::cref(this->fp32_dest_acc_en)); } }; -operation::ProgramWithCallbacks untilize_multi_core(const Tensor &a, Tensor& output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); -operation::ProgramWithCallbacks untilize_single_core(const Tensor &a, Tensor& output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); -operation::ProgramWithCallbacks untilize_with_unpadding_multi_core(const Tensor &a, Tensor& output, const Shape &output_tensor_start, const Shape &output_tensor_end, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); -operation::ProgramWithCallbacks untilize_with_unpadding_single_core(const Tensor &a, Tensor& output, const Shape &output_tensor_start, const Shape &output_tensor_end, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); - -Tensor untilize (const Tensor &a, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, bool use_multicore = true, bool use_pack_untilize = true); -Tensor untilize_with_unpadding(const Tensor &a, const Shape &output_tensor_start, const Shape &output_tensor_end, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, bool use_pack_untilize = true); +operation::ProgramWithCallbacks untilize_multi_core( + const Tensor &a, Tensor &output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); +operation::ProgramWithCallbacks untilize_single_core( + const Tensor &a, Tensor &output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); + +operation::ProgramWithCallbacks untilize_with_unpadding_multi_core( + const Tensor &a, Tensor &output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); +operation::ProgramWithCallbacks untilize_with_unpadding_single_core( + const Tensor &a, Tensor &output, bool use_pack_untilize = true, bool fp32_dest_acc_en = false); + +Tensor untilize( + const Tensor &a, + const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + bool use_multicore = true, + bool use_pack_untilize = true); +Tensor untilize_with_unpadding( + const Tensor &a, + const Shape &output_tensor_end, + const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + bool use_multicore = false, + bool use_pack_untilize = true); // NOTE: UntilizeWithHalo is only for sharded input/output struct UntilizeWithHalo { @@ -85,7 +105,8 @@ struct UntilizeWithHalo { void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; static constexpr auto attribute_names = std::make_tuple( "pad_val", "in_b", "in_h", "in_w", "out_shard_size_max_per_core", "stride", "output_mem_config"); @@ -100,7 +121,14 @@ struct UntilizeWithHalo { std::cref(this->output_mem_config)); } }; -Tensor untilize_with_halo(const Tensor &a, const uint32_t pad_val, const uint32_t &in_b, const uint32_t &in_h, const uint32_t &in_w, const uint32_t stride = 1, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +Tensor untilize_with_halo( + const Tensor &a, + const uint32_t pad_val, + const uint32_t &in_b, + const uint32_t &in_h, + const uint32_t &in_w, + const uint32_t stride = 1, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); struct UntilizeWithHaloV2 { const uint32_t pad_val_; @@ -113,10 +141,11 @@ struct UntilizeWithHaloV2 { void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; - static constexpr auto attribute_names = - std::make_tuple("pad_val_", "ncores_nhw_", "max_out_nsticks_per_core_", "out_mem_config_", "remote_read_", "transpose_mcast_"); + static constexpr auto attribute_names = std::make_tuple( + "pad_val_", "ncores_nhw_", "max_out_nsticks_per_core_", "out_mem_config_", "remote_read_", "transpose_mcast_"); const auto attribute_values() const { return std::make_tuple( std::cref(pad_val_), @@ -124,8 +153,7 @@ struct UntilizeWithHaloV2 { std::cref(max_out_nsticks_per_core_), std::cref(out_mem_config_), std::cref(remote_read_), - std::cref(transpose_mcast_) - ); + std::cref(transpose_mcast_)); } }; Tensor untilize_with_halo_v2( diff --git a/tt_eager/tt_dnn/op_library/work_split_tilize.hpp b/tt_eager/tt_dnn/op_library/work_split_tilize.hpp index 84b3b0c065f4..fbad729c998a 100644 --- a/tt_eager/tt_dnn/op_library/work_split_tilize.hpp +++ b/tt_eager/tt_dnn/op_library/work_split_tilize.hpp @@ -8,9 +8,9 @@ #pragma once +#include "tensor/types.hpp" #include "tt_metal/common/core_coord.h" - namespace tt::tt_metal { struct BlockSplit { @@ -47,7 +47,7 @@ inline BlockSplit split_blocks_for_tilize(CoreCoord grid_size, uint32_t nblocks) core_range.insert(range); } else if (nblocks_per_core_cliff > 0) { // Last partial row (excluding last core) and single cliff core - if (ncores_x_cliff > 1) { // Add range only if there are cores before the cliff core + if (ncores_x_cliff > 1) { // Add range only if there are cores before the cliff core auto range = CoreRange{CoreCoord{0, ncores_y - 1}, CoreCoord{ncores_x_cliff - 2, ncores_y - 1}}; core_range.insert(range); } @@ -73,4 +73,165 @@ inline BlockSplit split_blocks_for_tilize(CoreCoord grid_size, uint32_t nblocks) return BlockSplit{ncores, all_cores, core_range, cliff_core_range, nblocks_per_core, nblocks_per_core_cliff}; } -} // namespace tt::tt_metal +// BlockRep represents a repeated sequence of data blocks, mixed blocks, and padding blocks. +// It is convient to pass to the device kernels because a single data structure made of 4 ints +// can represent pure data rows, pure padding rows or a mixture thereof. +struct BlockRep { + // number of data blocks + uint32_t n_data; + // number of mixed data rows in a mixed block, 0 means no mixed block + uint32_t n_mixed; + // number of padding blocks + uint32_t n_pads; + // total repeat times + uint32_t times; + + BlockRep(uint32_t n_data, uint32_t n_mixed, uint32_t n_pads, uint32_t times) : + n_data(n_data), n_mixed(n_mixed), n_pads(n_pads), times(times) { + if (n_data == 0 && n_mixed == 0) { + n_pads *= times; + times = 1; + } else if (n_pads == 0 && n_mixed == 0) { + n_data *= times; + times = 1; + } + } + + bool has_mixed_block() const { return n_mixed > 0; } + + uint32_t single_rep() const { return n_data + has_mixed_block() + n_pads; } + + uint32_t block_count() const { return single_rep() * times; } + + uint32_t data_row_count() const { return (n_data * 32 + n_mixed) * times; } + + std::pair, std::vector> split_at(uint32_t idx) const { + // TT_ASSERT(idx <= block_count()); + + std::vector first; + std::vector second; + + int rep_idx = idx / single_rep(); + if (rep_idx > 0) { + first.emplace_back(n_data, n_mixed, n_pads, rep_idx); + } + + int within_rep_idx = idx % single_rep(); + bool is_within_rep = within_rep_idx > 0; + if (is_within_rep) { + if (within_rep_idx <= n_data) { + first.emplace_back(within_rep_idx, 0, 0, 1); + second.emplace_back(n_data - within_rep_idx, n_mixed, n_pads, 1); + } else if (within_rep_idx == n_data + 1 && has_mixed_block()) { + first.emplace_back(n_data, n_mixed, 0, 1); + second.emplace_back(0, 0, n_pads, 1); + } else { + within_rep_idx -= n_data + has_mixed_block(); + first.emplace_back(n_data, n_mixed, within_rep_idx, 1); + second.emplace_back(0, 0, n_pads - within_rep_idx, 1); + } + } + + int remaining_times = times - rep_idx - is_within_rep; + if (remaining_times > 0) { + second.emplace_back(n_data, n_mixed, n_pads, remaining_times); + } + + return {first, second}; + } +}; + +// FullRep is a repeated sequence of data rows followed by pure padding. It represents the row +// pattern seen from the outer-most dimension of a 4D tensor when padding is added to the second +// or the thrird dimension. +struct FullRep { + uint32_t n_rows; + uint32_t n_pads; + uint32_t times; + + uint32_t pads_mul; + uint32_t times_total; + + BlockRep rep; + BlockRep pad; + + FullRep(uint32_t n_rows, uint32_t n_pads, uint32_t times, uint32_t pads_mul, uint32_t times_total) : + n_rows(n_rows), + n_pads(n_pads), + times(times), + pads_mul(pads_mul), + times_total(times_total), + rep{n_rows / 32, n_rows % 32, n_pads / 32, times}, + pad{0, 0, (n_rows + n_pads) * pads_mul, 1} { + // TT_ASSERT((n_rows + n_pads) % 32 == 0 && "total rows must be divisible by 32"); + } + + std::vector to_block_reps() const { + std::vector block_reps; + block_reps.reserve(2 * times_total); + + for (int i = 0; i < times_total; ++i) { + block_reps.push_back(rep); + block_reps.push_back(pad); + } + + return block_reps; + } +}; + +inline std::vector> distribute_work( + const Shape& unpadded, const Padding& padding, uint32_t num_cores, uint32_t blocks_per_core, bool has_cliff, uint32_t nblocks_per_core_cliff) { + auto input_w = unpadded.rank() >= 4 ? unpadded[-4] : 1; + auto input_z = unpadded.rank() >= 3 ? unpadded[-3] : 1; + auto input_y = unpadded.rank() >= 2 ? unpadded[-2] : 1; + + auto padding_w = unpadded.rank() >= 4 ? padding[unpadded.get_normalized_index(-4)].back : 0; + auto padding_z = unpadded.rank() >= 3 ? padding[unpadded.get_normalized_index(-3)].back : 0; + auto padding_y = unpadded.rank() >= 2 ? padding[unpadded.get_normalized_index(-2)].back : 0; + + // total work is a full rep followed by a padding. + auto full_rep_blocks = FullRep(input_y, padding_y, input_z, padding_z, input_w).to_block_reps(); + std::deque total_work(full_rep_blocks.begin(), full_rep_blocks.end()); + total_work.emplace_back(0, 0, (input_y + padding_y) * (input_z + padding_z) * padding_w, 1); + + std::vector> core_assignments; + core_assignments.reserve(num_cores); + + for (int i = 0; i < num_cores; i++) { + int blocks_to_process = blocks_per_core; + if (i == num_cores - 1 && has_cliff) { + blocks_to_process = nblocks_per_core_cliff; + } + + // Assign blocks to cores + std::vector core_blocks; + int core_blocks_count = 0; + while (core_blocks_count < blocks_to_process) { + if (total_work.empty()) { + break; + } + + int remaining_core_blocks = blocks_to_process - core_blocks_count; + auto& first = total_work.front(); + if (first.block_count() <= remaining_core_blocks) { + core_blocks.push_back(first); + core_blocks_count += first.block_count(); + total_work.pop_front(); + } else { + auto [head, tail] = first.split_at(remaining_core_blocks); + for (auto& el : head) { + core_blocks.push_back(el); + core_blocks_count += el.block_count(); + } + total_work.pop_front(); + total_work.insert(total_work.begin(), tail.begin(), tail.end()); + } + } + + core_assignments.push_back(core_blocks); + } + + return core_assignments; +} + +} // namespace tt::tt_metal diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp index 2163efa11719..dea7965862cd 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_dm_ops.cpp @@ -246,8 +246,8 @@ namespace tt::tt_metal::detail{ )doc"); m_tensor.def("untilize_with_unpadding", &untilize_with_unpadding, - py::arg("input").noconvert(), py::arg("output_tensor_start"), py::arg("output_tensor_end"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("use_pack_untilize").noconvert() = true, + py::arg("input").noconvert(), py::arg("output_tensor_end"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("use_multicore").noconvert() = false, py::arg("use_pack_untilize").noconvert() = true, R"doc( Changes data layout of input tensor to ROW_MAJOR and unpads/removes elements from the tensor. @@ -259,9 +259,9 @@ namespace tt::tt_metal::detail{ :header: "Argument", "Description", "Data type", "Valid range", "Required" "input", "Input tensor", "Tensor", "Tensor of shape [W, Z, Y, X] where Y%32=0 and X%32=0", "Yes" - "output_tensor_start", "Start indices of input tensor", "List[int[4]]", "Values along each dim must be < input_tensor_shape[i]", "Yes" "output_tensor_end", "End indices of input tensor in output tensor", "List[int[4]]", "Values along each dim must be < input_tensor_shape[i]", "Yes" "pad_value", "Value to pad input tensor", "float", "", "Yes" + "use_multicore", "Whether to use multi-core parallelization", "bool", "Default is false", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); diff --git a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp index b2b246fb258e..2a34d8609a5b 100644 --- a/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/op_library/to_layout/to_layout_op.cpp @@ -85,8 +85,9 @@ Tensor execute( memory_config.value_or(ttnn::get_memory_config(tensor).value_or(ttnn::DRAM_MEMORY_CONFIG)); if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) { + bool use_multicore = true; + if (not requires_padding_change(layout, tensor.get_shape())) { - bool use_multicore = true; if (layout == ttnn::ROW_MAJOR_LAYOUT) { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); return tt::tt_metal::untilize(tensor, output_memory_config, use_multicore); @@ -118,13 +119,11 @@ Tensor execute( output_tensor_end.push_back(tensor.get_shape()[index] - 1); } - tensor = - tt::tt_metal::untilize_with_unpadding(tensor, {0, 0, 0, 0}, output_tensor_end, output_memory_config); + tensor = tt::tt_metal::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore); return reshape(tensor, ttnn::Shape(tt::tt_metal::Shape{output_shape})); } else if (layout == ttnn::TILE_LAYOUT) { tensor = unsqueeze_to_4D(tensor); - bool use_multicore = true; std::vector padded_4D_output_shape; padded_4D_output_shape.push_back(tensor.get_shape()[-4]); padded_4D_output_shape.push_back(tensor.get_shape()[-3]);