-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for padding along width dimension to ttnn.pad (#15985)
### Tickets - #15511 - #15603 (90% resolved with these changes and to be fully resolved in a future PR) - #12896 ### Problem description ttnn.pad's RM sharded implementation only has support for padding along the non-width dimensions. The row major implementation additionally is not fully general with respect to the width dimension, so until now there are no great options for padding along width. In a future PR coming tomorrow, I'll add input massaging code to convert to row-major and shard as needed for input configurations that aren't currently supported by pad. ### What's changed - Adds new kernels to support padding along the width dimension. - For pad operations requiring both NCH and width padding, we use a fused op using the original height-padding kernels and the new width kernels. - The previous point required extensive refactoring to the host code. I would like eyes on pad.cpp please @yugaoTT @sminakov-tt. - Also adds a bunch of common utility functions for working with sharded tensors: - A function for easily creating sharded memory configs from C++ (analogous to the Python `create_sharded_memory_config` utility function created by @ntarafdar) - A function for locating elements of a shard by their coordinates within the tensor. I've tested this one in the context of this PR, but it didn't end up being necessary in the final implementation. ### Checklist - [~] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12327681570) - [x] [Model regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308045581) - [x] [Device performance regression CI testing passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308046347) - [ ] Blackhole Post commit (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [ ] New/Existing tests provide coverage for changes --------- Co-authored-by: tarafdarTT <[email protected]>
- Loading branch information
Showing
11 changed files
with
745 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// This file contains common kernel functions used for debugging | ||
#pragma once | ||
#include "debug/dprint.h" | ||
namespace tt::data_movement::common { | ||
inline void print_bf16_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) { | ||
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(l1_addr) + start * elts_per_page; | ||
for (uint32_t page = 0; page < npages; ++page) { | ||
DPRINT << start + page << ": "; | ||
for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) { | ||
DPRINT << BF16(*ptr) << " "; | ||
} | ||
DPRINT << ENDL(); | ||
} | ||
} | ||
|
||
inline void print_f32_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) { | ||
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(l1_addr) + start * elts_per_page; | ||
for (uint32_t page = 0; page < npages; ++page) { | ||
DPRINT << start + page << ": "; | ||
for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) { | ||
DPRINT << F32(*ptr) << " "; | ||
} | ||
DPRINT << ENDL(); | ||
} | ||
} | ||
|
||
inline void print_u8_pages(uint32_t l1_addr, uint32_t bytes_per_page, uint32_t npages, uint32_t start = 0) { | ||
volatile tt_l1_ptr uint8_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t*>(l1_addr) + start * bytes_per_page; | ||
for (uint32_t page = 0; page < npages; ++page) { | ||
DPRINT << start + page << ": "; | ||
for (uint32_t j = 0; j < bytes_per_page; ++j, ++ptr) { | ||
DPRINT << SETW(2) << HEX() << "0x" << (uint32_t)*ptr << " "; | ||
} | ||
DPRINT << DEC(); // revert to decimal representation | ||
DPRINT << ENDL(); | ||
} | ||
} | ||
} // namespace tt::data_movement::common |
Oops, something went wrong.