Skip to content

Commit

Permalink
Support all ND shapes for tilize/untilize (#16299)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue
#16188

### Problem description
ND tilize/untilize support shapes with a value of 1 only for dimensions
greater than 4

### What's changed
Reshape the tensors that have shape values other than 1 for dims>4
Delete the nd_support assertion for ND

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12474559635
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12549967742
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
nardoTT authored Jan 5, 2025
1 parent aaf2d73 commit d939510
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 61 deletions.
113 changes: 113 additions & 0 deletions tests/ttnn/unit_tests/test_tilize_untilize_2D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout
from models.utility_functions import is_grayskull, is_blackhole, torch_random, skip_for_grayskull


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("use_pack_untilize", [False, True])
@pytest.mark.parametrize("H", [32, 512])
@pytest.mark.parametrize("W", [1024, 256])
def test_untilize_2D(device, in_dtype, use_multicore, use_pack_untilize, H, W):
torch_input_shape = [H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT)

output_tt = ttnn.untilize(ttnn_input, use_multicore=use_multicore, use_pack_untilize=use_pack_untilize)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("H", [128, 2048])
@pytest.mark.parametrize("W", [32, 1056])
def test_tilize_2D(device, in_dtype, use_multicore, H, W):
torch_input_shape = [H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT)

output_tt = ttnn.tilize(ttnn_input, use_multicore=use_multicore)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("use_pack_untilize", [False, True])
@pytest.mark.parametrize("H", [32, 43])
@pytest.mark.parametrize("W", [64, 76])
def test_untilize_with_unpadding_2D(device, in_dtype, use_multicore, use_pack_untilize, H, W):
torch_input_shape = [H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.TILE_LAYOUT)

output_tt = ttnn.untilize_with_unpadding(
ttnn_input, [H - 1, W - 1], use_multicore=use_multicore, use_pack_untilize=use_pack_untilize
)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("pad_value", [2, 1.3])
@pytest.mark.parametrize("H", [32, 43])
@pytest.mark.parametrize("W", [64, 76])
def test_tilize_with_val_padding_2D(device, in_dtype, use_multicore, H, W, pad_value):
torch_input_shape = [H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT)

output_tt = ttnn.tilize_with_val_padding(ttnn_input, [64, 128], pad_value, use_multicore=use_multicore)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing


@pytest.mark.parametrize("in_dtype", [ttnn.bfloat16, ttnn.float32])
@pytest.mark.parametrize("use_multicore", [False, True])
@pytest.mark.parametrize("H", [128, 98])
@pytest.mark.parametrize("W", [78, 1024])
def test_tilize_with_zero_padding_2D(device, in_dtype, use_multicore, H, W):
torch_input_shape = [H, W]

torch_input = torch.randn(torch_input_shape, dtype=torch.bfloat16).bfloat16()

ttnn_input = ttnn.from_torch(torch_input, device=device, dtype=in_dtype, layout=ttnn.ROW_MAJOR_LAYOUT)

output_tt = ttnn.tilize_with_zero_padding(ttnn_input, use_multicore=use_multicore)
output_torch = ttnn.to_torch(output_tt)

passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_input, output_torch)
logger.info(pcc_msg)
assert passing
45 changes: 22 additions & 23 deletions tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def test_to_layout_device(device, h, w, input_layout, output_layout):
assert_with_pcc(torch_input_tensor, torch_brought_back)


