diff --git a/tests/ttnn/unit_tests/operations/test_pad.py b/tests/ttnn/unit_tests/operations/test_pad.py index 60a3edc8e33..fa1373d59fe 100644 --- a/tests/ttnn/unit_tests/operations/test_pad.py +++ b/tests/ttnn/unit_tests/operations/test_pad.py @@ -128,6 +128,121 @@ def run_pad_rm_sharded(device, n, c, h, w, padding, torch_padding, value, shard_ assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) +def to_torch_padding(padspec): + def flatten_to_tuple(padding): + return tuple(sum(padding, ())) + + def ttnn_pad_spec_to_padding(padspec): + input_tensor_start = padspec["input_tensor_start"] + pad_to_shape = padspec["pad_to_shape"] + input_shape = padspec["input_shape"] + + padding = [] + for i in range(len(pad_to_shape)): + this_dim_padding = (input_tensor_start[i], pad_to_shape[i] - input_shape[i] - input_tensor_start[i]) + padding.append(this_dim_padding) + return padding + + torch_padding = flatten_to_tuple(reversed(ttnn_pad_spec_to_padding(padspec))) + return torch_padding + + +@pytest.mark.parametrize( + "input_shape, pad_to_shape, input_tensor_start, pad_value, input_sharded_memory_config_args", + [ + [ + (1, 1, 1, 4), + (1, 1, 1, 16), + (0, 0, 0, 0), + 3.0, + {"core_grid": ttnn.CoreGrid(x=1, y=1), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # a reduced version of esmal's test case for UNet + (1, 1, 4, 4), + (1, 1, 4, 16), + (0, 0, 0, 0), + 3.0, + {"core_grid": ttnn.CoreGrid(x=1, y=1), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # width padding across large core grid, 3 sticks per core + (1, 1, 3 * 64, 4), + (1, 1, 3 * 64, 16), + (0, 0, 0, 0), + 0.0, + {"core_grid": ttnn.CoreGrid(x=8, y=8), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # width padding across large core grid, 3 sticks per core, n300 version + (1, 1, 3 * 8 * 7, 4), + (1, 1, 3 * 8 * 7, 16), + (0, 0, 0, 0), + 0.0, + {"core_grid": ttnn.CoreGrid(x=8, y=7), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # width padding only, reduced core grid + (1, 1, 12, 8), + (1, 1, 12, 64), + (0, 0, 0, 0), + 3.0, + {"core_grid": ttnn.CoreGrid(x=2, y=6), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # height and width padding, small core grid + (1, 1, 2, 4), + (1, 1, 4, 8), + (0, 0, 0, 0), + 3.0, + {"core_grid": ttnn.CoreGrid(x=1, y=2), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + [ + # borys's second test case + (1, 2, 3, 4), + (1, 2, 32, 32), + (0, 0, 0, 0), + 3.0, + {"core_grid": ttnn.CoreGrid(x=1, y=6), "strategy": ttnn.ShardStrategy.HEIGHT}, + ], + ], +) +def test_pad_rm_sharded_stickwise( + device, input_shape, pad_to_shape, input_tensor_start, pad_value, input_sharded_memory_config_args +): + core_grid_x_ok = device.core_grid.x >= input_sharded_memory_config_args["core_grid"].x + core_grid_y_ok = device.core_grid.y >= input_sharded_memory_config_args["core_grid"].y + device_core_grid_ok = core_grid_x_ok and core_grid_y_ok + if not device_core_grid_ok: + pytest.skip("core grid for this test is not compatible with the device") + + input_shard_memory_config = ttnn.create_sharded_memory_config(input_shape, **input_sharded_memory_config_args) + + torch_input_tensor = torch.ones(input_shape, dtype=torch.float32) + ttnn_input_tensor = ttnn.from_torch( + torch_input_tensor, dtype=ttnn.float32, layout=ttnn.ROW_MAJOR_LAYOUT, device=device + ) + ttnn_sharded_input_tensor = ttnn.to_memory_config(ttnn_input_tensor, input_shard_memory_config) + + padded_tensor = ttnn.pad(ttnn_sharded_input_tensor, pad_to_shape, input_tensor_start, pad_value) + + tt_output_tensor = ttnn.to_memory_config(padded_tensor, ttnn.L1_MEMORY_CONFIG) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + + padspec = { + "input_shape": input_shape, + "pad_to_shape": pad_to_shape, + "input_tensor_start": input_tensor_start, + } + torch_padded_tensor = torch.nn.functional.pad( + torch_input_tensor, to_torch_padding(padspec), mode="constant", value=pad_value + ) + + assert torch_output_tensor.shape == torch_padded_tensor.shape + assert_with_pcc(torch_padded_tensor, torch_output_tensor, 0.99) + + @pytest.mark.parametrize("n", [20]) @pytest.mark.parametrize("c", [3]) @pytest.mark.parametrize("h", [224]) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp index e28802fd1a7..16a027eb9de 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp @@ -55,6 +55,152 @@ ttnn::Tensor pad_to_tile_vol( return tensor; } uint32_t wrap_index(int index, int size) { return index < 0 ? size + index : index; } + +std::array compute_block_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const tt::tt_metal::Layout& layout, + const tt::tt_metal::CoreCoord& grid_size, + const tt::tt_metal::ShardOrientation& orientation, + const uint32_t total_num_cores) { + TT_FATAL(grid_size.y * grid_size.x == total_num_cores, "compute_block_sharded_shard_shape received a core grid shape that does not match the total number of cores"); + auto adjusted_grid_size = grid_size; + if (orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) { + // for col major, we partition the width of the tensor along the height of the core grid + std::swap(adjusted_grid_size.x, adjusted_grid_size.y); + } + + auto [tensor_height, tensor_width] = squeezed_tensor_hw; + auto tensor_height_padded_to_tile = + layout == tt::tt_metal::Layout::TILE + ? tt::round_up(tensor_height, adjusted_grid_size.y * tt::constants::TILE_HEIGHT) + : tensor_height; + std::array shard_shape = {tt::div_up(tensor_height_padded_to_tile, adjusted_grid_size.y), + tt::div_up(tensor_width, adjusted_grid_size.x)}; + + return shard_shape; +} + +std::array compute_width_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const uint32_t total_num_cores) { + return {squeezed_tensor_hw[0], tt::div_up(squeezed_tensor_hw[1], total_num_cores)}; +} + +std::array compute_height_sharded_shard_shape(const std::array& squeezed_tensor_hw, + const tt::tt_metal::Layout& layout, + const uint32_t total_num_cores) { + auto [tensor_height, tensor_width] = squeezed_tensor_hw; + auto squeezed_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE + ? tt::round_up(tensor_height, total_num_cores) + : tensor_height; + return {tt::div_up(squeezed_height_padded_to_tile, total_num_cores), tensor_width}; +} + +ttnn::MemoryConfig create_sharded_memory_config( + const ttnn::SimpleShape& logical_shape, + const tt::tt_metal::CoreRangeSet& core_grid, + const ShardStrategy& strategy, + const tt::tt_metal::ShardOrientation& orientation, + std::optional> shard_shape, + const tt::tt_metal::Layout& layout, + bool halo) { + auto rank = logical_shape.rank(); + TT_FATAL(rank >= 2, "rank of tensor to shard must be at least 2."); + + ttnn::TensorMemoryLayout tensor_memory_layout; + if (strategy == ShardStrategy::BLOCK) { + tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED; + } else if (strategy == ShardStrategy::WIDTH) { + tensor_memory_layout = ttnn::TensorMemoryLayout::WIDTH_SHARDED; + } else if (strategy == ShardStrategy::HEIGHT) { + tensor_memory_layout = ttnn::TensorMemoryLayout::HEIGHT_SHARDED; + } + + auto height = logical_shape[-2]; + auto width = logical_shape[-1]; + std::array computed_shard_shape; + + if (shard_shape.has_value()) { + computed_shard_shape = shard_shape.value(); + } else { + uint32_t batch_size = 1; + for (int i = 0; i < rank - 2; i++) { + batch_size *= logical_shape[i]; + } + + auto tensor_height = batch_size * height; + auto tensor_width = width; + std::array squeezed_tensor_hw{tensor_height, tensor_width}; + auto total_num_cores = core_grid.num_cores(); + CoreCoord grid_size = core_grid.bounding_box().grid_size(); + + switch (strategy) { + case ShardStrategy::BLOCK: + computed_shard_shape = compute_block_sharded_shard_shape(squeezed_tensor_hw, layout, grid_size, orientation, total_num_cores); + break; + case ShardStrategy::WIDTH: + computed_shard_shape = compute_width_sharded_shard_shape(squeezed_tensor_hw, total_num_cores); + break; + case ShardStrategy::HEIGHT: + computed_shard_shape = compute_height_sharded_shard_shape(squeezed_tensor_hw, layout, total_num_cores); + break; + default: + TT_ASSERT(false, "Invalid shard strategy"); + } + } + + if (layout == tt::tt_metal::Layout::TILE) { + auto [shard_height, shard_width] = computed_shard_shape; + auto tile_divides_shard_height = shard_height % tt::constants::TILE_HEIGHT == 0; + auto tile_divides_shard_width = shard_width % tt::constants::TILE_WIDTH == 0; + TT_FATAL(tile_divides_shard_width && tile_divides_shard_height, + "For sharding tiled tensors, the shard shape must fit neatly into tiles but " + "create_sharded_memory_config got shard width {} and shard height {} while " + "on this architecture we have tile width {} and tile height {}", + computed_shard_shape[0], computed_shard_shape[1], tt::constants::TILE_WIDTH, tt::constants::TILE_HEIGHT); + } + + auto shard_spec = tt::tt_metal::ShardSpec(core_grid, computed_shard_shape, orientation, halo); + return ttnn::MemoryConfig(tensor_memory_layout, ttnn::BufferType::L1, shard_spec); +} + +std::pair> tensor_coord_to_height_sharded_coord( + const std::span& tensor_shape, + const std::span& shard_shape, + const std::span& tensor_coord) { + std::array tensor_shape_2d{0, 0}; + for (size_t i = 0; i < tensor_shape.size(); i++) { + if (i == tensor_shape.size() - 1) { + // width dimension, goes unmodified + tensor_shape_2d[1] = tensor_shape[i]; + } else { + // height dimension, squeeze into 2D shape + if (tensor_shape_2d[0] == 0) { + // first time we've seen this dimension + tensor_shape_2d[0] = tensor_shape[i]; + } else { + tensor_shape_2d[0] *= tensor_shape[i]; + } + } + } + + std::array tensor_coord_2d{0, tensor_coord.back()}; + uint32_t height_2d = 0; + for (size_t i = 0; i < tensor_coord.size() - 1; i++) { + std::vector page_shapes(tensor_shape.begin() + i + 1, tensor_shape.end() - 1); + auto component_sum = + tensor_coord[i] * std::accumulate(page_shapes.begin(), page_shapes.end(), 1, std::multiplies()); + height_2d += component_sum; + } + tensor_coord_2d[0] = height_2d; + + uint32_t shard_height = shard_shape[0]; + uint32_t w_in_shard = tensor_coord_2d[1]; + uint32_t h_in_shard = height_2d % shard_height; + uint32_t which_shard = height_2d / shard_height; + + std::array shard_coord{h_in_shard, w_in_shard}; + return std::make_pair(which_shard, shard_coord); +} + } // namespace data_movement } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp index c70b0afd296..78938828448 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp @@ -156,6 +156,27 @@ ttnn::Tensor pad_to_tile_vol( const bool use_multicore, const std::optional& memory_config); +enum class ShardStrategy { BLOCK, HEIGHT, WIDTH }; + +// Helper function for creating a sharded memory configuration for a tensor +// based on its logical shape, a shard strategy and orientation, and a core +// grid. Optionally, you may pass a preferred shard shape to use. If not +// provided, the shard shape will be inferred from the tensor shape and the +// shard strategy. +ttnn::MemoryConfig create_sharded_memory_config( + const ttnn::SimpleShape& logical_shape, + const tt::tt_metal::CoreRangeSet& core_grid, + const ShardStrategy& strategy, + const tt::tt_metal::ShardOrientation& orientation, + std::optional> shard_shape = std::nullopt, + const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR, + bool halo = false); + +std::pair> tensor_coord_to_height_sharded_coord( + const std::span& tensor_shape, + const std::span& shard_shape, + const std::span& tensor_coord); + } // namespace data_movement } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp new file mode 100644 index 00000000000..48dff936eea --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp @@ -0,0 +1,42 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// This file contains common kernel functions used for debugging +#pragma once +#include "debug/dprint.h" +namespace tt::data_movement::common { +inline void print_bf16_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * elts_per_page; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } +} + +inline void print_f32_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(l1_addr) + start * elts_per_page; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) { + DPRINT << F32(*ptr) << " "; + } + DPRINT << ENDL(); + } +} + +inline void print_u8_pages(uint32_t l1_addr, uint32_t bytes_per_page, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint8_t* ptr = reinterpret_cast(l1_addr) + start * bytes_per_page; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < bytes_per_page; ++j, ++ptr) { + DPRINT << SETW(2) << HEX() << "0x" << (uint32_t)*ptr << " "; + } + DPRINT << DEC(); // revert to decimal representation + DPRINT << ENDL(); + } +} +} // namespace tt::data_movement::common diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp new file mode 100644 index 00000000000..831707d0bf0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp" + +#define u8_l1_ptr volatile tt_l1_ptr uint8_t* +#define u8_vol_ptr volatile uint8_t* +#define u8_ptr uint8_t* + +void kernel_main() { + constexpr uint32_t unpadded_stick_bytes = get_compile_time_arg_val(0); + constexpr uint32_t padded_stick_bytes = get_compile_time_arg_val(1); + constexpr uint32_t unpadded_shard_height = get_compile_time_arg_val(2); + constexpr uint32_t padded_shard_height = get_compile_time_arg_val(3); + constexpr uint32_t W_front_pad_bytes = get_compile_time_arg_val(4); + + constexpr uint32_t input_shard_cb = get_compile_time_arg_val(5); + constexpr uint32_t output_shard_cb = get_compile_time_arg_val(6); + constexpr uint32_t unpadded_stick_step = get_compile_time_arg_val(7); + constexpr uint32_t padded_stick_step = get_compile_time_arg_val(8); + + uint32_t input_shard_base_addr = get_write_ptr(input_shard_cb); + uint32_t output_shard_base_addr = get_write_ptr(output_shard_cb); + + auto input_stick_ptr = reinterpret_cast(input_shard_base_addr); + auto output_stick_ptr = reinterpret_cast(output_shard_base_addr); + + // fill the sticks that aren't entirely padding with data from the input tensor + for (uint32_t h = 0; h < unpadded_shard_height; h++) { + cb_wait_front(output_shard_cb, 1); // wait for writer to fill this stick with padding + + // FIXME: this isn't aligned. we need to do a memcpy for now. we can try + // to do a noc_async_read later on with a trick. + // + // currently small noc transfers are slow, but once runtime drops an + // optimization (upcoming as of 12/12/2024) this might be worth + // investigating. + + // paulk says that an optimized loop will still be faster. + // TODO(jkruer): get paul's help optimizing this. + + // read the input stick into the padded output stick starting after the + // front padding + for (uint32_t i = 0; i < unpadded_stick_bytes; i++) { + output_stick_ptr[W_front_pad_bytes + i] = input_stick_ptr[i]; + } + + cb_pop_front(output_shard_cb, 1); + + input_stick_ptr += unpadded_stick_step; + output_stick_ptr += padded_stick_step; + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_sharded_stickwise.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_sharded_stickwise.cpp new file mode 100644 index 00000000000..8395a2f19c0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_sharded_stickwise.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "dataflow_api.h" + +#define u16_l1_ptr volatile tt_l1_ptr uint16_t* +#define u32_l1_ptr volatile tt_l1_ptr uint32_t* + +template +inline __attribute__((always_inline)) void fill_cb_with_padding_value( + const uint32_t cb, const uint32_t padding_value_as_u32) { + constexpr uint32_t num_elts = + num_bytes / padding_value_num_bytes; // constexpr so that this division happens once on host + uint32_t cb_write_addr = get_write_ptr(cb); + + if constexpr (padding_value_num_bytes == 4) { + u32_l1_ptr cb_write_addr_as_u32 = reinterpret_cast(cb_write_addr); + for (uint32_t i = 0; i < num_elts; i++) { + cb_write_addr_as_u32[i] = padding_value_as_u32; + } + } else if constexpr (padding_value_num_bytes == 2) { + uint16_t padding_value_as_u16 = static_cast(padding_value_as_u32); + u16_l1_ptr cb_write_addr_as_u16 = reinterpret_cast(cb_write_addr); + for (uint32_t i = 0; i < num_elts; i++) { + cb_write_addr_as_u16[i] = padding_value_as_u16; + } + } else { + static_assert( + padding_value_num_bytes == 2 || padding_value_num_bytes == 4, "padding_value_num_bytes is not 2 or 4"); + } +} + +void kernel_main() { + constexpr uint32_t padded_stick_bytes = get_compile_time_arg_val(0); + constexpr uint32_t padded_shard_height = get_compile_time_arg_val(1); + constexpr uint32_t padding_value_as_u32 = get_compile_time_arg_val(2); + constexpr uint32_t padding_value_num_bytes = get_compile_time_arg_val(3); + + constexpr auto output_shard_cb = get_compile_time_arg_val(4); + constexpr auto padding_value_cb = get_compile_time_arg_val(5); + + cb_reserve_back(output_shard_cb, padded_shard_height); + uint32_t output_shard_base_addr = get_write_ptr(output_shard_cb); + + fill_cb_with_padding_value(padding_value_cb, padding_value_as_u32); + uint32_t padding_value_base_addr = get_read_ptr(padding_value_cb); + + uint64_t output_stick_noc_addr = get_noc_addr(output_shard_base_addr); + for (uint32_t h = 0; h < padded_shard_height; h++) { + noc_async_write(padding_value_base_addr, output_stick_noc_addr, padded_stick_bytes); + noc_async_write_barrier(); + + cb_push_back(output_shard_cb, 1); + + output_stick_noc_addr += padded_stick_bytes; + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp index 675c0b5d622..0dbbf5fc4ae 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_op.cpp @@ -12,6 +12,10 @@ void Pad::validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const { using namespace tt::constants; const auto& input_tensor = input_tensors.at(0); + auto logical_rank = input_tensor.logical_shape().rank(); + auto padded_rank = input_tensor.padded_shape().rank(); + TT_FATAL(logical_rank == padded_rank, "ttnn.pad: logical and padded shapes must have the same rank"); + TT_FATAL(input_tensor.logical_shape().rank() <= 4, "ttnn.pad: input tensor rank currently must be 4 or less"); TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operand to pad needs to be on device!"); TT_FATAL(input_tensor.buffer() != nullptr, "Operand to pad needs to be allocated in a buffer on device!"); TT_FATAL(input_tensor.get_layout() == Layout::TILE || input_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); @@ -48,11 +52,11 @@ void Pad::validate_with_output_tensors( } if (input_tensor.is_sharded()) { - TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); - TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); + TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "ttnn.pad: For sharded inputs, only height-sharding is supported."); + TT_FATAL(input_tensor.get_layout() == Layout::ROW_MAJOR, "ttnn.pad: Only row-major sharded inputs are supported."); - TT_FATAL(this->output_mem_config.is_sharded(), "Error"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL(this->output_mem_config.is_sharded(), "ttnn.pad: For sharded inputs, the output must be sharded."); + TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "ttnn.pad: for sharded inputs, only height-sharding is supported for the output."); } } @@ -77,8 +81,36 @@ operation::ProgramWithCallbacks Pad::create_program( auto& output_tensor = output_tensors.at(0); if (input_tensor.get_layout() == Layout::ROW_MAJOR) { if (input_tensor.is_sharded()) { - return detail::pad_rm_sharded( - input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + uint32_t input_tot_h = std::accumulate( + input_tensor.get_logical_shape().view().begin(), + input_tensor.get_logical_shape().view().end() - 1, + 1, + std::multiplies()); + uint32_t input_w = input_tensor.get_logical_shape()[3]; + + uint32_t output_tot_h = std::accumulate( + output_tensor.get_logical_shape().view().begin(), + output_tensor.get_logical_shape().view().end() - 1, + 1, + std::multiplies()); + uint32_t output_w = output_tensor.get_logical_shape()[3]; + + if (input_w != output_w and input_tot_h != output_tot_h) { + TT_THROW( + "ttnn.pad: Unsupported sharded row-major padding configuration: pad_impl did not decompose padding " + "correctly."); + return {}; + } else if (input_w != output_w) { + return detail::pad_rm_sharded_width_only( + input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + } else if (input_tot_h != output_tot_h) { + return detail::pad_rm_sharded_height_only( + input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + } else { + // for no padding, we just use the height-only padding program + return detail::pad_rm_sharded_height_only( + input_tensor, output_tensor, this->output_tensor_shape, this->input_tensor_start, this->pad_value); + } } else { if (use_multicore) { return detail::pad_rm_reader_writer_multi_core_v2( diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp index af7ea10d225..a5383efbd40 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp @@ -6,10 +6,12 @@ #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/math.hpp" #include "tt_metal/common/constants.hpp" +#include "tt_metal/common/core_coord.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" #include "tt_log.h" #include "ttnn/operation.hpp" +#include "ttnn/operations/data_movement/common/common.hpp" using namespace tt::constants; using namespace tt::tt_metal; @@ -1397,7 +1399,7 @@ inline std::vector, std::vector>> get_ return ret_val; } -operation::ProgramWithCallbacks pad_rm_sharded( +operation::ProgramWithCallbacks pad_rm_sharded_height_only( const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_shape, @@ -1575,4 +1577,141 @@ operation::ProgramWithCallbacks pad_rm_sharded( return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } +operation::ProgramWithCallbacks pad_rm_sharded_width_only( + const Tensor& input_tensor, + Tensor& output, + const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& input_tensor_start, + float pad_value) { + Program program{}; + + TT_ASSERT( + output.shard_spec().has_value() and output.shard_spec()->shape[1] == output_tensor_shape[-1], + "ttnn.pad: pad_rm_sharded_width_only expects sharded output parameter with shard width equal to the width of " + "the requested output tensor. Ensure pad_impl is calling this program factory correctly."); + + uint32_t W = input_tensor.logical_shape()[-1]; + uint32_t W_padded = output_tensor_shape[3]; + + auto unpadded_stick_bytes = W * input_tensor.element_size(); + auto padded_stick_bytes = W_padded * input_tensor.element_size(); + + Device *device = input_tensor.device(); + + // input shard spec + auto input_shard_spec = input_tensor.shard_spec().value(); + uint32_t shard_height_unpadded = input_shard_spec.shape[0]; + + // output shard spec + auto shard_spec_padded = output.shard_spec().value(); + uint32_t shard_height_padded = shard_spec_padded.shape[0]; + + auto& all_cores_padded = shard_spec_padded.grid; + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t input_shard_cb_index = tt::CBIndex::c_0; + tt::tt_metal::CircularBufferConfig input_shard_cb_config = + tt::tt_metal::CircularBufferConfig( + shard_height_unpadded * unpadded_stick_bytes, {{input_shard_cb_index, input_cb_data_format}}) + .set_page_size(input_shard_cb_index, unpadded_stick_bytes) + .set_globally_allocated_address(*input_tensor.buffer()); + auto input_shard_cb = tt::tt_metal::CreateCircularBuffer(program, total_cores, input_shard_cb_config); + + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_shard_cb_index = tt::CBIndex::c_16; + tt::tt_metal::CircularBufferConfig output_shard_cb_config = + tt::tt_metal::CircularBufferConfig( + shard_height_padded * padded_stick_bytes, {{output_shard_cb_index, output_cb_data_format}}) + .set_page_size(output_shard_cb_index, padded_stick_bytes) + .set_globally_allocated_address(*output.buffer()); + auto output_shard_cb = tt::tt_metal::CreateCircularBuffer(program, total_cores, output_shard_cb_config); + + // construct const buffer with the pad_value + tt::DataFormat pad_val_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t pad_val_cb_index = tt::CBIndex::c_1; + tt::tt_metal::CircularBufferConfig cb_pad_val_config = + tt::tt_metal::CircularBufferConfig(padded_stick_bytes, {{pad_val_cb_index, pad_val_cb_data_format}}) + .set_page_size(pad_val_cb_index, padded_stick_bytes); + auto pad_val_cb = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_pad_val_config); + + uint32_t W_padding_front_bytes = input_tensor_start[-3] * input_tensor.element_size(); + + uint32_t padding_value_as_u32; + if (input_tensor.get_dtype() == tt::tt_metal::DataType::BFLOAT16) { + uint16_t bfloat_pad_value_bits = bfloat16(pad_value).to_uint16(); + padding_value_as_u32 = *reinterpret_cast(&bfloat_pad_value_bits); + } else if (input_tensor.get_dtype() == tt::tt_metal::DataType::FLOAT32) { + padding_value_as_u32 = *reinterpret_cast(&pad_value); + } else { + TT_THROW("ttnn.pad: unsupported data type for pad_rm_sharded_stickwise"); + } + + // FIXME: assumes that this was sharded using DRAM alignment so that gaps are left in the tensor. + // if this changes, we should change the stick step to be 16B (L1 alignment). + auto dram_alignment_bytes = tt::tt_metal::hal.get_alignment(tt::tt_metal::HalMemType::DRAM); + uint32_t padded_stick_step = tt::round_up( + padded_stick_bytes, dram_alignment_bytes); // round padded_stick bytes to a multiple of dram_alignment_bytes + uint32_t unpadded_stick_step = tt::round_up( + unpadded_stick_bytes, + dram_alignment_bytes); // round unpadded_stick bytes to a multiple of dram_alignment_bytes + + std::vector reader_ct_args = { + unpadded_stick_bytes, + padded_stick_bytes, + shard_height_unpadded, + shard_height_padded, + W_padding_front_bytes, + input_shard_cb_index, + output_shard_cb_index, + unpadded_stick_step, + padded_stick_step}; + + std::vector writer_ct_args = { + padded_stick_bytes, + shard_height_padded, + padding_value_as_u32, + output.element_size(), + output_shard_cb_index, + pad_val_cb_index, + padded_stick_step}; + + KernelHandle reader_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/reader_pad_dims_rm_sharded_stickwise.cpp", + all_cores_padded, + tt::tt_metal::ReaderDataMovementConfig(reader_ct_args)); + + KernelHandle writer_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/pad/device/kernels/dataflow/writer_pad_dims_rm_sharded_stickwise.cpp", + all_cores_padded, + tt::tt_metal::WriterDataMovementConfig(writer_ct_args)); + + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, all_cores_padded, {}); + tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, all_cores_padded, {}); + + auto override_runtime_args_callback = [ + input_shard_cb, + output_shard_cb + ] + ( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors + ) { + auto input_buffer = input_tensors.at(0).buffer(); + auto output_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, input_shard_cb, *input_buffer); + UpdateDynamicCircularBufferAddress(program, output_shard_cb, *output_buffer); + }; + return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} } // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp index 4fb223964f6..db5236092b3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.hpp @@ -48,11 +48,18 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core_v2( const ttnn::SimpleShape& input_tensor_start, const float pad_value); -operation::ProgramWithCallbacks pad_rm_sharded( +operation::ProgramWithCallbacks pad_rm_sharded_height_only( const Tensor& a, Tensor& output, const tt::tt_metal::LegacyShape& output_tensor_shape, const ttnn::SimpleShape& input_tensor_start, const float pad_value); +operation::ProgramWithCallbacks pad_rm_sharded_width_only( + const Tensor& a, + Tensor& output, + const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& input_tensor_start, + float pad_value); + } // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index 284ea39d352..f54a763b638 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -7,44 +7,134 @@ #include "ttnn/common/constants.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/run_operation.hpp" - +#include "ttnn/operations/data_movement/common/common.hpp" #include "ttnn/operations/data_movement/pad/device/pad_op.hpp" namespace ttnn::operations::data_movement { namespace { -template +template +bool eq_spans(const ArrayType& a, const ArrayType& b) { + return std::equal(a.begin(), a.end(), b.begin(), b.end()); +} + static ttnn::Tensor pad_impl( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const ShapeType& output_padded_shape, - const ShapeType& input_tensor_start, + std::span output_padded_shape, + std::span input_tensor_start, const float value, const bool use_multicore, const std::optional& memory_config_arg) { + auto input_logical_shape = input_tensor.logical_shape().view(); // on host if (input_tensor.storage_type() != StorageType::DEVICE) { - if (input_tensor.get_legacy_shape() == output_padded_shape) { + if (eq_spans(input_logical_shape, output_padded_shape)) { return input_tensor; } else { return input_tensor.pad( - tt::tt_metal::LegacyShape(output_padded_shape), ttnn::SimpleShape(input_tensor_start), value); + tt::tt_metal::LegacyShape(output_padded_shape), ttnn::SimpleShape{input_tensor_start}, value); } } + // on device else { - const auto input_tensor_shape = input_tensor.get_shape(); + auto input_tensor_shape = input_tensor.get_shape(); const auto rank = input_tensor_shape.rank(); - TT_FATAL(rank == 4, "Tensor rank is not 4"); + TT_FATAL(rank == 4, "ttnn.pad: input tensor passed to pad_impl must have rank == 4, but got rank {}.", rank); + + using ShardStrategy = ttnn::operations::data_movement::ShardStrategy; + using ShardOrientation = tt::tt_metal::ShardOrientation; + using Layout = tt::tt_metal::Layout; + + auto output_memory_config = memory_config_arg.value_or(input_tensor.memory_config()); + + if (input_tensor.is_sharded()) { + auto total_height = [](const auto& shape) { + return std::accumulate(shape.begin(), shape.end() - 1, 1, std::multiplies()); + }; + + auto height_distinct = [&total_height](const auto& shape, const auto& other_shape) { + return total_height(shape) != total_height(other_shape); + }; + + auto width_distinct = [](const auto& shape, const auto& other_shape) { return shape[3] != other_shape[3]; }; + + uint32_t input_w = input_logical_shape[3]; + uint32_t output_w = output_padded_shape[3]; + + if (width_distinct(input_logical_shape, output_padded_shape)) { + std::array output_shape_width_padded{ + input_logical_shape[0], input_logical_shape[1], input_logical_shape[2], output_w}; + auto width_pad_memory_config = create_sharded_memory_config( + ttnn::SimpleShape{output_shape_width_padded}, + input_tensor.shard_spec()->grid, // reuse input cores for now: FIXME: can we do better? + // it's complicated because we need the input shards to be local + // to the core holding the output shard currently. + ShardStrategy::HEIGHT, // stay height sharded + ShardOrientation::ROW_MAJOR); + output_memory_config = width_pad_memory_config; + + if (height_distinct(input_logical_shape, output_padded_shape)) { + // we will decompose the padding into two parts and run two + // separate pads. + ttnn::SmallVector adjusted_input_tensor_start{0, 0, 0, input_tensor_start[3]}; + + TT_ASSERT( + not(height_distinct(input_logical_shape, output_shape_width_padded) and + width_distinct(input_logical_shape, output_shape_width_padded)), + "infinite recursion"); + + // pad width + auto output_tensor_width_padded = pad_impl( + queue_id, + input_tensor, + output_shape_width_padded, + adjusted_input_tensor_start, + value, + use_multicore, + width_pad_memory_config); + + TT_ASSERT( + not(height_distinct(output_padded_shape, output_shape_width_padded) and + width_distinct(output_padded_shape, output_shape_width_padded)), + "infinite recursion"); + + auto height_pad_memory_config = create_sharded_memory_config( + ttnn::SimpleShape{output_padded_shape}, + input_tensor.shard_spec()->grid, + ShardStrategy::HEIGHT, + ShardOrientation::ROW_MAJOR); + + // then pad height + auto output_tensor_height_padded = pad_impl( + queue_id, + output_tensor_width_padded, + output_padded_shape, + input_tensor_start, + value, + use_multicore, + memory_config_arg.value_or(height_pad_memory_config)); + output_tensor_width_padded.deallocate(); // dealloc temporary width padded tensor + return output_tensor_height_padded; + } + } + } + + auto output_w = output_padded_shape[3]; + TT_ASSERT( + !input_tensor.is_sharded() || output_w == output_memory_config.shard_spec->shape[1], + "output_w != output_memory_config.shard_spec().shape[1]"); + + tt::tt_metal::LegacyShape output_padded_legacy_shape{output_padded_shape}; - auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); auto output_tensor = operation::run( - Pad{tt::tt_metal::LegacyShape(output_padded_shape), - ttnn::SimpleShape(input_tensor_start), + Pad{output_padded_legacy_shape, + ttnn::SimpleShape{input_tensor_start}, value, - memory_config, + output_memory_config, use_multicore}, {input_tensor}, {}, @@ -56,7 +146,6 @@ static ttnn::Tensor pad_impl( } } -template static ttnn::Tensor pad_impl( uint8_t queue_id, const ttnn::Tensor& input_tensor, @@ -79,7 +168,7 @@ static ttnn::Tensor pad_impl( padding.insert(padding.begin(), 4 - original_rank, {0, 0}); auto input_shape_with_tile_padding = input_tensor_4D.get_shape().with_tile_padding(); - ShapeType output_padded_shape; + std::vector output_padded_shape(padding.size(), 0); for (size_t i = 0; i < padding.size(); i++) { output_padded_shape[i] = padding[i].first + input_shape_with_tile_padding[i] + padding[i].second; } @@ -102,12 +191,12 @@ static ttnn::Tensor pad_impl( } // Performing actual padding - ShapeType pad_front_array; + std::vector pad_front_array(padding.size(), 0); for (size_t i = 0; i < pad_front.size(); i++) { pad_front_array[i] = pad_front[i]; } - return pad_impl( + return pad_impl( queue_id, input_tensor_4D, output_padded_shape, pad_front_array, value, use_multicore, memory_config_arg); } @@ -125,56 +214,17 @@ ttnn::Tensor ExecutePad::invoke( const int original_rank = input_tensor.get_shape().rank(); ttnn::SmallVector> padding_vec(padding.begin(), padding.end()); - ttnn::Tensor output_tensor; - if (input_tensor.storage_type() != StorageType::DEVICE) { - switch (original_rank) { - case 1: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 2: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 3: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 4: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 5: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 6: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 7: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - case 8: - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - break; - default: TT_THROW("Unsupported tensor rank of {}. Needs to be between 1 and 8 inclusively.", original_rank); - } - } else { - output_tensor = pad_impl( - queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); - } + ttnn::Tensor output_tensor = + pad_impl(queue_id, input_tensor, std::move(padding_vec), value, use_multicore, memory_config_arg); // output_tensor is currently 4D. We have to squeeze back to the original rank - auto to_vec = [](const auto& arr) { return ttnn::SmallVector(arr.begin(), arr.end()); }; - auto shape = to_vec(output_tensor.get_shape().value); + auto to_vec = [](const auto& arr) { return ttnn::SmallVector{arr.begin(), arr.end()}; }; + auto output_shape = to_vec(output_tensor.get_shape().value); auto padded_shape = to_vec(output_tensor.get_shape().with_tile_padding().value); - if (auto rank_diff = shape.size() - original_rank; rank_diff) { - auto remove_first_elements = [](auto& source, size_t n) { source.erase(source.begin(), source.begin() + n); }; - remove_first_elements(shape, rank_diff); - remove_first_elements(padded_shape, rank_diff); - auto squeezedShape = ttnn::Shape(tt::tt_metal::LegacyShape(shape, padded_shape)); + if (const auto rank_diff = output_shape.size() - original_rank; rank_diff) { + auto remove_prefix = [](auto& source, size_t n) { source.erase(source.begin(), source.begin() + n); }; + remove_prefix(output_shape, rank_diff); + remove_prefix(padded_shape, rank_diff); + auto squeezedShape = ttnn::Shape(tt::tt_metal::LegacyShape(output_shape, padded_shape)); output_tensor = ttnn::reshape(output_tensor, squeezedShape); } @@ -193,7 +243,7 @@ ttnn::Tensor ExecutePad::invoke( const float value, \ const bool use_multicore, \ const std::optional& memory_config_arg) { \ - return pad_impl( \ + return pad_impl( \ queue_id, input_tensor, output_padded_shape, input_tensor_start, value, use_multicore, memory_config_arg); \ } \ \ @@ -203,7 +253,7 @@ ttnn::Tensor ExecutePad::invoke( const ShapeType& input_tensor_start, \ const float value, \ const std::optional& memory_config_arg) { \ - return pad_impl( \ + return pad_impl( \ DefaultQueueId, input_tensor, output_padded_shape, input_tensor_start, value, false, memory_config_arg); \ } \ \ @@ -212,7 +262,7 @@ ttnn::Tensor ExecutePad::invoke( const ShapeType& output_padded_shape, \ const ShapeType& input_tensor_start, \ const float value) { \ - return pad_impl( \ + return pad_impl( \ DefaultQueueId, input_tensor, output_padded_shape, input_tensor_start, value, false, std::nullopt); \ } diff --git a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp index 113d5c3ec1b..3863210a4f2 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape_base.hpp @@ -24,6 +24,7 @@ class ShapeBase { explicit ShapeBase(const std::array& arr) : value_(arr.begin(), arr.end()) { init(); } + explicit ShapeBase(std::span span) : value_(span.begin(), span.end()) { init(); } template bool operator==(const std::array& other) const {