Skip to content

Commit

Permalink
Fix ttnn.from_torch for 0D/1D tensors with tile layout (#15882)
Browse files Browse the repository at this point in the history
### Ticket
#15630

### Problem description
Since Shape/LegacyShape doesn't support different logical and padded
ranks, we had to remove all usages of those classes on the way from
pytorch to ttnn tensor.

### What's changed
Major refactoring in `to_layout`, `pad` ops
TensorLayout fixes for 0D/1D tensors

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12398856356)
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Dec 19, 2024
1 parent a8a812b commit 444b0dc
Show file tree
Hide file tree
Showing 26 changed files with 447 additions and 426 deletions.
2 changes: 1 addition & 1 deletion tests/tt_eager/ops/test_tilize_zero_padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ int main(int argc, char** argv) {
log_debug(LogTest, "Moving src data to host to validate");
Tensor host_a = a.cpu(); // Move tensor a to host to validate
// TODO: Update when tensor.pad_to_tile() function is added
auto padded_shape = a.get_legacy_shape();
auto padded_shape = a.get_padded_shape();
padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT);
padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH);
Tensor padded_host_a = host_a.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ int main(int argc, char** argv) {
Tensor host_a = a.cpu(); // Move tensor a to host to validate
Tensor g = Tensor(host_a.get_storage(), shape, DataType::BFLOAT16, Layout::ROW_MAJOR);
// TODO: Update when tensor.pad_to_tile() function is added
auto padded_shape = g.get_legacy_shape();
auto padded_shape = g.get_padded_shape();
padded_shape[2] = round_up(padded_shape[2], TILE_HEIGHT);
padded_shape[3] = round_up(padded_shape[3], TILE_WIDTH);
Tensor padded_g = g.pad(padded_shape, ttnn::SimpleShape{0, 0, 0, 0}, 0);
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,14 +1229,14 @@ def test_matmul_with_matched_width_height(device, m_size, k_size, n_size):
def test_matmul_with_matched_width_height_from_1D(device, k_size, n_size):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((k_size), dtype=torch.bfloat16)
torch_input_tensor_a = torch.rand((1, k_size), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((k_size, n_size), dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = input_tensor_a @ input_tensor_b
output = ttnn.to_torch(output, torch_rank=1)
output = ttnn.to_torch(output)

assert len(output.shape) == len(torch_output_tensor.shape)
assert output.shape == torch_output_tensor.shape
Expand Down
19 changes: 19 additions & 0 deletions tests/ttnn/unit_tests/test_to_and_from_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,22 @@ def test_from_torch_large(device):
x_tensor = ttnn.from_torch(torch_x, layout=ttnn.TILE_LAYOUT)
x_tensor = ttnn.to_torch(x_tensor)
assert torch.allclose(torch_x, x_tensor)


@pytest.mark.parametrize(
"shape",
[
(),
(1),
(2),
(0),
],
)
@pytest.mark.parametrize("layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_to_for_01_rank(shape, layout, dtype):
torch_input_tensor = torch.rand(shape, dtype=dtype)
tensor = ttnn.from_torch(torch_input_tensor, layout=layout)
# Conversion in the opposite direction is not yet supported
# torch_output_tensor = ttnn.to_torch(tensor)
# assert torch.allclose(torch_input_tensor, torch_output_tensor)
22 changes: 11 additions & 11 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,31 +115,30 @@ Tensor convert_float_vector_to_tt_tensor(
Layout::TILE,
layout);
}
auto result_cpu_spec = TensorSpec(
ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::TILE, tile), MemoryConfig{}));
auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), DataType::FLOAT32);
auto float_tensor = Tensor(OwnedStorage{owned_buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, tile);
auto tile_val = tile.value_or(Tile());
if (shape[2] % tile_val.get_height() != 0 || shape[3] % tile_val.get_width() != 0) {
auto padded_shape = shape;
padded_shape[2] = tt::round_up(shape[2], tile_val.get_height());
padded_shape[3] = tt::round_up(shape[3], tile_val.get_width());

float_tensor = tensor_ops::tensor_pad(
float_tensor, LegacyShape(shape, padded_shape), ttnn::SimpleShape{0, 0, 0, 0}, 0);
if (result_cpu_spec.logical_shape() != result_cpu_spec.padded_shape()) {
float_tensor =
tensor_ops::tensor_pad(float_tensor, result_cpu_spec.padded_shape(), ttnn::SimpleShape{0, 0, 0, 0}, 0);
}
auto output_float_data = owned_buffer::get_as<float>(float_tensor.to(Layout::TILE)).get();
auto output_packed_data =
data_type == DataType::BFLOAT8_B
? pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile)
: pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile);
auto output_buffer = owned_buffer::create<uint32_t>(std::move(output_packed_data));
auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile);
auto tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), result_cpu_spec);
if (device) {
return tensor.to(device, memory_config.value_or(MemoryConfig{}));
}
return tensor;
}
auto result_cpu_spec = TensorSpec(
ttnn::SimpleShape(shape), TensorLayout(data_type, PageConfig(Layout::ROW_MAJOR, tile), MemoryConfig{}));
auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type);
auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR, tile).to(layout);
auto tensor = Tensor(OwnedStorage{owned_buffer}, result_cpu_spec).to(layout);
if (device) {
return tensor.to(device, memory_config.value_or(MemoryConfig{}));
}
Expand Down Expand Up @@ -1212,7 +1211,8 @@ void pytensor_module(py::module& m_tensor) {
const std::array<uint32_t, 4>& output_tensor_shape,
const std::array<uint32_t, 4>& input_tensor_start,
float pad_value) {
return self.pad(output_tensor_shape, ttnn::SimpleShape(input_tensor_start), pad_value);
return self.pad(
ttnn::SimpleShape(output_tensor_shape), ttnn::SimpleShape(input_tensor_start), pad_value);
},
R"doc(
Pad TT Tensor with given pad value ``arg2``.
Expand Down
265 changes: 113 additions & 152 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,90 @@ inline bool use_multicore_device_tilize(
return num_tiles_in_row <= max_tiles;
}

