From 444b0dcd057317039e5a20fe8509bf7dca33d8f8 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Wed, 18 Dec 2024 23:39:01 -0800 Subject: [PATCH] Fix ttnn.from_torch for 0D/1D tensors with tile layout (#15882) ### Ticket https://github.com/tenstorrent/tt-metal/issues/15630 ### Problem description Since Shape/LegacyShape doesn't support different logical and padded ranks, we had to remove all usages of those classes on the way from pytorch to ttnn tensor. ### What's changed Major refactoring in `to_layout`, `pad` ops TensorLayout fixes for 0D/1D tensors ### Checklist - [x] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12398856356) - [ ] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --- .../tt_eager/ops/test_tilize_zero_padding.cpp | 2 +- ...test_tilize_zero_padding_channels_last.cpp | 2 +- .../ttnn/unit_tests/operations/test_matmul.py | 4 +- .../ttnn/unit_tests/test_to_and_from_torch.py | 19 ++ ttnn/cpp/pybind11/pytensor.cpp | 22 +- .../core/to_layout/to_layout_op.cpp | 265 ++++++++---------- .../core/work_split/work_split_tilize.hpp | 9 +- .../data_movement/concat/concat.cpp | 2 +- .../ttnn/operations/data_movement/pad/pad.cpp | 2 +- .../data_movement/reshape_view/reshape.cpp | 51 ++-- .../device/tilize_with_val_padding_op.cpp | 51 ++-- .../device/tilize_with_val_padding_op.hpp | 6 +- .../tilize_with_val_padding.cpp | 14 +- .../tilize_with_val_padding.hpp | 4 +- .../tilize_with_val_padding_pybind.hpp | 8 +- .../experimental/auto_format/auto_format.cpp | 3 +- .../moreh/moreh_helper_functions.cpp | 34 ++- ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp | 103 ++++--- ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp | 9 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 45 ++- ttnn/cpp/ttnn/tensor/tensor.hpp | 9 +- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 135 +++++---- ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 5 +- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 54 +--- ttnn/cpp/ttnn/tensor/tensor_ops.hpp | 2 +- ttnn/cpp/ttnn/tensor/types.hpp | 13 +- 26 files changed, 447 insertions(+), 426 deletions(-) diff --git a/tests/tt_eager/ops/test_tilize_zero_padding.cpp b/tests/tt_eager/ops/test_tilize_zero_padding.cpp index 580bd410295..2cfd265e888 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding.cpp @@ -46,7 +46,7 @@ int main(int argc, char** argv) { log_debug(LogTest, "Moving src data to host to validate"); Tensor host_a = a.cpu(); // Move tensor a to host to validate // TODO: Update when tensor.pad_to_tile() function is added - auto padded_shape = a.get_legacy_shape(); + auto padded_shape = a.get_padded_shape(); padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT); padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH); Tensor padded_host_a = host_a.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0); diff --git a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp index e565c4d8026..cc37bd14bf7 100644 --- a/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp +++ b/tests/tt_eager/ops/test_tilize_zero_padding_channels_last.cpp @@ -49,7 +49,7 @@ int main(int argc, char** argv) { Tensor host_a = a.cpu(); // Move tensor a to host to validate Tensor g = Tensor(host_a.get_storage(), shape, DataType::BFLOAT16, Layout::ROW_MAJOR); // TODO: Update when tensor.pad_to_tile() function is added - auto padded_shape = g.get_legacy_shape(); + auto padded_shape = g.get_padded_shape(); padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT); padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH); Tensor padded_g = g.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0); diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index c411ab46631..ae04c188263 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -1229,14 +1229,14 @@ def test_matmul_with_matched_width_height(device, m_size, k_size, n_size): def test_matmul_with_matched_width_height_from_1D(device, k_size, n_size): torch.manual_seed(0) - torch_input_tensor_a = torch.rand((k_size), dtype=torch.bfloat16) + torch_input_tensor_a = torch.rand((1, k_size), dtype=torch.bfloat16) torch_input_tensor_b = torch.rand((k_size, n_size), dtype=torch.bfloat16) torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) output = input_tensor_a @ input_tensor_b - output = ttnn.to_torch(output, torch_rank=1) + output = ttnn.to_torch(output) assert len(output.shape) == len(torch_output_tensor.shape) assert output.shape == torch_output_tensor.shape diff --git a/tests/ttnn/unit_tests/test_to_and_from_torch.py b/tests/ttnn/unit_tests/test_to_and_from_torch.py index f3132050e53..539edf79084 100644 --- a/tests/ttnn/unit_tests/test_to_and_from_torch.py +++ b/tests/ttnn/unit_tests/test_to_and_from_torch.py @@ -84,3 +84,22 @@ def test_from_torch_large(device): x_tensor = ttnn.from_torch(torch_x, layout=ttnn.TILE_LAYOUT) x_tensor = ttnn.to_torch(x_tensor) assert torch.allclose(torch_x, x_tensor) + + +@pytest.mark.parametrize( + "shape", + [ + (), + (1), + (2), + (0), + ], +) +@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_to_for_01_rank(shape, layout, dtype): + torch_input_tensor = torch.rand(shape, dtype=dtype) + tensor = ttnn.from_torch(torch_input_tensor, layout=layout) + # Conversion in the opposite direction is not yet supported + # torch_output_tensor = ttnn.to_torch(tensor) + # assert torch.allclose(torch_input_tensor, torch_output_tensor) diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 6f9ccc5ab64..01379239ed3 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -115,16 +115,13 @@ Tensor convert_float_vector_to_tt_tensor( Layout::TILE, layout); } + auto result_cpu_spec = TensorSpec( + ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::TILE, tile), MemoryConfig{})); auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), DataType::FLOAT32); auto float_tensor = Tensor(OwnedStorage{owned_buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, tile); - auto tile_val = tile.value_or(Tile()); - if (shape[2] % tile_val.get_height() != 0 || shape[3] % tile_val.get_width() != 0) { - auto padded_shape = shape; - padded_shape[2] = tt::round_up(shape[2], tile_val.get_height()); - padded_shape[3] = tt::round_up(shape[3], tile_val.get_width()); - - float_tensor = tensor_ops::tensor_pad( - float_tensor, LegacyShape(shape, padded_shape), ttnn::SimpleShape{0, 0, 0, 0}, 0); + if (result_cpu_spec.logical_shape() != result_cpu_spec.padded_shape()) { + float_tensor = + tensor_ops::tensor_pad(float_tensor, result_cpu_spec.padded_shape(), ttnn::SimpleShape{0, 0, 0, 0}, 0); } auto output_float_data = owned_buffer::get_as(float_tensor.to(Layout::TILE)).get(); auto output_packed_data = @@ -132,14 +129,16 @@ Tensor convert_float_vector_to_tt_tensor( ? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile) : pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); - auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); + auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), result_cpu_spec); if (device) { return tensor.to(device, memory_config.value_or(MemoryConfig{})); } return tensor; } + auto result_cpu_spec = TensorSpec( + ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::ROW_MAJOR, tile), MemoryConfig{})); auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR, tile).to(layout); + auto tensor = Tensor(OwnedStorage{owned_buffer}, result_cpu_spec).to(layout); if (device) { return tensor.to(device, memory_config.value_or(MemoryConfig{})); } @@ -1212,7 +1211,8 @@ void pytensor_module(py::module& m_tensor) { const std::array& output_tensor_shape, const std::array& input_tensor_start, float pad_value) { - return self.pad(output_tensor_shape, ttnn::SimpleShape(input_tensor_start), pad_value); + return self.pad( + ttnn::SimpleShape(output_tensor_shape), ttnn::SimpleShape(input_tensor_start), pad_value); }, R"doc( Pad TT Tensor with given pad value ``arg2``. diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index e9a74e4706b..a5df642fdc2 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -42,6 +42,90 @@ inline bool use_multicore_device_tilize( return num_tiles_in_row <= max_tiles; } +bool requires_padding_change(const ttnn::Tensor& tensor, ttnn::Layout layout) { + auto tile = tensor.get_tensor_spec().tile(); + if (layout == Layout::ROW_MAJOR) { + // There shouldn't be extra paddings for Row Major layout + return tensor.logical_shape() != tensor.padded_shape(); + } + // It's okay for conversion to tile layout to preserve arbitrary padding as long as it satisfies the alignment + TensorSpec padded_spec( + tensor.padded_shape(), + TensorLayout(tensor.dtype(), PageConfig(layout, std::move(tile)), tensor.memory_config())); + return tensor.get_padded_shape() != padded_spec.padded_shape(); +} + +template +Tensor to_layout_impl_on_device( + const ttnn::Tensor& tensor_arg, + const ttnn::Layout layout, + const std::optional& dtype, + ttnn::MemoryConfig output_memory_config, + T* device) { + bool use_multicore_untilize = true; + bool use_multicore_tilize = use_multicore_device_tilize(tensor_arg, dtype); + + if (layout == ttnn::ROW_MAJOR_LAYOUT) { + TT_FATAL( + !dtype.has_value() || dtype.value() == tensor_arg.dtype(), + "dtype cannot be different from tensor dtype when converting to ROW_MAJOR_LAYOUT on device!"); + } + + if (!requires_padding_change(tensor_arg, layout)) { + if (layout == ttnn::ROW_MAJOR_LAYOUT) { + return ttnn::untilize(tensor_arg, output_memory_config, use_multicore_untilize); + } + return ttnn::tilize(tensor_arg, output_memory_config, dtype, use_multicore_tilize); + } + + auto tensor_shape = tensor_arg.get_logical_shape(); + + if (layout == ttnn::ROW_MAJOR_LAYOUT) { + if (tensor_arg.is_sharded()) { + const auto memory_config = tensor_arg.memory_config(); + output_memory_config = tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type}; + } + SmallVector output_tensor_end; + for (auto index = 0; index < tensor_shape.rank(); ++index) { + output_tensor_end.push_back(tensor_shape[index] - 1); + } + + auto tensor = + ttnn::untilize_with_unpadding(tensor_arg, output_tensor_end, output_memory_config, use_multicore_untilize); + return ttnn::reshape(tensor, tensor_shape); + } + + TensorSpec result_spec( + tensor_arg.logical_shape(), + TensorLayout( + tensor_arg.dtype(), + PageConfig(layout, std::move(tensor_arg.tensor_spec().tile())), + tensor_arg.memory_config())); + + // ttnn::tilize_with_val_padding doesn't support height sharded tensors + // workaround by applying padding and then tilizing + if (tensor_arg.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + ttnn::SmallVector> pad(result_spec.shape().rank()); + auto output_padding = result_spec.shape().padding(); + for (size_t i = 0; i < result_spec.padded_shape().rank(); i++) { + pad[i] = {output_padding[i].front, output_padding[i].back}; + } + auto tensor = ttnn::pad(0, tensor_arg, tt::stl::Span(pad), 0, true, std::nullopt); + return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); + } + + PadValue pad_value_variant; + if (tensor_arg.get_dtype() == ttnn::DataType::BFLOAT16 or tensor_arg.get_dtype() == ttnn::DataType::FLOAT32) { + pad_value_variant = 0.0f; + } else { + pad_value_variant = (uint32_t)0; + } + + auto tensor = ttnn::tilize_with_val_padding( + tensor_arg, result_spec.padded_shape(), pad_value_variant, output_memory_config, dtype, use_multicore_tilize); + return tensor.reshape(tensor_arg.logical_shape()); +} + template Tensor to_layout_impl( const ttnn::Tensor& tensor_arg, @@ -67,167 +151,44 @@ Tensor to_layout_impl( return tensor_arg; } - const std::set supported_layouts = { - ttnn::ROW_MAJOR_LAYOUT, - ttnn::TILE_LAYOUT, - }; - - if (supported_layouts.find(layout) == supported_layouts.end()) { + if (layout != ROW_MAJOR_LAYOUT && layout != TILE_LAYOUT) { TT_THROW("ttnn::to_layout: Unsupported layout conversion from {} to {}!", tensor_arg.get_layout(), layout); } - const auto requires_padding_change = - [](ttnn::Tensor& tensor, ttnn::Layout layout, const ttnn::Shape& shape) -> bool { - const auto intended_shape = shape; - const auto padded_shape = shape.with_tile_padding(); - if (layout == ttnn::ROW_MAJOR_LAYOUT and intended_shape != padded_shape) { - return true; - } - if (layout == ttnn::TILE_LAYOUT) { - auto tile_shape = tensor.tensor_spec().tile().get_tile_shape(); - if (padded_shape.rank() < 2 or padded_shape[-1] % tile_shape[1] != 0 or - padded_shape[-2] % tile_shape[0] != 0) { - return true; - } - } - return false; - }; - - const auto intended_shape = tensor_arg.get_shape(); - - auto tensor = tensor_arg; - const auto tile = tensor.get_tensor_spec().tile(); - - SmallVector output_shape; - if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) { - output_shape.push_back(1); - tensor = ttnn::reshape( - tensor, - ttnn::Shape( - SmallVector{1, intended_shape[0]}, - SmallVector{1, tensor_arg.get_shape().with_tile_padding()[0]})); - } - for (auto index = 0; index < intended_shape.rank(); ++index) { - output_shape.push_back(intended_shape[index]); + auto output_memory_config = + memory_config.value_or(ttnn::get_memory_config(tensor_arg).value_or(ttnn::DRAM_MEMORY_CONFIG)); + + if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) { + return to_layout_impl_on_device(tensor_arg, layout, dtype, std::move(output_memory_config), device); } - auto padded_output_shape = output_shape; - for (auto index = output_shape.size() - 2; index < output_shape.size(); ++index) { - padded_output_shape[index] = ttnn::pad_to_multiple_of_tile_size( - padded_output_shape[index], - (index == output_shape.size() - 2) ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]); + TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!"); + if (not requires_padding_change(tensor_arg, layout)) { + return device ? tensor_arg.to(layout, device) : tensor_arg.to(layout); } - auto output_memory_config = - memory_config.value_or(ttnn::get_memory_config(tensor).value_or(ttnn::DRAM_MEMORY_CONFIG)); + if (layout == ttnn::ROW_MAJOR_LAYOUT) { + auto tensor = device ? tensor_arg.to(layout, device) : tensor_arg.to(layout); + tensor = tensor.unpad_from_tile(tensor.get_logical_shape()); + return tensor.reshape(tensor_arg.logical_shape()); + } - if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) { - bool use_multicore_untilize = true; - bool use_multicore_tilize = use_multicore_device_tilize(tensor, dtype); - - if (not requires_padding_change(tensor, layout, tensor.get_shape())) { - if (layout == ttnn::ROW_MAJOR_LAYOUT) { - TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); - return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize); - } else if (layout == ttnn::TILE_LAYOUT) { - if (tensor.is_sharded()) { - const auto shard_shape = get_memory_config(tensor).value().shard_spec.value().shape; - if (shard_shape[0] % ttnn::TILE_SIZE != 0 or shard_shape[1] % ttnn::TILE_SIZE != 0) { - TT_THROW( - "ttnn::to_layout: Sharded tensor must have shard shape that is a multiple of " - "TILE_SIZE!"); - } - } - return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); - } else { - throw std::runtime_error("ttnn::to_layout: Unsupported layout!"); - } - } else if (layout == ttnn::ROW_MAJOR_LAYOUT) { - TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); - - if (tensor.is_sharded()) { - const auto memory_config = tensor.memory_config(); - output_memory_config = - tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type}; - } - SmallVector output_tensor_end; - for (auto index = 0; index < tensor.get_shape().rank(); ++index) { - output_tensor_end.push_back(tensor.get_shape()[index] - 1); - } - - tensor = - ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize); - return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); - - } else if (layout == ttnn::TILE_LAYOUT) { - SmallVector padded_output_shape; - - for (int index = 0; index < tensor.get_shape().rank(); ++index) { - uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim - uint32_t padded_value = - index < second_last_rank - ? tensor.get_shape()[index] - : ttnn::pad_to_multiple_of_tile_size( - tensor.get_shape()[index], - index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]); - padded_output_shape.push_back(padded_value); - } - if (tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - // ttnn::tilize_with_val_padding doesn't support height sharded tensors - // workaround by applying padding and then tilizing - SmallVector> padding = { - {0, 0}, - {0, 0}, - {0, padded_output_shape[2] - output_shape[2]}, - {0, padded_output_shape[3] - output_shape[3]}}; - tensor = ttnn::pad(0, tensor, padding, 0, true, std::nullopt); - return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); - } else { - PadValue pad_value_variant; - if (tensor.get_dtype() == ttnn::DataType::BFLOAT16 or tensor.get_dtype() == ttnn::DataType::FLOAT32) { - pad_value_variant = 0.0f; - } else { - pad_value_variant = (uint32_t)0; - } - - tensor = ttnn::tilize_with_val_padding( - tensor, padded_output_shape, pad_value_variant, output_memory_config, dtype, use_multicore_tilize); - } - - return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape})); - - } else { - TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout); - } - } else { - TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!"); - if (not requires_padding_change(tensor, layout, tensor.get_shape())) { - return device ? tensor.to(layout, device) : tensor.to(layout); - } else if (layout == ttnn::ROW_MAJOR_LAYOUT) { - tensor = device ? tensor.to(layout, device) : tensor.to(layout); - tensor = tensor.unpad_from_tile(tensor.get_logical_shape()); - return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); - } else if (layout == ttnn::TILE_LAYOUT) { - SmallVector padded_output_shape; - SmallVector padded_input_start; - for (int index = 0; index < tensor.get_shape().rank(); ++index) { - uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim - uint32_t padded_value = - index < second_last_rank - ? tensor.get_shape()[index] - : ttnn::pad_to_multiple_of_tile_size( - tensor.get_shape()[index], - index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]); - padded_output_shape.push_back(padded_value); - padded_input_start.push_back(0); - } - tensor = tensor.pad(padded_output_shape, ttnn::SimpleShape(std::move(padded_input_start)), 0); - tensor = device ? tensor.to(layout, device) : tensor.to(layout); - return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape})); - } else { - TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout); - } + SmallVector padded_input_start; + for (int index = 0; index < tensor_arg.get_logical_shape().rank(); ++index) { + padded_input_start.push_back(0); } + TensorSpec result_spec( + tensor_arg.padded_shape(), + TensorLayout::fromPaddedShape( + tensor_arg.dtype(), + PageConfig(layout, std::move(tensor_arg.tensor_spec().tile())), + tensor_arg.memory_config(), + tensor_arg.logical_shape(), + tensor_arg.padded_shape())); + + auto tensor = tensor_arg.pad(result_spec.padded_shape(), ttnn::SimpleShape(std::move(padded_input_start)), 0); + tensor = device ? tensor.to(layout, device) : tensor.to(layout); + return tensor.reshape(result_spec.logical_shape()); } } // namespace detail diff --git a/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp b/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp index cd783f161b5..eaddc1c2d14 100644 --- a/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp +++ b/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp @@ -183,16 +183,15 @@ inline std::vector> distribute_work( uint32_t blocks_per_core, bool has_cliff, uint32_t nblocks_per_core_cliff) { - TT_FATAL( - logical_shape.rank() >= 2, "Logical shape rank needs to be >=2. Shape: {}", "Error", logical_shape, padding); + TT_FATAL(padding.rank() >= 2 && padding.rank() <= 4, "Rank needs to be >=2. Shape: {} {}", logical_shape, padding); auto input_w = logical_shape.rank() >= 4 ? logical_shape[-4] : 1; auto input_z = logical_shape.rank() >= 3 ? logical_shape[-3] : 1; auto input_y = logical_shape.rank() >= 2 ? logical_shape[-2] : 1; - auto padding_w = logical_shape.rank() >= 4 ? padding[padding.get_normalized_index(-4)].back : 0; - auto padding_z = logical_shape.rank() >= 3 ? padding[padding.get_normalized_index(-3)].back : 0; - auto padding_y = logical_shape.rank() >= 2 ? padding[padding.get_normalized_index(-2)].back : 0; + auto padding_w = padding.rank() >= 4 ? padding[-4].back : 0; + auto padding_z = padding.rank() >= 3 ? padding[-3].back : 0; + auto padding_y = padding.rank() >= 2 ? padding[-2].back : 0; // total work is a full rep followed by a padding. auto full_rep_blocks = FullRep(input_y, padding_y, input_z, padding_z, input_w).to_block_reps(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp index c0fcd9fb356..9161f24a2ed 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp @@ -146,7 +146,7 @@ MassagedConcat build_untilize_rm_retilize_concat( auto padded = pad_to_tile_vol(queue_id, output, 0.0f, true, output.memory_config()); concat_db_print(true, "[DEBUG] padded to tile layout, now tilizing."); auto tilized = - ttnn::tilize_with_val_padding(padded, padded.get_legacy_shape(), 0.0f, output.memory_config()); + ttnn::tilize_with_val_padding(padded, padded.get_padded_shape(), 0.0f, output.memory_config()); concat_db_print(true, "[DEBUG] tilized"); // need to reshape tilized result to logical concat output shape auto reshaped = ttnn::reshape( diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp index f54a763b638..f5f2b8cc380 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp @@ -34,7 +34,7 @@ static ttnn::Tensor pad_impl( return input_tensor; } else { return input_tensor.pad( - tt::tt_metal::LegacyShape(output_padded_shape), ttnn::SimpleShape{input_tensor_start}, value); + ttnn::SimpleShape(output_padded_shape), ttnn::SimpleShape(input_tensor_start), value); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 924da1f446b..ecc6d7f1c9b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -310,10 +310,9 @@ ttnn::Shape shape_corrector(const ttnn::Tensor& tensor, const ttnn::Shape& shape ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, const ttnn::Shape& input_shape, - const std::optional &memory_config, + const std::optional& memory_config, const uint8_t queue_id, - const std::optional &pad_value - ) { + const std::optional& pad_value) { MemoryConfig mem_config = memory_config.value_or(tensor.memory_config()); auto layout = tensor.get_layout(); auto tensor_shape = tensor.get_shape(); @@ -337,22 +336,24 @@ ttnn::Tensor ReshapeViewOperation::invoke( //The following case should only be called for the device storage case, the rest is a bandaid //for issue 15317 - + const uint32_t shape_last_dim = shape.rank() >= 1 ? shape[-1] : 1; + const uint32_t tensor_last_dim = tensor_shape.rank() >= 1 ? tensor_shape[-1] : 1; const uint32_t shape_second_last_dim = shape.rank() >= 2 ? shape[-2]:1; const uint32_t tensor_shape_second_last_dim = tensor_shape.rank() >= 2 ? tensor_shape[-2]:1; // Just edit shape if shape has a 0 dimension if (tensor.get_logical_volume() == 0) { - TT_FATAL(shape.logical_shape().volume() == 0 , "Tensor volume is 0, but shape's volume is not"); - TT_FATAL((tensor.storage_type() != StorageType::MULTI_DEVICE && - tensor.storage_type() != StorageType::MULTI_DEVICE_HOST), - "Reshaping a multi-device tensor with 0 volume is not supported"); + TT_FATAL(shape.logical_shape().volume() == 0, "Tensor volume is 0, but shape's volume is not"); + TT_FATAL( + (tensor.storage_type() != StorageType::MULTI_DEVICE && + tensor.storage_type() != StorageType::MULTI_DEVICE_HOST), + "Reshaping a multi-device tensor with 0 volume is not supported"); return tensor.reshape(shape); } TT_FATAL(shape.logical_shape().volume() != 0, "Tensor volume is not 0, but shape volume is 0"); bool this_is_view = - (tensor_shape[-1] == shape[-1]) && (mem_config.is_sharded() == tensor.memory_config().is_sharded()) && + (tensor_last_dim == shape_last_dim) && (mem_config.is_sharded() == tensor.memory_config().is_sharded()) && (mem_config.is_l1() == tensor.memory_config().is_l1()) && ((tensor.get_layout() == ttnn::ROW_MAJOR_LAYOUT) || // Its row major (tensor_shape_second_last_dim == shape_second_last_dim) || // Second last dimension is the same @@ -366,20 +367,19 @@ ttnn::Tensor ReshapeViewOperation::invoke( if (this_is_view) { return PerformView(tensor,shape, tile_first_dim, tile_second_dim); } - if(shape.logical_shape().volume() != tensor.get_logical_volume()) - { - //This is completely incorrect but it is due to issue 15137 or issue 15558 + if (shape.logical_shape().volume() != tensor.get_logical_volume()) { + // This is completely incorrect but it is due to issue 15137 or issue 15558 bool tile_tensor_view_reshape_possible = (layout == ttnn::Layout::TILE and shape.with_tile_padding().rank() >= 2 and - shape.with_tile_padding()[-2] % ttnn::TILE_SIZE == 0 and - shape.with_tile_padding()[-1] % ttnn::TILE_SIZE == 0 and - tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1]); + shape.with_tile_padding()[-2] % ttnn::TILE_SIZE == 0 and + shape.with_tile_padding()[-1] % ttnn::TILE_SIZE == 0 and + tensor_shape.with_tile_padding()[-1] == shape.with_tile_padding()[-1]); if (tile_tensor_view_reshape_possible) { // This case has been allowed in the past though it means introducing padding values to the data return tensor.reshape(shape); } - //This is a completely incorrect test but it is due to issue 15558 + // This is a completely incorrect test but it is due to issue 15558 TT_FATAL(false, "Attempting to reshape between two shapes with different volumes"); } // Catch-all @@ -402,21 +402,20 @@ ttnn::Tensor ReshapeViewOperation::invoke( return invoke(tensor, shape,std::nullopt,0,std::nullopt); } -ttnn::Tensor ReshapeViewOperation::invoke( - const ttnn::Tensor& tensor, - const ttnn::SimpleShape& shape, - const std::optional &memory_config, - const uint8_t queue_id, - const std::optional &pad_value - ) { - return invoke(tensor, ttnn::Shape(shape.view()),memory_config,queue_id,pad_value); -} + ttnn::Tensor ReshapeViewOperation::invoke( + const ttnn::Tensor& tensor, + const ttnn::SimpleShape& shape, + const std::optional& memory_config, + const uint8_t queue_id, + const std::optional& pad_value) { + return invoke(tensor, ttnn::Shape(shape.view()), memory_config, queue_id, pad_value); + } ttnn::Tensor ReshapeViewOperation::invoke( const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape ) { - return invoke(tensor, ttnn::Shape(shape.view()),std::nullopt,0,std::nullopt); + return invoke(tensor, ttnn::Shape(shape.view()), std::nullopt, 0, std::nullopt); } ttnn::Tensor ReshapeViewOperation::invoke( diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp index 0677baaaa2c..91eff6af9b3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.cpp @@ -14,7 +14,7 @@ namespace ttnn::operations::data_movement { void TilizeWithValPadding::validate(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); - const auto& input_shape = input_tensor_a.get_legacy_shape(); + const auto& input_shape = input_tensor_a.get_padded_shape(); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!"); TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!"); TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "Can only tilize row major data"); @@ -22,24 +22,23 @@ void TilizeWithValPadding::validate(const std::vector& input_tensors) co input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::UINT32 or input_tensor_a.get_dtype() == DataType::FLOAT32, "Can only tilize bfloat16/float32 or uint32 tensors"); - TT_FATAL(input_shape.rank() >= 2, "Input tensor must be of rank >2, but its shape is {}", input_shape); - for (auto i = 0; i < input_shape.rank(); i++) { + for (int i = -static_cast(input_shape.rank()); i >= 0; i--) { TT_FATAL( - input_shape[i] <= this->output_tensor_shape[i], + input_shape[i] <= this->output_padded_shape[i], "Output tensor shape {} must be greater than or equal to input shape {} in each dimension, but is smaller " "in dimension {}", - this->output_tensor_shape, + this->output_padded_shape, input_shape, i); } - uint32_t num_rows = this->output_tensor_shape[-1]; - uint32_t inner_dim = this->output_tensor_shape[-2]; + uint32_t num_rows = this->output_padded_shape[-1]; + uint32_t inner_dim = this->output_padded_shape[-2]; TT_FATAL( inner_dim % TILE_WIDTH == 0 && num_rows % TILE_HEIGHT == 0, "To be tilizable output tensor shape {} must be divisible by tile size ({}, {})", - output_tensor_shape, + output_padded_shape, TILE_WIDTH, TILE_HEIGHT); @@ -50,41 +49,33 @@ void TilizeWithValPadding::validate(const std::vector& input_tensors) co TT_FATAL( this->output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout, "Output tensor must have the same memory layout as input tensor"); - for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) { + for (uint32_t i = 0; i < input_tensor_a.get_padded_shape().rank(); i++) { if (i != input_shape.rank() - 2) { - TT_FATAL(input_shape[i] == this->output_tensor_shape[i], "Error"); + TT_FATAL(input_shape[i] == this->output_padded_shape[i], "Error"); } } } } -std::vector TilizeWithValPadding::compute_output_shapes( +std::vector TilizeWithValPadding::compute_output_specs( const std::vector& input_tensors) const { - auto input_shape = input_tensors.at(0).get_legacy_shape(); - auto dimensions_pads = std::vector(); - for (auto index = 0; index < input_shape.rank(); index++) { - auto back = this->output_tensor_shape[index] - input_shape[index]; - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = back}); - } - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - return {tt::tt_metal::LegacyShape(this->output_tensor_shape, padding)}; -} - -std::vector TilizeWithValPadding::create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); + auto input_shape = input_tensor_a.get_padded_shape(); + if (input_tensor_a.memory_config().is_sharded()) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); auto shard_spec = input_tensor_a.shard_spec().value(); - shard_spec.shape[0] = tt::tt_metal::compute_volume(output_shape) / output_shape[-1]; + shard_spec.shape[0] = output_padded_shape.volume() / output_padded_shape[-1]; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return { - create_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor_a.device(), mem_config)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); + return {TensorSpec( + input_shape, + TensorLayout::fromPaddedShape( + output_dtype, PageConfig(Layout::TILE), mem_config, input_shape, output_padded_shape))}; } + return {TensorSpec( + input_shape, + TensorLayout::fromPaddedShape( + output_dtype, PageConfig(Layout::TILE), output_mem_config, input_shape, output_padded_shape))}; } // TODO: If pad is called on a tile and output is not tile, we could untilize then pad, and output is RM diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.hpp index 2317ba86ab5..3dc805e0f0b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/device/tilize_with_val_padding_op.hpp @@ -13,16 +13,14 @@ namespace ttnn::operations::data_movement { struct TilizeWithValPadding { - const tt::tt_metal::LegacyShape output_tensor_shape; + const ttnn::SimpleShape output_padded_shape; const PadValue pad_value; const tt::tt_metal::MemoryConfig output_mem_config; const tt::tt_metal::DataType output_dtype; const bool use_multicore; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; tt::tt_metal::operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp index 0e6e1044522..a5470e32122 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp @@ -66,7 +66,7 @@ MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) { .operation = std::move(base_tilize)}); } -tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) { +ttnn::SimpleShape squeeze_output_shape(ttnn::SimpleShape output_shape) { if (output_shape.rank() > 4) { std::array output_shape_4d; output_shape_4d[0] = 1; @@ -77,7 +77,7 @@ tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_ output_shape_4d[1] = output_shape[1 + extra_rank]; output_shape_4d[2] = output_shape[2 + extra_rank]; output_shape_4d[3] = output_shape[3 + extra_rank]; - return tt::tt_metal::LegacyShape(output_shape_4d); + return ttnn::SimpleShape(output_shape_4d); } return output_shape; } @@ -85,7 +85,7 @@ tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_ ttnn::Tensor ExecuteTilizeWithValPadding::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const PadValue pad_value, const std::optional& memory_config, std::optional output_dtype, @@ -93,7 +93,7 @@ ttnn::Tensor ExecuteTilizeWithValPadding::invoke( auto base_tilize = [=](const ttnn::Tensor& input_tensor) { return operation::run( TilizeWithValPadding{ - squeeze_output_shape(output_tensor_shape), + squeeze_output_shape(output_padded_shape), pad_value, memory_config.value_or(input_tensor.memory_config()), output_dtype.value_or(input_tensor.get_dtype()), @@ -109,13 +109,13 @@ ttnn::Tensor ExecuteTilizeWithValPadding::invoke( ttnn::Tensor ExecuteTilizeWithValPadding::invoke( const ttnn::Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const PadValue pad_value, const std::optional& memory_config, std::optional output_dtype, bool use_multicore) { return invoke( - DefaultQueueId, input_tensor, output_tensor_shape, pad_value, memory_config, output_dtype, use_multicore); + DefaultQueueId, input_tensor, output_padded_shape, pad_value, memory_config, output_dtype, use_multicore); } ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke( @@ -125,7 +125,7 @@ ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke( std::optional output_dtype, bool use_multicore) { using namespace tt::constants; - auto shape = input_tensor.get_legacy_shape(); + auto shape = input_tensor.get_padded_shape(); shape[2] = tt::round_up(shape[2], tt::constants::TILE_HEIGHT); shape[3] = tt::round_up(shape[3], tt::constants::TILE_WIDTH); diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp index 92f8374e58f..dec6a34333c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp @@ -18,7 +18,7 @@ struct ExecuteTilizeWithValPadding { static ttnn::Tensor invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const PadValue pad_value, const std::optional& memory_config = std::nullopt, std::optional output_dtype = std::nullopt, @@ -26,7 +26,7 @@ struct ExecuteTilizeWithValPadding { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const PadValue pad_value, const std::optional& memory_config = std::nullopt, std::optional output_dtype = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_pybind.hpp index 0fc3cc27145..3915564394e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_pybind.hpp @@ -53,7 +53,13 @@ void bind_tilize_with_val_padding(py::module& module) { bool use_multicore, uint8_t queue_id) { return self( - queue_id, input_tensor, output_tensor_shape, value, memory_config, output_dtype, use_multicore); + queue_id, + input_tensor, + Shape(output_tensor_shape).padded_shape(), + value, + memory_config, + output_dtype, + use_multicore); }, py::arg("input_tensor"), py::arg("output_tensor_shape"), diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index 5e66f403347..165376b82d8 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -113,7 +113,8 @@ Tensor AutoFormat::format_input_tensor( } else { pad_value_variant = (uint32_t)pad_value; } - return ttnn::tilize_with_val_padding(formatted_input, padded_shape, pad_value_variant, mem_config); + return ttnn::tilize_with_val_padding( + formatted_input, Shape(padded_shape).padded_shape(), pad_value_variant, mem_config); } else if (formatted_input.get_layout() == Layout::TILE && target_layout == Layout::ROW_MAJOR) { formatted_input = ttnn::untilize(formatted_input, mem_config); return ttnn::pad( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp index 2b3cbb83cba..bae91413f9a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_helper_functions.cpp @@ -351,11 +351,12 @@ void validate_input_with_dim(const Tensor& input, const int64_t& dim) { void validate_output_with_keepdim(const Tensor& input, const Tensor& output, const int64_t& dim, const bool& keepdim) { auto input_shape = input.get_padded_shape(); auto input_shape_wo_padding = input.get_logical_shape(); - const auto input_rank = input_shape.rank(); + const auto input_rank = input_shape_wo_padding.rank(); + auto padded_dim = dim + input_shape.rank() - input_shape_wo_padding.rank(); const auto output_shape = output.get_padded_shape(); const auto output_shape_wo_padding = output.get_logical_shape(); - const auto output_rank = output_shape.rank(); + const auto output_rank = output_shape_wo_padding.rank(); const bool is_tile_dim = (dim == input_rank - 1 || dim == input_rank - 2); @@ -365,7 +366,7 @@ void validate_output_with_keepdim(const Tensor& input, const Tensor& output, con if (keepdim) { bool ranks_are_equal = (input_rank == output_rank); - input_shape[dim] = (is_tile_dim) ? (TILE_HEIGHT) : (1); + input_shape[padded_dim] = (is_tile_dim) ? (TILE_HEIGHT) : (1); input_shape_wo_padding[dim] = 1; if (!ranks_are_equal) { @@ -387,31 +388,36 @@ void validate_output_with_keepdim(const Tensor& input, const Tensor& output, con expand_to_max_dim(input_dim_wo_padding, input_shape_wo_padding); expand_to_max_dim(output_dim_wo_padding, output_shape_wo_padding); - for (int i = 0; i < input_rank; ++i) { + for (int i = 0; i < input_shape.rank(); ++i) { TT_FATAL(input_dim[i] == output_dim[i], "Error"); + } + for (int i = 0; i < input_shape_wo_padding.rank(); ++i) { TT_FATAL(input_dim_wo_padding[i] == output_dim_wo_padding[i], "Error"); } } else { ttnn::SmallVector expected_output_shape; - ttnn::SmallVector expected_output_shape_wo_padding; for (int i = 0; i < output_shape.rank(); ++i) { - if (i == dim && !is_tile_dim) { + if (i == padded_dim && !is_tile_dim) { expected_output_shape.push_back(1); - expected_output_shape_wo_padding.push_back(1); } expected_output_shape.push_back(output_shape[i]); + } + ttnn::SmallVector expected_output_shape_wo_padding; + for (int i = 0; i < output_shape_wo_padding.rank(); ++i) { + if (i == dim && !is_tile_dim) { + expected_output_shape_wo_padding.push_back(1); + } expected_output_shape_wo_padding.push_back(output_shape_wo_padding[i]); } - log_debug(LogOp, "{}:{} expected_output_shape {}", __func__, __LINE__, expected_output_shape); log_debug( LogOp, "{}:{} expected_output_shape_wo_padding {}", __func__, __LINE__, expected_output_shape_wo_padding); - for (int i = 0; i < input_rank; ++i) { - if (i == dim) { - continue; - } - TT_FATAL(input_shape[i] == expected_output_shape[i], "Error"); - TT_FATAL(input_shape_wo_padding[i] == expected_output_shape_wo_padding[i], "Error"); + + for (int i = 0; i < expected_output_shape.size(); ++i) { + TT_FATAL(i == padded_dim || input_shape[i] == expected_output_shape[i], "Error"); + } + for (int i = 0; i < expected_output_shape_wo_padding.size(); ++i) { + TT_FATAL(i == dim || input_shape_wo_padding[i] == expected_output_shape_wo_padding[i], "Error"); } } } diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp index 339c919571a..a60e916070a 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.cpp @@ -4,6 +4,8 @@ #include "tensor_layout.hpp" +#include "ttnn/tensor/tensor_utils.hpp" + namespace tt::tt_metal { namespace { @@ -18,25 +20,27 @@ size_t round_up(size_t value, size_t multiple) { }; Alignment legacyShapeToAlignment( - const ttnn::Shape& shape, const PageConfig& page_config, const MemoryConfig& memory_config) { - const auto& logical_shape = shape.logical_shape(); - const auto& legacy_padded_shape = shape.padded_shape(); - if (logical_shape == legacy_padded_shape) { + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape, + const PageConfig& page_config, + const MemoryConfig& memory_config) { + if (logical_shape == padded_shape) { return Alignment{}; } - const auto rank = legacy_padded_shape.rank(); + const auto rank = padded_shape.rank(); bool alignment_can_be_2D = true; for (int i = rank - 3; i >= 0; i--) { - alignment_can_be_2D &= logical_shape[i] == legacy_padded_shape[i]; + alignment_can_be_2D &= logical_shape[i] == padded_shape[i]; } // SHARDED if (memory_config.shard_spec.has_value()) { TT_FATAL( alignment_can_be_2D, - "Tensor with shape {} cannot be sharded because alignment will have rank greater than 2!", - shape); + "Tensor with shape {} ({}) cannot be sharded because alignment will have rank greater than 2!", + logical_shape, + padded_shape); if (page_config.get_layout() == Layout::ROW_MAJOR) { const auto& shard_spec = memory_config.shard_spec.value(); if (shard_spec.physical_shard_shape.has_value()) { @@ -52,10 +56,10 @@ Alignment legacyShapeToAlignment( ttnn::SmallVector values(std::min((int)rank, 2)); const auto alignment_size = values.size(); if (alignment_size >= 1) { - values[alignment_size - 1] = legacy_padded_shape[-1]; + values[alignment_size - 1] = padded_shape[-1]; } if (alignment_size == 2) { - values[alignment_size - 2] = legacy_padded_shape[-2]; + values[alignment_size - 2] = padded_shape[-2]; } Alignment result(std::move(values)); return result; @@ -64,11 +68,11 @@ Alignment legacyShapeToAlignment( // INTERLEAVED with (deprecated) non-height/width padding // NOTE: Rank > 2 is guaranteed in this case ttnn::SmallVector values(rank); - values[rank - 1] = legacy_padded_shape[-1]; - values[rank - 2] = legacy_padded_shape[-2]; + values[rank - 1] = padded_shape[-1]; + values[rank - 2] = padded_shape[-2]; for (int i = rank - 3; i >= 0; i--) { - values[i] = legacy_padded_shape[i] * values[i + 1]; + values[i] = padded_shape[i] * values[i + 1]; } for (auto& value : values) { @@ -101,15 +105,40 @@ TensorLayout TensorLayout::fromLegacyPaddedShape( dtype, page_config, memory_config, - CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(legacy_shape, page_config, memory_config)); + CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment( + legacy_shape.logical_shape(), legacy_shape.padded_shape(), page_config, memory_config)); +} + +TensorLayout TensorLayout::fromPaddedShape( + DataType dtype, + const PageConfig& page_config, + const MemoryConfig& memory_config, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape) { + return TensorLayout( + dtype, + page_config, + memory_config, + CMAKE_UNIQUE_NAMESPACE::legacyShapeToAlignment(logical_shape, padded_shape, page_config, memory_config)); } void TensorLayout::initialize_alignment() { - if (!alignment_.empty()) { + auto default_alignment = page_config_.create_default_alignment(dtype_, memory_config_); + if (alignment_.empty()) { + alignment_ = default_alignment; return; } - alignment_ = page_config_.create_default_alignment(dtype_, memory_config_); + ttnn::SmallVector result(std::max(alignment_.size(), default_alignment.size()), 1); + for (size_t i = 0; i < alignment_.size(); i++) { + result[i + result.size() - alignment_.size()] = alignment_[i]; + } + for (size_t i = 0; i < default_alignment.size(); i++) { + size_t result_idx = i + result.size() - default_alignment.size(); + result[result_idx] = CMAKE_UNIQUE_NAMESPACE::round_up(result[result_idx], default_alignment[i]); + } + + alignment_ = Alignment(std::move(result)); } void TensorLayout::validate_alignment() const { @@ -310,39 +339,31 @@ Size TensorLayout::compute_page_shape(const Size& physical_size) const { } Strides TensorLayout::compute_strides(const ttnn::SimpleShape& shape) const { - const int rank = static_cast(shape.rank()); - const int alignment_rank = static_cast(alignment_.size()); - - Strides strides(rank, 1); - for (int i = rank - 2; i >= 0; i--) { - strides[i] = strides[i + 1] * shape[i + 1]; - - const int alignment_index = i - (rank - alignment_rank) + 1; - if (alignment_index >= 0) { - strides[i] = CMAKE_UNIQUE_NAMESPACE::round_up(strides[i], alignment_[alignment_index]); - } - } - - return strides; + auto padded_shape = compute_padded_shape(shape); + return tt::tt_metal::compute_strides(padded_shape); } ttnn::SimpleShape TensorLayout::compute_padded_shape(const ttnn::SimpleShape& shape) const { - ttnn::SmallVector padded_shape(shape.rank()); + ttnn::SmallVector padded_shape(std::max(shape.rank(), alignment_.size())); int rank_index = static_cast(shape.rank()) - 1; int alignment_index = static_cast(alignment_.size()) - 1; + int padded_shape_index = static_cast(padded_shape.size() - 1); size_t accum_alignment = 1; - for (; rank_index >= 0 && alignment_index >= 0; rank_index--, alignment_index--) { + for (; alignment_index >= 0; rank_index--, alignment_index--, padded_shape_index--) { + uint32_t shape_value = rank_index >= 0 ? shape[rank_index] : 1; + uint32_t alignment_value = alignment_[alignment_index]; + uint32_t& padded_shape_value = padded_shape[padded_shape_index]; + // The last 2 dimensions of a shape are special if (rank_index >= static_cast(shape.rank()) - 2) { - padded_shape[rank_index] = CMAKE_UNIQUE_NAMESPACE::round_up(shape[rank_index], alignment_[alignment_index]); + padded_shape_value = CMAKE_UNIQUE_NAMESPACE::round_up(shape_value, alignment_value); } else { - if (accum_alignment % alignment_[alignment_index] == 0) { + if (accum_alignment % alignment_value == 0) { // Alignment for this dimension is redundant, ignoring - padded_shape[rank_index] = shape[rank_index]; - } else if (alignment_[alignment_index] % accum_alignment == 0) { - padded_shape[rank_index] = - CMAKE_UNIQUE_NAMESPACE::round_up(shape[rank_index], alignment_[alignment_index] / accum_alignment); + padded_shape_value = shape_value; + } else if (alignment_value % accum_alignment == 0) { + padded_shape_value = CMAKE_UNIQUE_NAMESPACE::round_up(shape_value, alignment_value / accum_alignment); } else { TT_THROW( "Padded shape can't be deducted from TensorLayout parameters {} and Shape {}", alignment_, shape); @@ -351,11 +372,11 @@ ttnn::SimpleShape TensorLayout::compute_padded_shape(const ttnn::SimpleShape& sh // Alignment doesn't accumulate on the last dimension of a shape if (rank_index != static_cast(shape.rank()) - 1) { - accum_alignment *= padded_shape[rank_index]; + accum_alignment *= padded_shape_value; } } - for (; rank_index >= 0; rank_index--) { - padded_shape[rank_index] = shape[rank_index]; + for (; rank_index >= 0; rank_index--, padded_shape_index--) { + padded_shape[padded_shape_index] = shape[rank_index]; } return ttnn::SimpleShape(std::move(padded_shape)); } diff --git a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp index 2e9b24cb03a..6625bb19ac6 100644 --- a/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/tensor_layout.hpp @@ -14,7 +14,7 @@ namespace tt::tt_metal { -using Strides = std::vector; +using Strides = ttnn::SmallVector; // TensorLayout describes how a tensor is laid out in memory // It takes datatype, layout (eg. TILE vs. RM), memory (eg. DRAM vs. L1), sharding (ie. how you want to cut your logical @@ -31,6 +31,13 @@ class TensorLayout { const PageConfig& page_config, const MemoryConfig& memory_config, const ttnn::Shape& legacy_shape); + [[deprecated("Use of Padded Shape is deprecated")]] + static TensorLayout fromPaddedShape( + DataType dtype, + const PageConfig& page_config, + const MemoryConfig& memory_config, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape); Layout get_layout() const { return page_config_.get_layout(); } PageConfig get_page_config() const { return page_config_; } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 5dd2a70106b..9ab439600b9 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -136,11 +136,14 @@ void Tensor::TensorAttributes::update_main_thread_ref_count(Device* worker, uint } Tensor::Tensor( - Storage storage, const ttnn::Shape& shape, DataType dtype, Layout layout, const std::optional& tile) { + Storage storage, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape, + DataType dtype, + Layout layout, + const std::optional& tile) { using namespace tt::constants; - - if (tile.has_value() and // - (tile->get_tile_shape()[0] != TILE_WIDTH or tile->get_tile_shape()[1] != TILE_HEIGHT)) { + if (tile.has_value() && (tile->get_tile_shape()[0] != TILE_WIDTH || tile->get_tile_shape()[1] != TILE_HEIGHT)) { tt::log_warning( "only matmul op and ccl all-gather currently supports the customized tile shape: {}", tile->get_tile_shape()); @@ -156,10 +159,18 @@ Tensor::Tensor( init( std::move(storage), TensorSpec( - shape.logical_shape(), - TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(layout, tile), memory_config, shape))); + logical_shape, + TensorLayout::fromLegacyPaddedShape( + dtype, + PageConfig(layout, tile), + memory_config, + ttnn::Shape(logical_shape.view(), padded_shape.view())))); } +Tensor::Tensor( + Storage storage, const ttnn::Shape& shape, DataType dtype, Layout layout, const std::optional& tile) : + Tensor(std::move(storage), shape.logical_shape(), shape.padded_shape(), dtype, layout, tile) {} + Tensor::Tensor(Storage storage, TensorSpec tensor_spec) { init(std::move(storage), std::move(tensor_spec)); } void Tensor::init(Storage storage, TensorSpec tensor_spec) { @@ -654,12 +665,18 @@ template std::vector Tensor::to_vector() const; template std::vector Tensor::to_vector() const; template std::vector Tensor::to_vector() const; -Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config,uint8_t cq_id, +Tensor Tensor::to( + Device* target_device, + const MemoryConfig& mem_config, + uint8_t cq_id, const std::vector& sub_device_ids) const { return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config,uint8_t cq_id, +Tensor Tensor::to( + distributed::MeshDevice* mesh_device, + const MemoryConfig& mem_config, + uint8_t cq_id, const std::vector& sub_device_ids) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); return tensor_ops::tensor_to(*this, workers_to_use, mem_config, cq_id, sub_device_ids); @@ -701,10 +718,8 @@ const std::string Tensor::write_to_string() const { return tensor_impl::to_strin void Tensor::print() const { tensor_ops::tensor_print(*this); } Tensor Tensor::pad( - const tt::tt_metal::LegacyShape& output_tensor_shape, - const ttnn::SimpleShape& input_tensor_start, - float pad_value) const { - return tensor_ops::tensor_pad(*this, output_tensor_shape, input_tensor_start, pad_value); + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) const { + return tensor_ops::tensor_pad(*this, output_padded_shape, input_tensor_start, pad_value); } Tensor Tensor::unpad(const ttnn::SimpleShape& output_tensor_start, const ttnn::SimpleShape& output_tensor_end) const { @@ -987,7 +1002,8 @@ void write_tensor( "Error"); std::visit( tt::stl::overloaded{ - [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const DeviceStorage& device_storage) { + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids]( + const DeviceStorage& device_storage) { // Copying from host to a single device. void* host_data = std::visit( tt::stl::overloaded{ @@ -1014,7 +1030,8 @@ void write_tensor( /*blocking=*/false, sub_device_ids); }, - [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const MultiDeviceStorage& device_storage) { + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids]( + const MultiDeviceStorage& device_storage) { // Copying from host to multi-device. TT_ASSERT( std::holds_alternative(async_safe_tensor.get_storage()), diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 6827c421320..726fa4bb1ce 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -98,6 +98,13 @@ struct Tensor { DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor( + Storage storage, + const ttnn::SimpleShape& logical_shape, + const ttnn::SimpleShape& padded_shape, + DataType dtype, + Layout layout, + const std::optional& tile = std::nullopt); Tensor(Storage storage, TensorSpec tensor_spec); // Constructors to initialize unpopulated tensor with workers and storage specified. Use this when creating tensor @@ -198,7 +205,7 @@ struct Tensor { Tensor to(Layout target_layout, distributed::MeshDevice* mesh_device) const; Tensor pad( - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) const; diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 3f731c97c65..51c03865798 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -175,7 +175,7 @@ void validate_on_device_dtype_and_layout( Tensor pad_bfloat8_b( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { auto tile = tensor.get_tensor_spec().tile(); @@ -189,19 +189,22 @@ Tensor pad_bfloat8_b( auto float_tensor = Tensor( OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) - .pad(output_tensor_shape, input_tensor_start, pad_value); + .pad(output_padded_shape, input_tensor_start, pad_value); // Convert back to BFLOAT8_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - float_tensor.get_legacy_shape(), - DataType::BFLOAT8_B, - tensor.get_layout(), - tile); + TensorSpec output_spec( + tensor.logical_shape(), + TensorLayout::fromPaddedShape( + DataType::BFLOAT8_B, + tensor.get_tensor_spec().page_config(), + MemoryConfig{}, + tensor.logical_shape(), + output_padded_shape)); + return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_spec); } Tensor unpad_bfloat8_b( @@ -234,7 +237,7 @@ Tensor unpad_bfloat8_b( Tensor pad_bfloat4_b( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { auto tile = tensor.get_tensor_spec().tile(); @@ -248,19 +251,22 @@ Tensor pad_bfloat4_b( auto float_tensor = Tensor( OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tile) - .pad(output_tensor_shape, input_tensor_start, pad_value); + .pad(output_padded_shape, input_tensor_start, pad_value); // Convert back to BFLOAT4_B auto output_float_data = owned_buffer::get_as(float_tensor).get(); auto output_packed_data = pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - float_tensor.get_legacy_shape(), - DataType::BFLOAT4_B, - tensor.get_layout(), - tile); + TensorSpec output_spec( + tensor.logical_shape(), + TensorLayout::fromPaddedShape( + DataType::BFLOAT4_B, + tensor.get_tensor_spec().page_config(), + MemoryConfig{}, + tensor.logical_shape(), + output_padded_shape)); + return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_spec); } Tensor unpad_bfloat4_b( @@ -875,7 +881,15 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { raise_unsupported_storage(); } return Tensor( - storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tensor_spec().tile()); + storage, + TensorSpec( + tensor.get_logical_shape(), + TensorLayout::fromPaddedShape( + tensor.get_dtype(), + PageConfig(target_layout, tensor.get_tensor_spec().tile()), + MemoryConfig{}, + tensor.get_logical_shape(), + tensor.get_padded_shape()))); }, output_storage); } @@ -918,58 +932,62 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { if (ttnn::distributed::is_multi_device_tensor(tensor)) { return transform(tensor, [&](const Tensor& device_tensor) { - return pad(device_tensor, output_shape, input_tensor_start, pad_value); + return pad(device_tensor, output_padded_shape, input_tensor_start, pad_value); }); } + auto output_spec = TensorSpec( + tensor.get_logical_shape(), + TensorLayout::fromPaddedShape( + tensor.get_dtype(), + tensor.get_tensor_spec().page_config(), + MemoryConfig{}, + tensor.get_logical_shape(), + output_padded_shape)); + auto pad_value_ = static_cast(pad_value); - const auto input_shape = tensor.get_legacy_shape(); + const auto input_padded_shape = tensor.get_padded_shape(); const auto input_strides = tensor.strides(); - const auto input_data_type = tensor.get_dtype(); - - auto pad = [&input_shape, &input_strides, &input_data_type, &output_shape, &input_tensor_start, &pad_value_]( - const auto& input_buffer) { - auto compute_stride = [](const tt::tt_metal::LegacyShape& shape, uint32_t index) { - uint32_t stride = 1; - for (auto i = index + 1; i < shape.rank(); i++) { - stride *= shape[i]; - } - return stride; - }; + auto output_strides = output_spec.compute_strides(); + auto tensor_padded_shape = tensor.padded_shape(); + auto pad = [&](const auto& input_buffer) { ttnn::SmallVector> pad_size{}; - ttnn::SmallVector input_strides{}; - ttnn::SmallVector output_strides{}; - ttnn::SmallVector input_indices(input_shape.rank(), 0); + ttnn::SmallVector input_indices(tensor.padded_shape().rank(), 0); + + for (int index = 0; index < output_padded_shape.rank(); index++) { + uint32_t out_dim = output_padded_shape[index]; + + int tensor_idx = + index + static_cast(tensor_padded_shape.size()) - static_cast(output_padded_shape.size()); + uint32_t tensor_dim = tensor_idx >= 0 ? tensor_padded_shape[tensor_idx] : 1; + + int start_idx = + index + static_cast(input_tensor_start.size()) - static_cast(output_padded_shape.size()); + uint32_t start = start_idx >= 0 ? input_tensor_start[start_idx] : 0; - for (auto index = 0; index < output_shape.rank(); index++) { // Check if input tensor fits in output tensor given the input tensor start indices - TT_ASSERT( - input_shape[index] + input_tensor_start[index] <= output_shape[index], "Input tensor is out of bounds"); + TT_ASSERT(tensor_dim + start <= out_dim, "Input tensor is out of bounds"); // Figure out pad size on each dim - pad_size.push_back( - {input_tensor_start[index], output_shape[index] - input_shape[index] - input_tensor_start[index]}); - - input_strides.push_back(compute_stride(input_shape, index)); - output_strides.push_back(compute_stride(output_shape, index)); + pad_size.push_back({start, out_dim - tensor_dim - start}); } auto flat_output_index = 0; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + auto output_buffer = owned_buffer::create(output_spec.padded_shape().volume()); std::function pad_to_tile = [&](std::size_t dim) -> void { for (auto i = 0; i < pad_size[dim][0] * output_strides[dim]; i++) { output_buffer[flat_output_index++] = pad_value_; } - for (auto i = 0; i < input_shape[dim]; i++) { + for (auto i = 0; i < input_padded_shape[dim]; i++) { input_indices[dim] = i; - if (dim == input_shape.rank() - 1) { + if (dim == input_padded_shape.rank() - 1) { auto flat_input_index = compute_flat_input_index(input_indices, input_strides); output_buffer[flat_output_index++] = input_buffer[flat_input_index]; } else { @@ -1006,61 +1024,56 @@ Tensor pad( } }, tensor.get_storage()); - return Tensor( - OwnedStorage{output_buffer}, - output_shape, - tensor.get_dtype(), - tensor.get_layout(), - tensor.get_tensor_spec().tile()); + return Tensor(OwnedStorage{output_buffer}, output_spec); } template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); template <> Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { - return pad_bfloat8_b(tensor, output_shape, input_tensor_start, pad_value); + return pad_bfloat8_b(tensor, output_padded_shape, input_tensor_start, pad_value); } template <> Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { - return pad_bfloat4_b(tensor, output_shape, input_tensor_start, pad_value); + return pad_bfloat4_b(tensor, output_padded_shape, input_tensor_start, pad_value); } template diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index a6db8b14388..e0bb1149ef7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -99,6 +99,9 @@ static ttnn::SmallVector to_4D_shape(const tt::tt_metal::LegacyShape& template typename BufferType> inline std::vector convert_layout_row_major_to_tile( const Size& shape, const Tile& tile, const BufferType& data_to_convert) { + if (shape.width() * shape.height() == 0) { + return std::vector(); + } TT_FATAL( (shape.height() % tile.get_tile_shape()[0] == 0 && shape.width() % tile.get_tile_shape()[1] == 0), "Unsupported shape for tensor conversion from row-major to tile layout. The tensor shape height and width must " @@ -211,7 +214,7 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout); template Tensor pad( const Tensor& tensor, - const tt::tt_metal::LegacyShape& output_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index c2df9f3e430..b5ca7d6dbb7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -254,12 +254,12 @@ void tensor_print(const Tensor& input_tensor) { Tensor tensor_pad( const Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_padded_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value) { ZoneScoped; GraphTracker::instance().track_function_start( - "Tensor::pad", input_tensor, output_tensor_shape, input_tensor_start, pad_value); + "Tensor::pad", input_tensor, output_padded_shape, input_tensor_start, pad_value); TT_ASSERT( input_tensor.storage_type() == StorageType::OWNED or input_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST or @@ -273,17 +273,7 @@ Tensor tensor_pad( return input_tensor; } - auto input_shape = input_tensor.get_legacy_shape(); - auto dimensions_pads = std::vector(); - for (auto index = 0; index < input_shape.rank(); index++) { - auto front = input_tensor_start[index]; - auto back = output_tensor_shape[index] - (input_tensor_start[index] + input_shape[index]); - dimensions_pads.push_back(Padding::PadDimension{.front = front, .back = back}); - } - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - auto output_shape_with_padding = tt::tt_metal::LegacyShape(output_tensor_shape, padding); - - auto output = tensor_impl::pad_wrapper(input_tensor, output_shape_with_padding, input_tensor_start, pad_value); + auto output = tensor_impl::pad_wrapper(input_tensor, output_padded_shape, input_tensor_start, pad_value); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -306,30 +296,26 @@ Tensor tensor_unpad( Tensor tensor_pad_to_tile(const Tensor& input_tensor, float pad_value) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::pad_to_tile", input_tensor, pad_value); - uint32_t height = input_tensor.get_legacy_shape()[-2]; - uint32_t width = input_tensor.get_legacy_shape()[-1]; + uint32_t height = input_tensor.get_logical_shape()[-2]; + uint32_t width = input_tensor.get_logical_shape()[-1]; uint32_t padded_height = round_up(height, constants::TILE_HEIGHT); uint32_t padded_width = round_up(width, constants::TILE_WIDTH); ttnn::SmallVector shape; - ttnn::SmallVector padded_shape; ttnn::SmallVector input_tensor_start; - for (auto index = 0; index < input_tensor.get_legacy_shape().rank() - 2; index++) { - shape.push_back(input_tensor.get_legacy_shape().without_padding()[index]); - padded_shape.push_back(input_tensor.get_legacy_shape()[index]); + for (auto index = 0; index < input_tensor.get_logical_shape().rank() - 2; index++) { + shape.push_back(input_tensor.get_logical_shape()[index]); input_tensor_start.push_back(0); } - shape.push_back(height); - shape.push_back(width); - padded_shape.push_back(padded_height); - padded_shape.push_back(padded_width); + shape.push_back(padded_height); + shape.push_back(padded_width); input_tensor_start.push_back(0); input_tensor_start.push_back(0); auto output = input_tensor.pad( - tt::tt_metal::LegacyShape(shape, padded_shape), ttnn::SimpleShape{std::move(input_tensor_start)}, pad_value); + ttnn::SimpleShape(std::move(shape)), ttnn::SimpleShape{std::move(input_tensor_start)}, pad_value); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -368,19 +354,7 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape); - const auto& new_padded_shape = new_shape.padded_shape(); const auto tile = input_tensor.get_tensor_spec().tile(); - TT_ASSERT( - input_tensor.volume() == new_padded_shape.volume(), - "{} != {}", - input_tensor.volume(), - new_padded_shape.volume()); - if (input_tensor.get_layout() == Layout::TILE) { - TT_ASSERT( - new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 && - new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 && - "Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!"); - } auto output = std::visit( [&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor { using T = std::decay_t; @@ -424,12 +398,12 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) if (new_shape[-1] == 0 || shard_shape[1] == 0) { mul_div = 0; } else { - mul_div = new_shape[-1] > shard_shape[1] ? - (new_shape[-1] / shard_shape[1]) : - (shard_shape[1] / new_shape[-1]); + mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1]) + : (shard_shape[1] / new_shape[-1]); } - shard_spec.shape[0] = new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div; + shard_spec.shape[0] = + new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div; shard_spec.shape[1] = new_shape[-1]; shard_spec_buffer.page_shape = {1, new_shape[-1]}; diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index b65af33cb42..392ae6e7665 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -45,7 +45,7 @@ void tensor_print(const Tensor& input_tensor); Tensor tensor_pad( const Tensor& input_tensor, - const tt::tt_metal::LegacyShape& output_tensor_shape, + const ttnn::SimpleShape& output_tensor_shape, const ttnn::SimpleShape& input_tensor_start, float pad_value); diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 33fd91e8b40..284520834d0 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -206,14 +206,13 @@ class LegacyShape { } } explicit LegacyShape(tt::stl::Span shape, tt::stl::Span shape_with_tile_padding) : - rank_(shape.size()), dimensions_{}, padding_{shape.size()} { - TT_ASSERT( - shape.size() == shape_with_tile_padding.size(), - "Shape and shape_with_tile_padding must have the same size"); - for (auto index = 0; index < shape.size(); index++) { - auto padded_dimension = shape_with_tile_padding[index]; + rank_(shape_with_tile_padding.size()), dimensions_{}, padding_{shape_with_tile_padding.size()} { + for (int index = 0; index < shape_with_tile_padding.size(); index++) { + int shape_index = index + static_cast(shape.size()) - static_cast(shape_with_tile_padding.size()); + int dimension = shape_index >= 0 ? shape[shape_index] : 1; + int padded_dimension = shape_with_tile_padding[index]; this->dimensions_[index] = padded_dimension; - this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]}; + this->padding_[index] = {.front = 0, .back = static_cast(padded_dimension - dimension)}; } } explicit LegacyShape(const ttnn::SmallVector& shape, const ttnn::SmallVector& shape_with_tile_padding)