Skip to content

Commit

Permalink
Fix ttnn.from_torch for 0D/1D tensors with tile layout (#16484)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
Part 3 of the original PR for `ttnn.from_torch` support for 0D/1D
tensors with tile layout which got reverted previously

### What's changed
Variety of Shape fixes throughout the codebase

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12649999245)
- [x] [T3K unit tests CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12703587849)
- [x] [T3K frequent CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12703592302)
- [x] [MLIR on push CI
passes](https://github.com/tenstorrent/tt-mlir/actions/runs/1270370182)
- [x] [Model regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12650502927)
- [x] [Device performance regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12650499944)
- [x] [Demo
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12650494382)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Jan 10, 2025
1 parent 9c5a0b2 commit b2912fe
Show file tree
Hide file tree
Showing 18 changed files with 208 additions and 210 deletions.
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,14 +1230,14 @@ def test_matmul_with_matched_width_height(device, m_size, k_size, n_size):
def test_matmul_with_matched_width_height_from_1D(device, k_size, n_size):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((k_size), dtype=torch.bfloat16)
torch_input_tensor_a = torch.rand((1, k_size), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((k_size, n_size), dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = input_tensor_a @ input_tensor_b
output = ttnn.to_torch(output, torch_rank=1)
output = ttnn.to_torch(output)

assert len(output.shape) == len(torch_output_tensor.shape)
assert output.shape == torch_output_tensor.shape
Expand Down
19 changes: 19 additions & 0 deletions tests/ttnn/unit_tests/test_to_and_from_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,22 @@ def test_from_torch_large(device):
x_tensor = ttnn.from_torch(torch_x, layout=ttnn.TILE_LAYOUT)
x_tensor = ttnn.to_torch(x_tensor)
assert torch.allclose(torch_x, x_tensor)


@pytest.mark.parametrize(
"shape",
[
(),
(1),
(2),
(0),
],
)
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_to_for_01_rank(shape, layout, dtype):
torch_input_tensor = torch.rand(shape, dtype=dtype)
tensor = ttnn.from_torch(torch_input_tensor, layout=layout)
# Conversion in the opposite direction is not yet supported
# torch_output_tensor = ttnn.to_torch(tensor)
# assert torch.allclose(torch_input_tensor, torch_output_tensor)
112 changes: 42 additions & 70 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ inline bool use_multicore_device_tilize(
return num_tiles_in_row <= max_tiles;
}

bool requires_padding_change(const ttnn::Tensor& tensor, ttnn::Layout layout) {
auto tile = tensor.get_tensor_spec().tile();
if (layout == Layout::ROW_MAJOR) {
// There shouldn't be extra paddings for Row Major layout
return tensor.get_logical_shape() != tensor.get_padded_shape();
}
// It's okay for conversion to tile layout to preserve arbitrary padding as long as it satisfies the alignment
TensorSpec padded_spec(
tensor.get_padded_shape(),
TensorLayout(tensor.get_dtype(), PageConfig(layout, std::move(tile)), tensor.memory_config()));
return tensor.get_padded_shape() != padded_spec.padded_shape();
}

template <typename T>
Tensor to_layout_impl(
const ttnn::Tensor& tensor_arg,
Expand Down Expand Up @@ -76,56 +89,31 @@ Tensor to_layout_impl(
TT_THROW("ttnn::to_layout: Unsupported layout conversion from {} to {}!", tensor_arg.get_layout(), layout);
}

const auto requires_padding_change =
[](ttnn::Tensor& tensor, ttnn::Layout layout, const ttnn::Shape& shape) -> bool {
const auto intended_shape = shape;
const auto padded_shape = shape.with_tile_padding();
if (layout == ttnn::ROW_MAJOR_LAYOUT and intended_shape != padded_shape) {
return true;
}
if (layout == ttnn::TILE_LAYOUT) {
auto tile_shape = tensor.tensor_spec().tile().get_tile_shape();
if (padded_shape.rank() < 2 or padded_shape[-1] % tile_shape[1] != 0 or
padded_shape[-2] % tile_shape[0] != 0) {
return true;
}
}
return false;
};

const auto intended_shape = tensor_arg.get_shape();

auto tensor = tensor_arg;
const auto tile = tensor.get_tensor_spec().tile();

SmallVector<uint32_t> output_shape;
if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) {
output_shape.push_back(1);
tensor = ttnn::reshape(
tensor,
ttnn::Shape(
SmallVector<uint32_t>{1, intended_shape[0]},
SmallVector<uint32_t>{1, tensor_arg.get_shape().with_tile_padding()[0]}));
}
for (auto index = 0; index < intended_shape.rank(); ++index) {
output_shape.push_back(intended_shape[index]);
}

auto padded_output_shape = output_shape;
for (auto index = output_shape.size() - 2; index < output_shape.size(); ++index) {
padded_output_shape[index] = ttnn::pad_to_multiple_of_tile_size(
padded_output_shape[index],
(index == output_shape.size() - 2) ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
}

auto output_shape = tensor_arg.get_logical_shape();
auto output_memory_config =
memory_config.value_or(ttnn::get_memory_config(tensor).value_or(ttnn::DRAM_MEMORY_CONFIG));

TensorSpec tile_spec(
tensor_arg.get_logical_shape(),
TensorLayout(tensor_arg.dtype(), PageConfig(Layout::TILE, tile), output_memory_config));
auto padded_output_shape = tile_spec.padded_shape();

if (layout == ttnn::TILE_LAYOUT) {
if (tensor.get_padded_shape().size() < 2) {
SmallVector<uint32_t> new_padded_shape(2, 1);
new_padded_shape[1] = tensor.get_padded_shape()[-1];
new_padded_shape[0] = tensor.get_padded_shape()[-2];
tensor = tensor.reshape(tensor.get_logical_shape(), SimpleShape(new_padded_shape));
}
}

if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) {
bool use_multicore_untilize = true;
bool use_multicore_tilize = use_multicore_device_tilize(tensor, dtype);

if (not requires_padding_change(tensor, layout, tensor.get_shape())) {
if (not requires_padding_change(tensor, layout)) {
if (layout == ttnn::ROW_MAJOR_LAYOUT) {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!");
return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize);
Expand Down Expand Up @@ -153,27 +141,15 @@ Tensor to_layout_impl(
tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type};
}
SmallVector<uint32_t> output_tensor_end;
for (auto index = 0; index < tensor.get_shape().rank(); ++index) {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
for (auto index = 0; index < tensor.get_logical_shape().rank(); ++index) {
output_tensor_end.push_back(tensor.get_logical_shape()[index] - 1);
}

tensor =
ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize);
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});

} else if (layout == ttnn::TILE_LAYOUT) {
SmallVector<uint32_t> padded_output_shape;

for (int index = 0; index < tensor.get_shape().rank(); ++index) {
uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim
uint32_t padded_value =
index < second_last_rank
? tensor.get_shape()[index]
: ttnn::pad_to_multiple_of_tile_size(
tensor.get_shape()[index],
index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
padded_output_shape.push_back(padded_value);
}
if (tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
// ttnn::tilize_with_val_padding doesn't support height sharded tensors
// workaround by applying padding and then tilizing
Expand All @@ -193,40 +169,36 @@ Tensor to_layout_impl(
}

tensor = ttnn::tilize_with_val_padding(
tensor, padded_output_shape, pad_value_variant, output_memory_config, dtype, use_multicore_tilize);
tensor,
SimpleShape(padded_output_shape),
pad_value_variant,
output_memory_config,
dtype,
use_multicore_tilize);
}

return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape}));

