Skip to content

Commit

Permalink
#9744: add ttnn to autoformat pad
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed Jul 2, 2024
1 parent 730d207 commit fa67366
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions tt_eager/tt_dnn/op_library/auto_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "tt_dnn/op_library/copy/copy_op.hpp"
#include "tt_dnn/op_library/data_transfer/data_transfer_op.hpp"
#include "tt_dnn/op_library/layout_conversion/layout_conversion_op.hpp"
#include "tt_dnn/op_library/pad/pad_op.hpp"
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "tt_dnn/op_library/tilize/tilize_op.hpp"
#include "tt_dnn/op_library/transpose/transpose_op.hpp"
#include "tt_dnn/op_library/unpad/unpad_op.hpp"
Expand Down Expand Up @@ -93,14 +93,14 @@ Tensor AutoFormat::format_input_tensor(
}
} else if (!convert_layout && pad_input) {
if (formatted_input.get_layout() == Layout::ROW_MAJOR || formatted_input.get_layout() == Layout::TILE) {
return pad(formatted_input, padded_shape, {0, 0, 0, 0}, pad_value, mem_config);
return ttnn::pad((const ttnn::Tensor) formatted_input, ttnn::Shape(padded_shape), ttnn::Shape({0, 0, 0, 0}), pad_value, mem_config);
}
} else if (convert_layout && pad_input) {
if (formatted_input.get_layout() == Layout::ROW_MAJOR && target_layout == Layout::TILE) {
return tilize_with_val_padding(formatted_input, padded_shape, pad_value, mem_config);
} else if (formatted_input.get_layout() == Layout::TILE && target_layout == Layout::ROW_MAJOR) {
formatted_input = untilize(formatted_input, mem_config);
return pad(formatted_input, padded_shape, {0, 0, 0, 0}, pad_value, mem_config);
return ttnn::pad((const ttnn::Tensor) formatted_input, ttnn::Shape(padded_shape), ttnn::Shape({0, 0, 0, 0}), pad_value, mem_config);
}
}
// Fall back to host conversions
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,8 +1607,8 @@ std::vector<Tensor> _prod_bw(
// dim 0
Tensor tensor_1_temp = reciprocal_input;
if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) {
std::vector<std::pair<uint32_t, uint32_t>> padding = {{0, 0},
{0, 32 - (reciprocal_input.get_legacy_shape()[0] % 32)},
std::vector<std::pair<uint32_t, uint32_t>> padding = {{0, (32 - (reciprocal_input.get_legacy_shape()[0] % 32))},
{0, 0},
{0, 0},
{0, 0}};
tensor_1_temp = ttnn::pad(reciprocal_input, padding, 0, std::nullopt);
Expand Down

0 comments on commit fa67366

Please sign in to comment.