Skip to content

Commit

Permalink
Minor improvements, fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Jan 10, 2025
1 parent 81fa970 commit 8df5ce0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ bool requires_padding_change(const ttnn::Tensor& tensor, ttnn::Layout layout) {
auto tile = tensor.get_tensor_spec().tile();
if (layout == Layout::ROW_MAJOR) {
// There shouldn't be extra paddings for Row Major layout
return tensor.logical_shape() != tensor.padded_shape();
return tensor.get_logical_shape() != tensor.get_padded_shape();
}
// It's okay for conversion to tile layout to preserve arbitrary padding as long as it satisfies the alignment
TensorSpec padded_spec(
tensor.padded_shape(),
TensorLayout(tensor.dtype(), PageConfig(layout, std::move(tile)), tensor.memory_config()));
tensor.get_padded_shape(),
TensorLayout(tensor.get_dtype(), PageConfig(layout, std::move(tile)), tensor.memory_config()));
return tensor.get_padded_shape() != padded_spec.padded_shape();
}

Expand Down Expand Up @@ -105,7 +105,7 @@ Tensor to_layout_impl(
SmallVector<uint32_t> new_padded_shape(2, 1);
new_padded_shape[1] = tensor.get_padded_shape()[-1];
new_padded_shape[0] = tensor.get_padded_shape()[-2];
tensor = tensor.reshape(tensor.logical_shape(), SimpleShape(new_padded_shape));
tensor = tensor.reshape(tensor.get_logical_shape(), SimpleShape(new_padded_shape));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Tensor AutoFormat::format_input_tensor(
pad_value_variant = (uint32_t)pad_value;
}
return ttnn::tilize_with_val_padding(
formatted_input, Shape(padded_shape).padded_shape(), pad_value_variant, mem_config);
formatted_input, padded_shape.padded_shape(), pad_value_variant, mem_config);
} else if (formatted_input.get_layout() == Layout::TILE && target_layout == Layout::ROW_MAJOR) {
formatted_input = ttnn::untilize(formatted_input, mem_config);
return ttnn::pad(
Expand Down
14 changes: 7 additions & 7 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,24 +325,24 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap
ZoneScoped;
GraphTracker::instance().track_function_start("Tensor::unpad_from_tile", input_tensor, output_tensor_shape);

for (auto index = 0; index < input_tensor.get_legacy_shape().rank() - 2; index++) {
for (auto index = -3; index >= -static_cast<int>(input_tensor.get_padded_shape().rank()); index--) {
TT_ASSERT(
input_tensor.get_legacy_shape().without_padding()[index] == output_tensor_shape[index],
"Input shape must match output shape apart from last 2 dims");
}
TT_ASSERT(
input_tensor.get_legacy_shape()[-2] % constants::TILE_HEIGHT == 0 &&
input_tensor.get_legacy_shape()[-1] % constants::TILE_WIDTH == 0,
input_tensor.get_padded_shape()[-2] % constants::TILE_HEIGHT == 0 &&
input_tensor.get_padded_shape()[-1] % constants::TILE_WIDTH == 0,
"Last 2 dims of input shape must be multiples of 32");
TT_ASSERT(
input_tensor.get_legacy_shape()[-2] - constants::TILE_HEIGHT < output_tensor_shape[-2] &&
input_tensor.get_legacy_shape()[-1] - constants::TILE_WIDTH < output_tensor_shape[-1],
input_tensor.get_padded_shape()[-2] - constants::TILE_HEIGHT < output_tensor_shape[-2] &&
input_tensor.get_padded_shape()[-1] - constants::TILE_WIDTH < output_tensor_shape[-1],
"Last 2 dims of output must be within range to have been padded to input");
ttnn::SmallVector<uint32_t> output_tensor_start{};
ttnn::SmallVector<uint32_t> output_tensor_end{};
for (auto index = 0; index < input_tensor.get_legacy_shape().rank(); index++) {
for (auto index = 0; index < input_tensor.get_padded_shape().rank(); index++) {
output_tensor_start.push_back(0);
output_tensor_end.push_back(output_tensor_shape[index]);
output_tensor_end.push_back(index < output_tensor_shape.rank() ? output_tensor_shape[index] : 1);
}
auto output = input_tensor.unpad(
ttnn::SimpleShape(std::move(output_tensor_start)), ttnn::SimpleShape(std::move(output_tensor_end)));
Expand Down

0 comments on commit 8df5ce0

Please sign in to comment.