@pytest.mark.parametrize("shape", [[1, 50, 1, 3, 768], [1, 1370, 1, 3, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
@pytest.mark.parametrize("shape", [[3, 50, 1, 3, 768], [3, 1370, 1, 32, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_to_layout_5D(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
Expand All @@ -154,7 +154,7 @@ def test_to_layout_5D(shape, input_layout, output_layout, device):
assert_with_pcc(input_a, output_tensor)


@pytest.mark.parametrize("shape", [[1, 1, 58, 1, 37, 256], [1, 1, 64, 1, 90, 1280]])
@pytest.mark.parametrize("shape", [[4, 7, 58, 1, 37, 256], [1, 3, 64, 1, 32, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_to_layout_6D(shape, input_layout, output_layout, device):
Expand All @@ -166,26 +166,25 @@ def test_to_layout_6D(shape, input_layout, output_layout, device):
assert_with_pcc(input_a, output_tensor)


@pytest.mark.skip("Skipping due to hang on to_layout to tile where input shape has 1 in it")
@pytest.mark.parametrize(
"config",
[
[[3, 1370, 1, 1, 1280], ttnn.ROW_MAJOR_LAYOUT], # hang
[[3, 50, 1, 1, 768], ttnn.ROW_MAJOR_LAYOUT], # hang
[[3, 50, 1, 1, 1024], ttnn.ROW_MAJOR_LAYOUT], # hang
[[3, 197, 1, 1, 768], ttnn.ROW_MAJOR_LAYOUT], # hang
[[3, 197, 1, 1, 1024], ttnn.ROW_MAJOR_LAYOUT], # hang
],
)
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG])
def test_to_layout_hangs(config, memory_config, device):
@pytest.mark.parametrize("shape", [[3, 50, 1, 1, 768], [3, 50, 1, 1, 1024], [3, 197, 1, 1, 768], [3, 197, 1, 1, 1024]])
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_to_layout_nd_hangs(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
torch_input = torch.randn(config[0], dtype=torch.bfloat16)
input_a = torch.randn(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16)
output_tensor = ttnn.to_layout(input_tensor, output_layout)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(input_a, output_tensor)

tt_input = ttnn.from_torch(
torch_input, dtype=ttnn.DataType.BFLOAT16, layout=config[1], device=device, memory_config=memory_config
)
tt_output = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT)
tt_output = ttnn.to_torch(tt_output)

assert_with_pcc(torch_input, tt_output, 0.9999)
@pytest.mark.parametrize("shape", [[1, 768], [3, 230], [32, 768], [32, 143]])
@pytest.mark.parametrize("input_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
def test_to_layout_for_2D(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
input_a = torch.randn(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16)
output_tensor = ttnn.to_layout(input_tensor, output_layout)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(input_a, output_tensor)
19 changes: 0 additions & 19 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,6 @@ namespace core {

namespace detail {

inline bool validate_nd_support(const ttnn::Tensor& tensor_arg, const ttnn::Layout layout) {
const auto initial_shape = tensor_arg.get_shape();
if (initial_shape.rank() > 4 && tensor_arg.get_layout() != layout) {
for (int i = 0; i < initial_shape.rank() - 4; i++) {
TT_FATAL(
initial_shape[i] == 1,
"For ND tensors, shape dimensions greater than 4 should be 1, shape at index{} is {}",
i,
initial_shape[i]);
}
}
return true;
}

// Issue #8617: Limitations on tensor width for multicore device tilize
inline bool use_multicore_device_tilize(
const Tensor& input, const std::optional<tt::tt_metal::DataType>& output_dtype) {
Expand Down Expand Up @@ -142,7 +128,6 @@ Tensor to_layout_impl(
if (not requires_padding_change(tensor, layout, tensor.get_shape())) {
if (layout == ttnn::ROW_MAJOR_LAYOUT) {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!");
validate_nd_support(tensor_arg, layout);
return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize);
} else if (layout == ttnn::TILE_LAYOUT) {
if (tensor.is_sharded()) {
Expand All @@ -153,7 +138,6 @@ Tensor to_layout_impl(
"TILE_SIZE!");
}
}
validate_nd_support(tensor_arg, layout);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
throw std::runtime_error("ttnn::to_layout: Unsupported layout!");
Expand All @@ -171,7 +155,6 @@ Tensor to_layout_impl(
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
}

validate_nd_support(tensor_arg, layout);
tensor =
ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize);
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});
Expand All @@ -198,7 +181,6 @@ Tensor to_layout_impl(
{0, padded_output_shape[2] - output_shape[2]},
{0, padded_output_shape[3] - output_shape[3]}};
tensor = ttnn::pad(0, tensor, padding, 0, true, std::nullopt);
validate_nd_support(tensor_arg, layout);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
PadValue pad_value_variant;
Expand All @@ -208,7 +190,6 @@ Tensor to_layout_impl(
pad_value_variant = (uint32_t)0;
}

validate_nd_support(tensor_arg, layout);
tensor = ttnn::tilize_with_val_padding(
tensor, padded_output_shape, pad_value_variant, output_memory_config, dtype, use_multicore_tilize);
}
Expand Down
42 changes: 35 additions & 7 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,52 @@
#include "ttnn/cpp/ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/squeeze/squeeze.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp"

namespace ttnn {
namespace operations {
namespace data_movement {
ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor) {
auto shape = tensor.get_shape();

ttnn::Shape squeeze_shape_to_4D(ttnn::Shape shape) {
if (shape.rank() <= 4) {
return shape;
}
std::array<uint32_t, 4> shape_4d;
shape_4d[0] = 1;
int extra_rank = shape.rank() - 4;
for (int i = extra_rank; i >= 0; i--) {
shape_4d[0] *= shape[i];
}
shape_4d[1] = shape[1 + extra_rank];
shape_4d[2] = shape[2 + extra_rank];
shape_4d[3] = shape[3 + extra_rank];
return ttnn::Shape(shape_4d);
}

ttnn::Tensor squeeze_from_ND_to_4D(const ttnn::Tensor& tensor) {
auto shape = tensor.get_shape();
auto rank = shape.rank();
TT_FATAL(shape.rank() >= 4, "Tensor has to be of rank larger than 4! Instead is {}", shape.rank());
if (rank == 4) {
return tensor;
} else {
auto rank = shape.rank();
}
int i = 0;
// This is a workaround for now, it will be fixed in another PR
if (shape[i] == 1) {
auto squeezed = tensor;
while (rank > 4) {
while (rank > 4 && shape[i] == 1) {
squeezed = ttnn::squeeze(squeezed, 0);
rank = squeezed.get_shape().rank();
i++;
}
return squeezed;
if (rank <= 4) {
return squeezed;
}
return ttnn::reshape(squeezed, squeeze_shape_to_4D(shape));
}
};
return ttnn::reshape(tensor, squeeze_shape_to_4D(shape));
}

ttnn::Tensor pad_to_tile_vol(
uint8_t queue_id,
Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
namespace ttnn {
namespace operations {
namespace data_movement {
ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor);

ttnn::Shape squeeze_shape_to_4D(ttnn::Shape output_shape);
ttnn::Tensor squeeze_from_ND_to_4D(const ttnn::Tensor& tensor);

ttnn::Tensor pad_to_tile_vol(
uint8_t queue_id,
Expand Down Expand Up @@ -148,7 +150,6 @@ class MassagedOperation {
OpType operation_;
};

ttnn::Tensor squeeze_to_le_4D(const ttnn::Tensor& tensor);
ttnn::Tensor pad_to_tile_vol(
uint8_t queue_id,
const ttnn::Tensor& tensor,
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ MassagedTilize build_ndiml_tilize(BaseTilizeType base_tilize) {
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@ MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) {
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
const auto tile = output.get_tensor_spec().tile();
uint32_t tile_height = tile.get_height();
uint32_t tile_width = tile.get_width();
auto unsqueezed_tensor =
ttnn::reshape(output, update_original_shape(*original_shape, tile_height, tile_width));
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = std::move(base_tilize)});
Expand Down Expand Up @@ -127,8 +126,8 @@ ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke(
using namespace tt::constants;
auto shape = input_tensor.get_legacy_shape();

shape[2] = tt::round_up(shape[2], tt::constants::TILE_HEIGHT);
shape[3] = tt::round_up(shape[3], tt::constants::TILE_WIDTH);
shape[-2] = tt::round_up(shape[-2], tt::constants::TILE_HEIGHT);
shape[-1] = tt::round_up(shape[-1], tt::constants::TILE_WIDTH);

PadValue pad_value;
if (input_tensor.get_dtype() == DataType::BFLOAT16 or input_tensor.get_dtype() == DataType::FLOAT32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ MassagedUntilize build_ndiml_untilize(BaseUntilizeType base_untilize) {
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

using namespace tt::tt_metal;

LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) {
LegacyShape squeeze_vector_shape(tt::tt_metal::LegacyShape output_shape) {
if (output_shape.rank() > 4) {
std::vector<uint32_t> output_shape_4d(output_shape.rank());
output_shape_4d[0] = 1;
Expand Down Expand Up @@ -45,7 +45,7 @@ MassagedUntilizeVal build_ndiml_untilize_val(BaseUntilizeValType base_untilize)
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
ttnn::Tensor squeezed_tensor = squeeze_from_ND_to_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
Expand All @@ -71,7 +71,7 @@ ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
for (auto index = 0; index < input_tensor.get_shape().rank(); ++index) {
output_end_vector.push_back(input_tensor.get_shape()[index] - 1);
}
output_end = squeeze_output_shape(LegacyShape(output_end_vector));
output_end = squeeze_vector_shape(LegacyShape(output_end_vector));
} else {
output_end = output_tensor_end;
}
Expand Down

0 comments on commit d939510

Please sign in to comment.