return ttnn::reshape(
tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape.view(), padded_output_shape.view()}));
} else {
TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout);
}
} else {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!");
if (not requires_padding_change(tensor, layout, tensor.get_shape())) {
if (not requires_padding_change(tensor, layout)) {
return device ? tensor.to(layout, device) : tensor.to(layout);
} else if (layout == ttnn::ROW_MAJOR_LAYOUT) {
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
tensor = tensor.unpad_from_tile(tensor.get_logical_shape());
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});
} else if (layout == ttnn::TILE_LAYOUT) {
SmallVector<uint32_t> padded_output_shape;
SmallVector<uint32_t> padded_input_start;
for (int index = 0; index < tensor.get_shape().rank(); ++index) {
uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim
uint32_t padded_value =
index < second_last_rank
? tensor.get_shape()[index]
: ttnn::pad_to_multiple_of_tile_size(
tensor.get_shape()[index],
index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
padded_output_shape.push_back(padded_value);
for (int index = 0; index < padded_output_shape.rank(); ++index) {
padded_input_start.push_back(0);
}
tensor =
tensor.pad(ttnn::SimpleShape(padded_output_shape), ttnn::SimpleShape(std::move(padded_input_start)), 0);
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape}));
return tensor.reshape(output_shape, padded_output_shape);
} else {
TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout);
}
Expand Down
12 changes: 6 additions & 6 deletions ttnn/cpp/ttnn/operations/core/work_split/work_split_tilize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ inline std::vector<std::vector<BlockRep>> distribute_work(
"Only tensors >=2D, tensors are supported. Shape: {}",
logical_shape);

auto input_w = logical_shape.rank() >= 4 ? logical_shape[-4] : 1;
auto input_z = logical_shape.rank() >= 3 ? logical_shape[-3] : 1;
auto input_y = logical_shape.rank() >= 2 ? logical_shape[-2] : 1;
auto input_w = logical_shape[-4];
auto input_z = logical_shape[-3];
auto input_y = logical_shape[-2];

auto padding_w = logical_shape.rank() >= 4 ? padding[padding.get_normalized_index(-4)].back : 0;
auto padding_z = logical_shape.rank() >= 3 ? padding[padding.get_normalized_index(-3)].back : 0;
auto padding_y = logical_shape.rank() >= 2 ? padding[padding.get_normalized_index(-2)].back : 0;
auto padding_w = padding.rank() >= 4 ? padding[padding.get_normalized_index(-4)].back : 0;
auto padding_z = padding.rank() >= 3 ? padding[padding.get_normalized_index(-3)].back : 0;
auto padding_y = padding.rank() >= 2 ? padding[padding.get_normalized_index(-2)].back : 0;

