diff --git a/tests/ttnn/unit_tests/test_tilize_untilize_2D.py b/tests/ttnn/unit_tests/test_tilize_untilize_2D.py new file mode 100644 index 00000000000..d71b8160d0f --- /dev/null +++ b/tests/ttnn/unit_tests/test_tilize_untilize_2D.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger +import pytest + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout +from models.utility_functions import is_grayskull, is_blackhole, torch_random, skip_for_grayskull + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("use_pack_untilize", [False, True]) +@pytest.mark.parametrize("H", [32, 512]) +@pytest.mark.parametrize("W", [1024, 256]) +def test_untilize_2D(device, in_dtype, use_multicore, use_pack_untilize, H, W): + torch_input_shape = [H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT) + + output_tt = ttnn.untilize(ttnn_input, use_multicore=use_multicore, use_pack_untilize=use_pack_untilize) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("H", [128, 2048]) +@pytest.mark.parametrize("W", [32, 1056]) +def test_tilize_2D(device, in_dtype, use_multicore, H, W): + torch_input_shape = [H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT) + + output_tt = ttnn.tilize(ttnn_input, use_multicore=use_multicore) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("use_pack_untilize", [False, True]) +@pytest.mark.parametrize("H", [32, 43]) +@pytest.mark.parametrize("W", [64, 76]) +def test_untilize_with_unpadding_2D(device, in_dtype, use_multicore, use_pack_untilize, H, W): + torch_input_shape = [H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT) + + output_tt = ttnn.untilize_with_unpadding( + ttnn_input, [H - 1, W - 1], use_multicore=use_multicore, use_pack_untilize=use_pack_untilize + ) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("pad_value", [2, 1.3]) +@pytest.mark.parametrize("H", [32, 43]) +@pytest.mark.parametrize("W", [64, 76]) +def test_tilize_with_val_padding_2D(device, in_dtype, use_multicore, H, W, pad_value): + torch_input_shape = [H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT) + + output_tt = ttnn.tilize_with_val_padding(ttnn_input, [64, 128], pad_value, use_multicore=use_multicore) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing + + +@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32]) +@pytest.mark.parametrize("use_multicore", [False, True]) +@pytest.mark.parametrize("H", [128, 98]) +@pytest.mark.parametrize("W", [78, 1024]) +def test_tilize_with_zero_padding_2D(device, in_dtype, use_multicore, H, W): + torch_input_shape = [H, W] + + torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16() + + ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT) + + output_tt = ttnn.tilize_with_zero_padding(ttnn_input, use_multicore=use_multicore) + output_torch = ttnn.to_torch(output_tt) + + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch) + logger.info(pcc_msg) + assert passing diff --git a/tests/ttnn/unit_tests/test_to_layout.py b/tests/ttnn/unit_tests/test_to_layout.py index 52144c89abb..fd5601dc0f2 100644 --- a/tests/ttnn/unit_tests/test_to_layout.py +++ b/tests/ttnn/unit_tests/test_to_layout.py @@ -142,8 +142,8 @@ def test_to_layout_device(device, h, w, input_layout, output_layout): assert_with_pcc(torch_input_tensor, torch_brought_back) -@pytest.mark.parametrize("shape", [[1, 50, 1, 3, 768], [1, 1370, 1, 3, 1280]]) -@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("shape", [[3, 50, 1, 3, 768], [3, 1370, 1, 32, 1280]]) +@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) def test_to_layout_5D(shape, input_layout, output_layout, device): torch.manual_seed(2005) @@ -154,7 +154,7 @@ def test_to_layout_5D(shape, input_layout, output_layout, device): assert_with_pcc(input_a, output_tensor) -@pytest.mark.parametrize("shape", [[1, 1, 58, 1, 37, 256], [1, 1, 64, 1, 90, 1280]]) +@pytest.mark.parametrize("shape", [[4, 7, 58, 1, 37, 256], [1, 3, 64, 1, 32, 1280]]) @pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) def test_to_layout_6D(shape, input_layout, output_layout, device): @@ -166,26 +166,25 @@ def test_to_layout_6D(shape, input_layout, output_layout, device): assert_with_pcc(input_a, output_tensor) -@pytest.mark.skip("Skipping due to hang on to_layout to tile where input shape has 1 in it") -@pytest.mark.parametrize( - "config", - [ - [[3, 1370, 1, 1, 1280], ttnn.ROW_MAJOR_LAYOUT], # hang - [[3, 50, 1, 1, 768], ttnn.ROW_MAJOR_LAYOUT], # hang - [[3, 50, 1, 1, 1024], ttnn.ROW_MAJOR_LAYOUT], # hang - [[3, 197, 1, 1, 768], ttnn.ROW_MAJOR_LAYOUT], # hang - [[3, 197, 1, 1, 1024], ttnn.ROW_MAJOR_LAYOUT], # hang - ], -) -@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG]) -def test_to_layout_hangs(config, memory_config, device): +@pytest.mark.parametrize("shape", [[3, 50, 1, 1, 768], [3, 50, 1, 1, 1024], [3, 197, 1, 1, 768], [3, 197, 1, 1, 1024]]) +@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_to_layout_nd_hangs(shape, input_layout, output_layout, device): torch.manual_seed(2005) - torch_input = torch.randn(config[0], dtype=torch.bfloat16) + input_a = torch.randn(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.to_layout(input_tensor, output_layout) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(input_a, output_tensor) - tt_input = ttnn.from_torch( - torch_input, dtype=ttnn.DataType.BFLOAT16, layout=config[1], device=device, memory_config=memory_config - ) - tt_output = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT) - tt_output = ttnn.to_torch(tt_output) - assert_with_pcc(torch_input, tt_output, 0.9999) +@pytest.mark.parametrize("shape", [[1, 768], [3, 230], [32, 768], [32, 143]]) +@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +def test_to_layout_for_2D(shape, input_layout, output_layout, device): + torch.manual_seed(2005) + input_a = torch.randn(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.to_layout(input_tensor, output_layout) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(input_a, output_tensor) diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index d7f8cb45ea4..90e508aa346 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -23,20 +23,6 @@ namespace core { namespace detail { -inline bool validate_nd_support(const ttnn::Tensor& tensor_arg, const ttnn::Layout layout) { - const auto initial_shape = tensor_arg.get_shape(); - if (initial_shape.rank() > 4 && tensor_arg.get_layout() != layout) { - for (int i = 0; i < initial_shape.rank() - 4; i++) { - TT_FATAL( - initial_shape[i] == 1, - "For ND tensors, shape dimensions greater than 4 should be 1, shape at index{} is {}", - i, - initial_shape[i]); - } - } - return true; -} - // Issue #8617: Limitations on tensor width for multicore device tilize inline bool use_multicore_device_tilize( const Tensor& input, const std::optional& output_dtype) { @@ -142,7 +128,6 @@ Tensor to_layout_impl( if (not requires_padding_change(tensor, layout, tensor.get_shape())) { if (layout == ttnn::ROW_MAJOR_LAYOUT) { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); - validate_nd_support(tensor_arg, layout); return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize); } else if (layout == ttnn::TILE_LAYOUT) { if (tensor.is_sharded()) { @@ -153,7 +138,6 @@ Tensor to_layout_impl( "TILE_SIZE!"); } } - validate_nd_support(tensor_arg, layout); return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); } else { throw std::runtime_error("ttnn::to_layout: Unsupported layout!"); @@ -171,7 +155,6 @@ Tensor to_layout_impl( output_tensor_end.push_back(tensor.get_shape()[index] - 1); } - validate_nd_support(tensor_arg, layout); tensor = ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize); return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape}); @@ -198,7 +181,6 @@ Tensor to_layout_impl( {0, padded_output_shape[2] - output_shape[2]}, {0, padded_output_shape[3] - output_shape[3]}}; tensor = ttnn::pad(0, tensor, padding, 0, true, std::nullopt); - validate_nd_support(tensor_arg, layout); return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize); } else { PadValue pad_value_variant; @@ -208,7 +190,6 @@ Tensor to_layout_impl( pad_value_variant = (uint32_t)0; } - validate_nd_support(tensor_arg, layout); tensor = ttnn::tilize_with_val_padding( tensor, padded_output_shape, pad_value_variant, output_memory_config, dtype, use_multicore_tilize); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp index 16a027eb9de..72402de22e6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.cpp @@ -5,24 +5,52 @@ #include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" namespace ttnn { namespace operations { namespace data_movement { -ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor) { - auto shape = tensor.get_shape(); + +ttnn::Shape squeeze_shape_to_4D(ttnn::Shape shape) { if (shape.rank() <= 4) { + return shape; + } + std::array shape_4d; + shape_4d[0] = 1; + int extra_rank = shape.rank() - 4; + for (int i = extra_rank; i >= 0; i--) { + shape_4d[0] *= shape[i]; + } + shape_4d[1] = shape[1 + extra_rank]; + shape_4d[2] = shape[2 + extra_rank]; + shape_4d[3] = shape[3 + extra_rank]; + return ttnn::Shape(shape_4d); +} + +ttnn::Tensor squeeze_from_ND_to_4D(const ttnn::Tensor& tensor) { + auto shape = tensor.get_shape(); + auto rank = shape.rank(); + TT_FATAL(shape.rank() >= 4, "Tensor has to be of rank larger than 4! Instead is {}", shape.rank()); + if (rank == 4) { return tensor; - } else { - auto rank = shape.rank(); + } + int i = 0; + // This is a workaround for now, it will be fixed in another PR + if (shape[i] == 1) { auto squeezed = tensor; - while (rank > 4) { + while (rank > 4 && shape[i] == 1) { squeezed = ttnn::squeeze(squeezed, 0); rank = squeezed.get_shape().rank(); + i++; } - return squeezed; + if (rank <= 4) { + return squeezed; + } + return ttnn::reshape(squeezed, squeeze_shape_to_4D(shape)); } -}; + return ttnn::reshape(tensor, squeeze_shape_to_4D(shape)); +} ttnn::Tensor pad_to_tile_vol( uint8_t queue_id, diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp index 78938828448..6c95291da9a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/common.hpp @@ -11,7 +11,9 @@ namespace ttnn { namespace operations { namespace data_movement { -ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor); + +ttnn::Shape squeeze_shape_to_4D(ttnn::Shape output_shape); +ttnn::Tensor squeeze_from_ND_to_4D(const ttnn::Tensor& tensor); ttnn::Tensor pad_to_tile_vol( uint8_t queue_id, @@ -148,7 +150,6 @@ class MassagedOperation { OpType operation_; }; -ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor); ttnn::Tensor pad_to_tile_vol( uint8_t queue_id, const ttnn::Tensor& tensor, diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp index c0eb410fa2f..2b3da8eab0f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp @@ -25,7 +25,7 @@ MassagedTilize build_ndiml_tilize(BaseTilizeType base_tilize) { .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeArgs { *original_shape = input_tensor.get_shape(); - ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor); return std::make_tuple(squeezed_tensor); }, .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp index 0e6e1044522..641c4c26afe 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp @@ -52,15 +52,14 @@ MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) { .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeValArgs { *original_shape = input_tensor.get_shape(); - ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor); return std::make_tuple(squeezed_tensor); }, .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { const auto tile = output.get_tensor_spec().tile(); uint32_t tile_height = tile.get_height(); uint32_t tile_width = tile.get_width(); - auto unsqueezed_tensor = - ttnn::reshape(output, update_original_shape(*original_shape, tile_height, tile_width)); + auto unsqueezed_tensor = ttnn::reshape(output, *original_shape); return unsqueezed_tensor; }, .operation = std::move(base_tilize)}); @@ -127,8 +126,8 @@ ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke( using namespace tt::constants; auto shape = input_tensor.get_legacy_shape(); - shape[2] = tt::round_up(shape[2], tt::constants::TILE_HEIGHT); - shape[3] = tt::round_up(shape[3], tt::constants::TILE_WIDTH); + shape[-2] = tt::round_up(shape[-2], tt::constants::TILE_HEIGHT); + shape[-1] = tt::round_up(shape[-1], tt::constants::TILE_WIDTH); PadValue pad_value; if (input_tensor.get_dtype() == DataType::BFLOAT16 or input_tensor.get_dtype() == DataType::FLOAT32) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp index 1e9d303def5..f6e4f4a2e47 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp @@ -25,7 +25,7 @@ MassagedUntilize build_ndiml_untilize(BaseUntilizeType base_untilize) { .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeArgs { *original_shape = input_tensor.get_shape(); - ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor); return std::make_tuple(squeezed_tensor); }, .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp index 0215c6869a5..7797376c7f4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp @@ -13,7 +13,7 @@ using namespace tt::tt_metal; -LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) { +LegacyShape squeeze_vector_shape(tt::tt_metal::LegacyShape output_shape) { if (output_shape.rank() > 4) { std::vector output_shape_4d(output_shape.rank()); output_shape_4d[0] = 1; @@ -45,7 +45,7 @@ MassagedUntilizeVal build_ndiml_untilize_val(BaseUntilizeValType base_untilize) .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeValArgs { *original_shape = input_tensor.get_shape(); - ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor); return std::make_tuple(squeezed_tensor); }, .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { @@ -71,7 +71,7 @@ ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke( for (auto index = 0; index < input_tensor.get_shape().rank(); ++index) { output_end_vector.push_back(input_tensor.get_shape()[index] - 1); } - output_end = squeeze_output_shape(LegacyShape(output_end_vector)); + output_end = squeeze_vector_shape(LegacyShape(output_end_vector)); } else { output_end = output_tensor_end; }