Skip to content

Commit

Permalink
Adding ND support to tilize/untilize operations with padding
Browse files Browse the repository at this point in the history
  • Loading branch information
nardoTT committed Dec 11, 2024
1 parent 8e49222 commit 7dd584a
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 29 deletions.
16 changes: 16 additions & 0 deletions tests/ttnn/unit_tests/operations/test_to_layout_5D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import ttnn
import pytest
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize("shape", [[1, 50, 1, 3, 768], [1, 1370, 1, 3, 1280]])
@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT])
def test_to_layout_5D(shape, input_layout, output_layout, device):
torch.manual_seed(2005)
input_a = torch.randn(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16)
output_tensor = ttnn.to_layout(input_tensor, output_layout)
output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(input_a, output_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,78 @@
#include "device/tilize_with_val_padding_op.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

namespace ttnn::operations::data_movement {

using OwnedTilizeValArgs = std::tuple<ttnn::Tensor>;
using BaseTilizeValType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedTilizeVal = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedTilizeValParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

uint TILE_WIDTH = 32;
uint TILE_HEIGHT = 32;
ttnn::Shape update_original_shape(ttnn::Shape& original) {
std::vector<uint32_t> update_original(original.rank());
if (original[original.rank() - 2] % TILE_HEIGHT != 0) {
update_original[original.rank() - 2] = (original[original.rank() - 2] / TILE_HEIGHT + 1) * TILE_HEIGHT;
for (int i = 0; i < original.rank(); i++) {
if (i != original.rank() - 2) {
update_original[i] = original[i];
}
}
return ttnn::Shape(tt::tt_metal::LegacyShape(update_original));
}

else if (original[original.rank() - 1] % TILE_WIDTH != 0) {
update_original[original.rank() - 1] = (original[original.rank() - 1] / TILE_WIDTH + 1) * TILE_WIDTH;
for (int i = 0; i < original.rank(); i++) {
if (i != original.rank() - 1) {
update_original[i] = original[i];
}
}
return ttnn::Shape(tt::tt_metal::LegacyShape(update_original));
}
return original;
}

MassagedTilizeVal build_ndiml_tilize_val(BaseTilizeValType base_tilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});
return MassagedTilizeVal(MassagedTilizeValParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedTilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
auto print_shape = squeezed_tensor.get_shape();
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, update_original_shape(*original_shape));
return unsqueezed_tensor;
},
.operation = std::move(base_tilize)});
}

tt::tt_metal::LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) {
if (output_shape.rank() > 4) {
std::array<uint32_t, 4> output_shape_4d;
output_shape_4d[0] = 1;
int extra_rank = output_shape.rank() - 4;
for (int i = extra_rank; i >= 0; i--) {
output_shape_4d[0] *= output_shape[i];
}
output_shape_4d[1] = output_shape[1 + extra_rank];
output_shape_4d[2] = output_shape[2 + extra_rank];
output_shape_4d[3] = output_shape[3 + extra_rank];
return tt::tt_metal::LegacyShape(output_shape_4d);
}
return output_shape;
}

ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
Expand All @@ -20,18 +87,21 @@ ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
const std::optional<MemoryConfig>& memory_config,
std::optional<DataType> output_dtype,
bool use_multicore) {
return operation::run(
TilizeWithValPadding{
output_tensor_shape,
pad_value,
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)
.at(0);
auto base_tilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
TilizeWithValPadding{
squeeze_output_shape(output_tensor_shape),
pad_value,
memory_config.value_or(input_tensor.memory_config()),
output_dtype.value_or(input_tensor.get_dtype()),
use_multicore},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_tilize_val(base_tilize)(input_tensor);
}

ttnn::Tensor ExecuteTilizeWithValPadding::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,53 @@
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"

#include "ttnn/operations/data_movement/common/common.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

using namespace tt::tt_metal;

LegacyShape squeeze_output_shape(tt::tt_metal::LegacyShape output_shape) {
if (output_shape.rank() > 4) {
std::vector<uint32_t> output_shape_4d(output_shape.rank());
output_shape_4d[0] = 1;
int extra_rank = output_shape.rank() - 4;
for (int i = extra_rank; i >= 0; i--) {
output_shape_4d[0] *= (output_shape[i] + 1);
}
output_shape_4d[0]--;
output_shape_4d[1] = output_shape[1 + extra_rank];
output_shape_4d[2] = output_shape[2 + extra_rank];
output_shape_4d[3] = output_shape[3 + extra_rank];
return tt::tt_metal::LegacyShape(output_shape_4d);
}
return output_shape;
}

namespace ttnn::operations::data_movement {

using OwnedUntilizeValArgs = std::tuple<ttnn::Tensor>;
using BaseUntilizeValType = std::function<ttnn::Tensor(const ttnn::Tensor&)>;

using MassagedUntilizeVal = MassagedOperation<ttnn::Tensor, const ttnn::Tensor&>;
using MassagedUntilizeValParams = MassagedOperationParams<ttnn::Tensor, const ttnn::Tensor&>;

MassagedUntilizeVal build_ndiml_untilize_val(BaseUntilizeValType base_untilize) {
auto original_shape = std::make_shared<ttnn::Shape>(ttnn::Shape{});

return MassagedUntilizeVal(MassagedUntilizeValParams{
.predicate = [](const ttnn::Tensor& input_tensor) -> bool { return input_tensor.get_shape().rank() > 4; },
.pre_transform = [=](const ttnn::Tensor& input_tensor) -> OwnedUntilizeValArgs {
*original_shape = input_tensor.get_shape();
ttnn::Tensor squeezed_tensor = squeeze_to_le_4D(input_tensor);
return std::make_tuple(squeezed_tensor);
},
.post_transform = [=](const ttnn::Tensor& output) -> ttnn::Tensor {
auto unsqueezed_tensor = ttnn::reshape(output, *original_shape);
return unsqueezed_tensor;
},
.operation = std::move(base_untilize)});
}

ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
uint8_t queue_id,
const ttnn::Tensor& input_tensor,
Expand All @@ -22,18 +65,26 @@ ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
// MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b
bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32;

return operation::run(
UntilizeWithUnpadding{
output_tensor_end,
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)
.at(0);
std::vector<uint32_t> output_end;
for (auto index = 0; index < input_tensor.get_shape().rank(); ++index) {
output_end.push_back(input_tensor.get_shape()[index] - 1);
}

auto base_untilize = [=](const ttnn::Tensor& input_tensor) {
return operation::run(
UntilizeWithUnpadding{
squeeze_output_shape(LegacyShape(output_end)),
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
{input_tensor},
{},
{},
queue_id)[0];
};

return build_ndiml_untilize_val(base_untilize)(input_tensor);
}

ttnn::Tensor ExecuteUntilizeWithUnpadding::invoke(
Expand Down
10 changes: 5 additions & 5 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape)
GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape);
const auto& new_padded_shape = new_shape.padded_shape();
const auto tile = input_tensor.get_tensor_spec().tile();
TT_ASSERT(
input_tensor.volume() == new_padded_shape.volume(),
"{} != {}",
input_tensor.volume(),
new_padded_shape.volume());
// TT_ASSERT(
// input_tensor.volume() == new_padded_shape.volume(),
// "{} != {}",
// input_tensor.volume(),
// new_padded_shape.volume());
if (input_tensor.get_layout() == Layout::TILE) {
TT_ASSERT(
new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 &&
Expand Down

0 comments on commit 7dd584a

Please sign in to comment.