diff --git a/tests/ttnn/unit_tests/test_to_layout.py b/tests/ttnn/unit_tests/test_to_layout.py index b84a8f4c5fc..d074b41b570 100644 --- a/tests/ttnn/unit_tests/test_to_layout.py +++ b/tests/ttnn/unit_tests/test_to_layout.py @@ -140,3 +140,27 @@ def test_to_layout_device(device, h, w, input_layout, output_layout): torch_brought_back = ttnn.to_torch(new_layout_tensor) assert_with_pcc(torch_input_tensor, torch_brought_back) + + +@pytest.mark.parametrize("shape", [[1, 50, 1, 3, 768], [1, 1370, 1, 3, 1280]]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +def test_to_layout_5D(shape, input_layout, output_layout, device): + torch.manual_seed(2005) + input_a = torch.randn(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.to_layout(input_tensor, output_layout) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(input_a, output_tensor) + + +@pytest.mark.parametrize("shape", [[1, 1, 58, 1, 37, 256], [1, 1, 64, 1, 90, 1280]]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +def test_to_layout_6D(shape, input_layout, output_layout, device): + torch.manual_seed(2005) + input_a = torch.randn(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.to_layout(input_tensor, output_layout) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(input_a, output_tensor) diff --git a/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp b/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp index e44919022f1..cd783f161b5 100644 --- a/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp +++ b/ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp @@ -184,11 +184,7 @@ inline std::vector> distribute_work( bool has_cliff, uint32_t nblocks_per_core_cliff) { TT_FATAL( - logical_shape.rank() >= 2 && logical_shape.rank() <= 4, - "Only 2D, 3D, and 4D tensors are supported. Shape: {}", - "Error", - logical_shape, - padding); + logical_shape.rank() >= 2, "Logical shape rank needs to be >=2. Shape: {}", "Error", logical_shape, padding); auto input_w = logical_shape.rank() >= 4 ? logical_shape[-4] : 1; auto input_z = logical_shape.rank() >= 3 ? logical_shape[-3] : 1; diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp index 4186adc4d96..0e6e1044522 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.cpp @@ -7,11 +7,81 @@ #include "device/tilize_with_val_padding_op.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/run_operation.hpp" +#include "ttnn/operations/data_movement/common/common.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" using namespace tt::tt_metal; namespace ttnn::operations::data_movement { +using OwnedTilizeValArgs = std::tuple; +using BaseTilizeValType = std::function; + +using MassagedTilizeVal = MassagedOperation; +using MassagedTilizeValParams = MassagedOperationParams; + +ttnn::Shape update_original_shape(ttnn::Shape& original, uint32_t tile_height, uint32_t tile_width) { + std::vector update_original(original.rank()); + uint32_t indx1 = original.rank() - 1; + uint32_t indx2 = original.rank() - 2; + if (original[indx2] % tile_height != 0) { + update_original[indx2] = (original[indx2] / tile_height + 1) * tile_height; + for (int i = 0; i < original.rank(); i++) { + if (i != indx2) { + update_original[i] = original[i]; + } + } + return tt::tt_metal::LegacyShape(update_original); + } + + else if (original[indx1] % tile_width != 0) { + update_original[indx1] = (original[indx1] / tile_width + 1) * tile_width; + for (int i = 0; i < original.rank(); i++) { + if (i != indx1) { + update_original[i] = original[i]; + } + } + return tt::tt_metal::LegacyShape(update_original); + } + return original; +} + +MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) { + auto original_shape = std::make_shared(ttnn::Shape{}); + return MassagedTilizeVal(MassagedTilizeValParams{ + .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, + .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeValArgs { + *original_shape = input_tensor.get_shape(); + ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + return std::make_tuple(squeezed_tensor); + }, + .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { + const auto tile = output.get_tensor_spec().tile(); + uint32_t tile_height = tile.get_height(); + uint32_t tile_width = tile.get_width(); + auto unsqueezed_tensor = + ttnn::reshape(output, update_original_shape(*original_shape, tile_height, tile_width)); + return unsqueezed_tensor; + }, + .operation = std::move(base_tilize)}); +} + +tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) { + if (output_shape.rank() > 4) { + std::array output_shape_4d; + output_shape_4d[0] = 1; + int extra_rank = output_shape.rank() - 4; + for (int i = extra_rank; i >= 0; i--) { + output_shape_4d[0] *= output_shape[i]; + } + output_shape_4d[1] = output_shape[1 + extra_rank]; + output_shape_4d[2] = output_shape[2 + extra_rank]; + output_shape_4d[3] = output_shape[3 + extra_rank]; + return tt::tt_metal::LegacyShape(output_shape_4d); + } + return output_shape; +} + ttnn::Tensor ExecuteTilizeWithValPadding::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, @@ -20,18 +90,21 @@ ttnn::Tensor ExecuteTilizeWithValPadding::invoke( const std::optional& memory_config, std::optional output_dtype, bool use_multicore) { - return operation::run( - TilizeWithValPadding{ - output_tensor_shape, - pad_value, - memory_config.value_or(input_tensor.memory_config()), - output_dtype.value_or(input_tensor.get_dtype()), - use_multicore}, - {input_tensor}, - {}, - {}, - queue_id) - .at(0); + auto base_tilize = [=](const ttnn::Tensor& input_tensor) { + return operation::run( + TilizeWithValPadding{ + squeeze_output_shape(output_tensor_shape), + pad_value, + memory_config.value_or(input_tensor.memory_config()), + output_dtype.value_or(input_tensor.get_dtype()), + use_multicore}, + {input_tensor}, + {}, + {}, + queue_id)[0]; + }; + + return build_ndiml_tilize_val(base_tilize)(input_tensor); } ttnn::Tensor ExecuteTilizeWithValPadding::invoke( @@ -54,8 +127,8 @@ ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke( using namespace tt::constants; auto shape = input_tensor.get_legacy_shape(); - shape[2] = tt::round_up(shape[2], TILE_HEIGHT); - shape[3] = tt::round_up(shape[3], TILE_WIDTH); + shape[2] = tt::round_up(shape[2], tt::constants::TILE_HEIGHT); + shape[3] = tt::round_up(shape[3], tt::constants::TILE_WIDTH); PadValue pad_value; if (input_tensor.get_dtype() == DataType::BFLOAT16 or input_tensor.get_dtype() == DataType::FLOAT32) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp index b2feccd12c9..0215c6869a5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.cpp @@ -8,10 +8,53 @@ #include "ttnn/common/constants.hpp" #include "ttnn/run_operation.hpp" +#include "ttnn/operations/data_movement/common/common.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" + using namespace tt::tt_metal; +LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) { + if (output_shape.rank() > 4) { + std::vector output_shape_4d(output_shape.rank()); + output_shape_4d[0] = 1; + int extra_rank = output_shape.rank() - 4; + for (int i = extra_rank; i >= 0; i--) { + output_shape_4d[0] *= (output_shape[i] + 1); + } + output_shape_4d[0]--; + output_shape_4d[1] = output_shape[1 + extra_rank]; + output_shape_4d[2] = output_shape[2 + extra_rank]; + output_shape_4d[3] = output_shape[3 + extra_rank]; + return tt::tt_metal::LegacyShape(output_shape_4d); + } + return output_shape; +} + namespace ttnn::operations::data_movement { +using OwnedUntilizeValArgs = std::tuple; +using BaseUntilizeValType = std::function; + +using MassagedUntilizeVal = MassagedOperation; +using MassagedUntilizeValParams = MassagedOperationParams; + +MassagedUntilizeVal build_ndiml_untilize_val(BaseUntilizeValType base_untilize) { + auto original_shape = std::make_shared(ttnn::Shape{}); + + return MassagedUntilizeVal(MassagedUntilizeValParams{ + .predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; }, + .pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeValArgs { + *original_shape = input_tensor.get_shape(); + ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor); + return std::make_tuple(squeezed_tensor); + }, + .post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor { + auto unsqueezed_tensor = ttnn::reshape(output, *original_shape); + return unsqueezed_tensor; + }, + .operation = std::move(base_untilize)}); +} + ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke( uint8_t queue_id, const ttnn::Tensor& input_tensor, @@ -22,18 +65,32 @@ ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke( // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32; - return operation::run( - UntilizeWithUnpadding{ - output_tensor_end, - memory_config.value_or(input_tensor.memory_config()), - use_multicore, - use_pack_untilize, - fp32_dest_acc_en}, - {input_tensor}, - {}, - {}, - queue_id) - .at(0); + std::vector output_end_vector; + tt::tt_metal::LegacyShape output_end = tt::tt_metal::LegacyShape{}; + if (input_tensor.get_shape().rank() > 4) { + for (auto index = 0; index < input_tensor.get_shape().rank(); ++index) { + output_end_vector.push_back(input_tensor.get_shape()[index] - 1); + } + output_end = squeeze_output_shape(LegacyShape(output_end_vector)); + } else { + output_end = output_tensor_end; + } + + auto base_untilize = [=](const ttnn::Tensor& input_tensor) { + return operation::run( + UntilizeWithUnpadding{ + output_end, + memory_config.value_or(input_tensor.memory_config()), + use_multicore, + use_pack_untilize, + fp32_dest_acc_en}, + {input_tensor}, + {}, + {}, + queue_id)[0]; + }; + + return build_ndiml_untilize_val(base_untilize)(input_tensor); } ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(