// total work is a full rep followed by a padding.
auto full_rep_blocks = FullRep(input_y, padding_y, input_z, padding_z, input_w, tile_height).to_block_reps();
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ MassagedConcat build_untilize_rm_retilize_concat(
auto padded = pad_to_tile_vol(queue_id, output, 0.0f, true, output.memory_config());
concat_db_print(true, "[DEBUG] padded to tile layout, now tilizing.");
auto tilized =
ttnn::tilize_with_val_padding(padded, padded.get_legacy_shape(), 0.0f, output.memory_config());
ttnn::tilize_with_val_padding(padded, padded.get_padded_shape(), 0.0f, output.memory_config());
concat_db_print(true, "[DEBUG] tilized");
// need to reshape tilized result to logical concat output shape
auto reshaped = ttnn::reshape(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ void TilizeWithValPadding::validate(const std::vector<Tensor>& input_tensors) co

for (auto i = 0; i < input_shape.rank(); i++) {
TT_FATAL(
input_shape[i] <= this->output_tensor_shape[i],
input_shape[i] <= this->output_padded_shape[i],
"Output tensor shape {} must be greater than or equal to input shape {} in each dimension, but is smaller "
"in dimension {}",
this->output_tensor_shape,
this->output_padded_shape,
input_shape,
i);
}

uint32_t num_rows = this->output_tensor_shape[-1];
uint32_t inner_dim = this->output_tensor_shape[-2];
uint32_t num_rows = this->output_padded_shape[-1];
uint32_t inner_dim = this->output_padded_shape[-2];
TT_FATAL(
inner_dim % TILE_WIDTH == 0 && num_rows % TILE_HEIGHT == 0,
"To be tilizable output tensor shape {} must be divisible by tile size ({}, {})",
output_tensor_shape,
output_padded_shape,
TILE_WIDTH,
TILE_HEIGHT);

Expand All @@ -52,39 +52,32 @@ void TilizeWithValPadding::validate(const std::vector<Tensor>& input_tensors) co
"Output tensor must have the same memory layout as input tensor");
for (uint32_t i = 0; i < input_tensor_a.get_legacy_shape().rank(); i++) {
if (i != input_shape.rank() - 2) {
TT_FATAL(input_shape[i] == this->output_tensor_shape[i], "Error");
TT_FATAL(input_shape[i] == this->output_padded_shape[i], "Error");
}
}
}
}

std::vector<tt::tt_metal::LegacyShape> TilizeWithValPadding::compute_output_shapes(
std::vector<ttnn::TensorSpec> TilizeWithValPadding::compute_output_specs(
const std::vector<Tensor>& input_tensors) const {
auto input_shape = input_tensors.at(0).get_legacy_shape();
auto dimensions_pads = std::vector<Padding::PadDimension>();
for (auto index = 0; index < input_shape.rank(); index++) {
auto back = this->output_tensor_shape[index] - input_shape[index];
dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = back});
}
const auto padding = Padding(dimensions_pads, Padding::PadValue::Any);
return {tt::tt_metal::LegacyShape(this->output_tensor_shape, padding)};
}
const auto& input_tensor = input_tensors.at(0);
auto input_shape = input_tensors.at(0).get_padded_shape();

std::vector<Tensor> TilizeWithValPadding::create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
if (input_tensor_a.memory_config().is_sharded()) {
auto output_shape = this->compute_output_shapes(input_tensors).at(0);
auto shard_spec = input_tensor_a.shard_spec().value();
shard_spec.shape[0] = tt::tt_metal::compute_volume(output_shape) / output_shape[-1];
if (input_tensor.memory_config().is_sharded()) {
auto shard_spec = input_tensor.shard_spec().value();
shard_spec.shape[0] = output_padded_shape.volume() / output_padded_shape[-1];
auto mem_config = this->output_mem_config;
mem_config.shard_spec = shard_spec;
return {
create_device_tensor(output_shape, this->output_dtype, Layout::TILE, input_tensor_a.device(), mem_config)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
return {TensorSpec(
input_shape,
TensorLayout::fromPaddedShape(
output_dtype, PageConfig(Layout::TILE), mem_config, input_shape, output_padded_shape))};
}

return {TensorSpec(
input_shape,
TensorLayout::fromPaddedShape(
output_dtype, PageConfig(Layout::TILE), output_mem_config, input_shape, output_padded_shape))};
}

// TODO: If pad is called on a tile and output is not tile, we could untilize then pad, and output is RM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
namespace ttnn::operations::data_movement {

struct TilizeWithValPadding {
const tt::tt_metal::LegacyShape output_tensor_shape;
const ttnn::SimpleShape output_padded_shape;
const PadValue pad_value;
const tt::tt_metal::MemoryConfig output_mem_config;
const tt::tt_metal::DataType output_dtype;
const bool use_multicore;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(
const std::vector<Tensor>& input_tensors, const std::vector<std::optional<Tensor>>& output_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
tt::tt_metal::operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors, std::vector<Tensor>& output_tensors) const;
};
Expand Down
Loading

0 comments on commit b2912fe

Please sign in to comment.