Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Dec 11, 2024
1 parent 38ea1ee commit a01e26a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ttnn::Tensor convert_tile_to_rm(
(tensor.get_dtype() == DataType::BFLOAT8_B)),
"illegal dimensions for a bfloat8 tensor");
auto new_tensor = (tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(tensor, DataType::BFLOAT16) : tensor;
new_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr);
new_tensor = ttnn::to_layout(tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
new_tensor = ReshapeViewOperation::invoke(new_tensor, shape, memory_config, queue_id, pad_value);
new_tensor =
ttnn::to_layout(new_tensor, ttnn::TILE_LAYOUT, new_tensor.get_dtype(), memory_config, (Device*)nullptr);
Expand Down
17 changes: 12 additions & 5 deletions ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,16 +962,23 @@ Tensor pad(
auto pad_value_ = static_cast<T>(pad_value);
const auto input_padded_shape = tensor.get_padded_shape();
const auto input_strides = tensor.strides();
auto output_strides = output_spec.compute_strides();
auto tensor_padded_shape = tensor.padded_shape();

auto pad = [&](const auto& input_buffer) {
ttnn::SmallVector<std::array<uint32_t, 2>> pad_size{};
auto output_strides = output_spec.compute_strides();
ttnn::SmallVector<uint32_t> input_indices(tensor.padded_shape().rank(), 0);

for (auto index = 0; index < output_spec.padded_shape().rank(); index++) {
uint32_t out_dim = output_spec.padded_shape()[index];
uint32_t tensor_dim = index < tensor.padded_shape().size() ? tensor.padded_shape()[index] : 1;
uint32_t start = index < input_tensor_start.size() ? input_tensor_start[index] : 0;
for (int index = 0; index < output_padded_shape.rank(); index++) {
uint32_t out_dim = output_padded_shape[index];

int tensor_idx =
index + static_cast<int>(tensor_padded_shape.size()) - static_cast<int>(output_padded_shape.size());
uint32_t tensor_dim = tensor_idx >= 0 ? tensor_padded_shape[tensor_idx] : 1;

int start_idx =
index + static_cast<int>(input_tensor_start.size()) - static_cast<int>(output_padded_shape.size());
uint32_t start = start_idx >= 0 ? input_tensor_start[start_idx] : 0;

// Check if input tensor fits in output tensor given the input tensor start indices
TT_ASSERT(tensor_dim + start <= out_dim, "Input tensor is out of bounds");
Expand Down
13 changes: 6 additions & 7 deletions ttnn/cpp/ttnn/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,13 @@ class LegacyShape {
}
}
explicit LegacyShape(tt::stl::Span<const uint32_t> shape, tt::stl::Span<const uint32_t> shape_with_tile_padding) :
rank_(shape.size()), dimensions_{}, padding_{shape.size()} {
TT_ASSERT(
shape.size() == shape_with_tile_padding.size(),
"Shape and shape_with_tile_padding must have the same size");
for (auto index = 0; index < shape.size(); index++) {
auto padded_dimension = shape_with_tile_padding[index];
rank_(shape_with_tile_padding.size()), dimensions_{}, padding_{shape_with_tile_padding.size()} {
for (int index = 0; index < shape_with_tile_padding.size(); index++) {
int shape_index = index + static_cast<int>(shape.size()) - static_cast<int>(shape_with_tile_padding.size());
int dimenstion = shape_index >= 0 ? shape[shape_index] : 1;
int padded_dimension = shape_with_tile_padding[index];
this->dimensions_[index] = padded_dimension;
this->padding_[index] = {.front = 0, .back = padded_dimension - shape[index]};
this->padding_[index] = {.front = 0, .back = static_cast<size_t>(padded_dimension - dimenstion)};
}
}
explicit LegacyShape(const ttnn::SmallVector<uint32_t>& shape, const ttnn::SmallVector<uint32_t>& shape_with_tile_padding)
Expand Down

0 comments on commit a01e26a

Please sign in to comment.