bool requires_padding_change(const ttnn::Tensor& tensor, ttnn::Layout layout) {
auto tile = tensor.get_tensor_spec().tile();
if (layout == Layout::ROW_MAJOR) {
// There shouldn't be extra paddings for Row Major layout
return tensor.logical_shape() != tensor.padded_shape();
}
// It's okay for conversion to tile layout to preserve arbitrary padding as long as it satisfies the alignment
TensorSpec padded_spec(
tensor.padded_shape(),
TensorLayout(tensor.dtype(), PageConfig(layout, std::move(tile)), tensor.memory_config()));
return tensor.get_padded_shape() != padded_spec.padded_shape();
}

template <typename T>
Tensor to_layout_impl_on_device(
const ttnn::Tensor& tensor_arg,
const ttnn::Layout layout,
const std::optional<ttnn::DataType>& dtype,
ttnn::MemoryConfig output_memory_config,
T* device) {
bool use_multicore_untilize = true;
bool use_multicore_tilize = use_multicore_device_tilize(tensor_arg, dtype);

if (layout == ttnn::ROW_MAJOR_LAYOUT) {
TT_FATAL(
!dtype.has_value() || dtype.value() == tensor_arg.dtype(),
"dtype cannot be different from tensor dtype when converting to ROW_MAJOR_LAYOUT on device!");
}

if (!requires_padding_change(tensor_arg, layout)) {
if (layout == ttnn::ROW_MAJOR_LAYOUT) {
return ttnn::untilize(tensor_arg, output_memory_config, use_multicore_untilize);
}
return ttnn::tilize(tensor_arg, output_memory_config, dtype, use_multicore_tilize);
}

auto tensor_shape = tensor_arg.get_logical_shape();

if (layout == ttnn::ROW_MAJOR_LAYOUT) {
if (tensor_arg.is_sharded()) {
const auto memory_config = tensor_arg.memory_config();
output_memory_config = tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type};
}
SmallVector<uint32_t> output_tensor_end;
for (auto index = 0; index < tensor_shape.rank(); ++index) {
output_tensor_end.push_back(tensor_shape[index] - 1);
}

auto tensor =
ttnn::untilize_with_unpadding(tensor_arg, output_tensor_end, output_memory_config, use_multicore_untilize);
return ttnn::reshape(tensor, tensor_shape);
}

TensorSpec result_spec(
tensor_arg.logical_shape(),
TensorLayout(
tensor_arg.dtype(),
PageConfig(layout, std::move(tensor_arg.tensor_spec().tile())),
tensor_arg.memory_config()));

// ttnn::tilize_with_val_padding doesn't support height sharded tensors
// workaround by applying padding and then tilizing
if (tensor_arg.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
ttnn::SmallVector<std::pair<uint32_t, uint32_t>> pad(result_spec.shape().rank());
auto output_padding = result_spec.shape().padding();
for (size_t i = 0; i < result_spec.padded_shape().rank(); i++) {
pad[i] = {output_padding[i].front, output_padding[i].back};
}
auto tensor = ttnn::pad(0, tensor_arg, tt::stl::Span(pad), 0, true, std::nullopt);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
}

