Skip to content

Commit

Permalink
Adding ND support for tilize/untilize with padding (#15933)
Browse files Browse the repository at this point in the history
### Ticket
#15935

### Problem description
Supporting ND tensors for tilize_with_val_padding and
untilize_with_unpadding operations

### What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12280234070
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12280247340
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
nardoTT authored Dec 17, 2024
1 parent b80a975 commit 9208f77
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 31 deletions.
24 changes: 24 additions & 0 deletions tests/ttnn/unit_tests/test_to_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,27 @@ def test_to_layout_device(device, h, w, input_layout, output_layout):
torch_brought_back = ttnn.to_torch(new_layout_tensor)

assert_with_pcc(torch_input_tensor, torch_brought_back)


@pytest.mark.parametrize("shape", [[1, 50, 1, 3, 768], [1, 1370, 1, 3, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_to_layout_5D(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
input_a = torch.randn(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16)
output_tensor = ttnn.to_layout(input_tensor, output_layout)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(input_a, output_tensor)


@pytest.mark.parametrize("shape", [[1, 1, 58, 1, 37, 256], [1, 1, 64, 1, 90, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_to_layout_6D(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
input_a = torch.randn(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16)
output_tensor = ttnn.to_layout(input_tensor, output_layout)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(input_a, output_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,7 @@ inline std::vector<std::vector<BlockRep>> distribute_work(
bool has_cliff,
uint32_t nblocks_per_core_cliff) {
TT_FATAL(
logical_shape.rank() >= 2 && logical_shape.rank() <= 4,
"Only 2D, 3D, and 4D tensors are supported. Shape: {}",
"Error",
logical_shape,
padding);
logical_shape.rank() >= 2, "Logical shape rank needs to be >=2. Shape: {}", "Error", logical_shape, padding);

auto input_w = logical_shape.rank() >= 4 ? logical_shape[-4] : 1;
auto input_z = logical_shape.rank() >= 3 ? logical_shape[-3] : 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,81 @@
#include "device/tilize_with_val_padding_op.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {

using OwnedTilizeValArgs = std::tuple<ttnn::Tensor>;
using BaseTilizeValType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedTilizeVal = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedTilizeValParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

ttnn::Shape update_original_shape(ttnn::Shape& original, uint32_t tile_height, uint32_t tile_width) {
std::vector<uint32_t> update_original(original.rank());
uint32_t indx1 = original.rank() - 1;
uint32_t indx2 = original.rank() - 2;
if (original[indx2] % tile_height != 0) {
update_original[indx2] = (original[indx2] / tile_height + 1) * tile_height;
for (int i = 0; i < original.rank(); i++) {
if (i != indx2) {
update_original[i] = original[i];
}
}
return tt::tt_metal::LegacyShape(update_original);
}

else if (original[indx1] % tile_width != 0) {
update_original[indx1] = (original[indx1] / tile_width + 1) * tile_width;
for (int i = 0; i < original.rank(); i++) {
if (i != indx1) {
update_original[i] = original[i];
}
}
return tt::tt_metal::LegacyShape(update_original);
}
return original;
}

MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});
return MassagedTilizeVal(MassagedTilizeValParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
const auto tile = output.get_tensor_spec().tile();
uint32_t tile_height = tile.get_height();
uint32_t tile_width = tile.get_width();
auto unsqueezed_tensor =
ttnn::reshape(output, update_original_shape(*original_shape, tile_height, tile_width));
return unsqueezed_tensor;
},
.operation = std::move(base_tilize)});
}

tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) {
if (output_shape.rank() > 4) {
std::array<uint32_t, 4> output_shape_4d;
output_shape_4d[0] = 1;
int extra_rank = output_shape.rank() - 4;
for (int i = extra_rank; i >= 0; i--) {
output_shape_4d[0] *= output_shape[i];
}
output_shape_4d[1] = output_shape[1 + extra_rank];
output_shape_4d[2] = output_shape[2 + extra_rank];
output_shape_4d[3] = output_shape[3 + extra_rank];
return tt::tt_metal::LegacyShape(output_shape_4d);
}
return output_shape;
}

ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
Expand All @@ -20,18 +90,21 @@ ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
const std::optional<MemoryConfig>& memory_config,
std::optional<DataType> output_dtype,
bool use_multicore) {
return operation::run(
TilizeWithValPadding{
output_tensor_shape,
pad_value,
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)
.at(0);
auto base_tilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
TilizeWithValPadding{
squeeze_output_shape(output_tensor_shape),
pad_value,
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_tilize_val(base_tilize)(input_tensor);
}

ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
Expand All @@ -54,8 +127,8 @@ ttnn::Tensor ExecuteTilizeWithZeroPadding::invoke(
using namespace tt::constants;
auto shape = input_tensor.get_legacy_shape();

shape[2] = tt::round_up(shape[2], TILE_HEIGHT);
shape[3] = tt::round_up(shape[3], TILE_WIDTH);
shape[2] = tt::round_up(shape[2], tt::constants::TILE_HEIGHT);
shape[3] = tt::round_up(shape[3], tt::constants::TILE_WIDTH);

PadValue pad_value;
if (input_tensor.get_dtype() == DataType::BFLOAT16 or input_tensor.get_dtype() == DataType::FLOAT32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,53 @@
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"

#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) {
if (output_shape.rank() > 4) {
std::vector<uint32_t> output_shape_4d(output_shape.rank());
output_shape_4d[0] = 1;
int extra_rank = output_shape.rank() - 4;
for (int i = extra_rank; i >= 0; i--) {
output_shape_4d[0] *= (output_shape[i] + 1);
}
output_shape_4d[0]--;
output_shape_4d[1] = output_shape[1 + extra_rank];
output_shape_4d[2] = output_shape[2 + extra_rank];
output_shape_4d[3] = output_shape[3 + extra_rank];
return tt::tt_metal::LegacyShape(output_shape_4d);
}
return output_shape;
}

namespace ttnn::operations::data_movement {

using OwnedUntilizeValArgs = std::tuple<ttnn::Tensor>;
using BaseUntilizeValType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedUntilizeVal = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedUntilizeValParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

MassagedUntilizeVal build_ndiml_untilize_val(BaseUntilizeValType base_untilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});

return MassagedUntilizeVal(MassagedUntilizeValParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = std::move(base_untilize)});
}

ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
Expand All @@ -22,18 +65,32 @@ ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
// MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b
bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32;

return operation::run(
UntilizeWithUnpadding{
output_tensor_end,
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)
.at(0);
std::vector<uint32_t> output_end_vector;
tt::tt_metal::LegacyShape output_end = tt::tt_metal::LegacyShape{};
if (input_tensor.get_shape().rank() > 4) {
for (auto index = 0; index < input_tensor.get_shape().rank(); ++index) {
output_end_vector.push_back(input_tensor.get_shape()[index] - 1);
}
output_end = squeeze_output_shape(LegacyShape(output_end_vector));
} else {
output_end = output_tensor_end;
}

auto base_untilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
UntilizeWithUnpadding{
output_end,
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_untilize_val(base_untilize)(input_tensor);
}

ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
Expand Down

0 comments on commit 9208f77

Please sign in to comment.