PadValue pad_value_variant;
if (tensor_arg.get_dtype() == ttnn::DataType::BFLOAT16 or tensor_arg.get_dtype() == ttnn::DataType::FLOAT32) {
pad_value_variant = 0.0f;
} else {
pad_value_variant = (uint32_t)0;
}

auto tensor = ttnn::tilize_with_val_padding(
tensor_arg, result_spec.padded_shape(), pad_value_variant, output_memory_config, dtype, use_multicore_tilize);
return tensor.reshape(tensor_arg.logical_shape());
}

template <typename T>
Tensor to_layout_impl(
const ttnn::Tensor& tensor_arg,
Expand All @@ -67,167 +151,44 @@ Tensor to_layout_impl(
return tensor_arg;
}

const std::set<ttnn::Layout> supported_layouts = {
ttnn::ROW_MAJOR_LAYOUT,
ttnn::TILE_LAYOUT,
};

if (supported_layouts.find(layout) == supported_layouts.end()) {
if (layout != ROW_MAJOR_LAYOUT && layout != TILE_LAYOUT) {
TT_THROW("ttnn::to_layout: Unsupported layout conversion from {} to {}!", tensor_arg.get_layout(), layout);
}

const auto requires_padding_change =
[](ttnn::Tensor& tensor, ttnn::Layout layout, const ttnn::Shape& shape) -> bool {
const auto intended_shape = shape;
const auto padded_shape = shape.with_tile_padding();
if (layout == ttnn::ROW_MAJOR_LAYOUT and intended_shape != padded_shape) {
return true;
}
if (layout == ttnn::TILE_LAYOUT) {
auto tile_shape = tensor.tensor_spec().tile().get_tile_shape();
if (padded_shape.rank() < 2 or padded_shape[-1] % tile_shape[1] != 0 or
padded_shape[-2] % tile_shape[0] != 0) {
return true;
}
}
return false;
};

const auto intended_shape = tensor_arg.get_shape();

auto tensor = tensor_arg;
const auto tile = tensor.get_tensor_spec().tile();

SmallVector<uint32_t> output_shape;
if (layout == ttnn::TILE_LAYOUT and intended_shape.rank() < 2) {
output_shape.push_back(1);
tensor = ttnn::reshape(
tensor,
ttnn::Shape(
SmallVector<uint32_t>{1, intended_shape[0]},
SmallVector<uint32_t>{1, tensor_arg.get_shape().with_tile_padding()[0]}));
}
for (auto index = 0; index < intended_shape.rank(); ++index) {
output_shape.push_back(intended_shape[index]);
auto output_memory_config =
memory_config.value_or(ttnn::get_memory_config(tensor_arg).value_or(ttnn::DRAM_MEMORY_CONFIG));

if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) {
return to_layout_impl_on_device(tensor_arg, layout, dtype, std::move(output_memory_config), device);
}

auto padded_output_shape = output_shape;
for (auto index = output_shape.size() - 2; index < output_shape.size(); ++index) {
padded_output_shape[index] = ttnn::pad_to_multiple_of_tile_size(
padded_output_shape[index],
(index == output_shape.size() - 2) ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!");
if (not requires_padding_change(tensor_arg, layout)) {
return device ? tensor_arg.to(layout, device) : tensor_arg.to(layout);
}

auto output_memory_config =
memory_config.value_or(ttnn::get_memory_config(tensor).value_or(ttnn::DRAM_MEMORY_CONFIG));
if (layout == ttnn::ROW_MAJOR_LAYOUT) {
auto tensor = device ? tensor_arg.to(layout, device) : tensor_arg.to(layout);
tensor = tensor.unpad_from_tile(tensor.get_logical_shape());
return tensor.reshape(tensor_arg.logical_shape());
}

if (ttnn::is_tensor_on_device_or_multidevice(tensor_arg)) {
bool use_multicore_untilize = true;
bool use_multicore_tilize = use_multicore_device_tilize(tensor, dtype);

if (not requires_padding_change(tensor, layout, tensor.get_shape())) {
if (layout == ttnn::ROW_MAJOR_LAYOUT) {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!");
return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize);
} else if (layout == ttnn::TILE_LAYOUT) {
if (tensor.is_sharded()) {
const auto shard_shape = get_memory_config(tensor).value().shard_spec.value().shape;
if (shard_shape[0] % ttnn::TILE_SIZE != 0 or shard_shape[1] % ttnn::TILE_SIZE != 0) {
TT_THROW(
"ttnn::to_layout: Sharded tensor must have shard shape that is a multiple of "
"TILE_SIZE!");
}
}
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
throw std::runtime_error("ttnn::to_layout: Unsupported layout!");
}
} else if (layout == ttnn::ROW_MAJOR_LAYOUT) {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!");

if (tensor.is_sharded()) {
const auto memory_config = tensor.memory_config();
output_memory_config =
tt::tt_metal::MemoryConfig{memory_config.memory_layout, memory_config.buffer_type};
}
SmallVector<uint32_t> output_tensor_end;
for (auto index = 0; index < tensor.get_shape().rank(); ++index) {
output_tensor_end.push_back(tensor.get_shape()[index] - 1);
}

tensor =
ttnn::untilize_with_unpadding(tensor, output_tensor_end, output_memory_config, use_multicore_untilize);
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});

} else if (layout == ttnn::TILE_LAYOUT) {
SmallVector<uint32_t> padded_output_shape;

for (int index = 0; index < tensor.get_shape().rank(); ++index) {
uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim
uint32_t padded_value =
index < second_last_rank
? tensor.get_shape()[index]
: ttnn::pad_to_multiple_of_tile_size(
tensor.get_shape()[index],
index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
padded_output_shape.push_back(padded_value);
}
if (tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
// ttnn::tilize_with_val_padding doesn't support height sharded tensors
// workaround by applying padding and then tilizing
SmallVector<std::pair<uint32_t, uint32_t>> padding = {
{0, 0},
{0, 0},
{0, padded_output_shape[2] - output_shape[2]},
{0, padded_output_shape[3] - output_shape[3]}};
tensor = ttnn::pad(0, tensor, padding, 0, true, std::nullopt);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
PadValue pad_value_variant;
if (tensor.get_dtype() == ttnn::DataType::BFLOAT16 or tensor.get_dtype() == ttnn::DataType::FLOAT32) {
pad_value_variant = 0.0f;
} else {
pad_value_variant = (uint32_t)0;
}

tensor = ttnn::tilize_with_val_padding(
tensor, padded_output_shape, pad_value_variant, output_memory_config, dtype, use_multicore_tilize);
}

return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape}));

} else {
TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout);
}
} else {
TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!");
if (not requires_padding_change(tensor, layout, tensor.get_shape())) {
return device ? tensor.to(layout, device) : tensor.to(layout);
} else if (layout == ttnn::ROW_MAJOR_LAYOUT) {
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
tensor = tensor.unpad_from_tile(tensor.get_logical_shape());
return ttnn::reshape(tensor, ttnn::SimpleShape{output_shape});
} else if (layout == ttnn::TILE_LAYOUT) {
SmallVector<uint32_t> padded_output_shape;
SmallVector<uint32_t> padded_input_start;
for (int index = 0; index < tensor.get_shape().rank(); ++index) {
uint32_t second_last_rank = tensor.get_shape().rank() - 2; // h dim
uint32_t padded_value =
index < second_last_rank
? tensor.get_shape()[index]
: ttnn::pad_to_multiple_of_tile_size(
tensor.get_shape()[index],
index == second_last_rank ? tile.get_tile_shape()[0] : tile.get_tile_shape()[1]);
padded_output_shape.push_back(padded_value);
padded_input_start.push_back(0);
}
tensor = tensor.pad(padded_output_shape, ttnn::SimpleShape(std::move(padded_input_start)), 0);
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
return ttnn::reshape(tensor, ttnn::Shape(tt::tt_metal::LegacyShape{output_shape, padded_output_shape}));
} else {
TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout);
}
SmallVector<uint32_t> padded_input_start;
for (int index = 0; index < tensor_arg.get_logical_shape().rank(); ++index) {
padded_input_start.push_back(0);
}
TensorSpec result_spec(
tensor_arg.padded_shape(),
TensorLayout::fromPaddedShape(
tensor_arg.dtype(),
PageConfig(layout, std::move(tensor_arg.tensor_spec().tile())),
tensor_arg.memory_config(),
tensor_arg.logical_shape(),
tensor_arg.padded_shape()));

auto tensor = tensor_arg.pad(result_spec.padded_shape(), ttnn::SimpleShape(std::move(padded_input_start)), 0);
tensor = device ? tensor.to(layout, device) : tensor.to(layout);
return tensor.reshape(result_spec.logical_shape());
}
} // namespace detail

Expand Down
Loading

0 comments on commit 444b0dc

Please sign in to comment.