Skip to content

Commit

Permalink
Add support for padding along width dimension to ttnn.pad (#15985)
Browse files Browse the repository at this point in the history
### Tickets
- #15511 
- #15603 (90% resolved
with these changes and to be fully resolved in a future PR)
- #12896

### Problem description
ttnn.pad's RM sharded implementation only has support for padding along
the non-width dimensions. The row major implementation additionally is
not fully general with respect to the width dimension, so until now
there are no great options for padding along width. In a future PR
coming tomorrow, I'll add input massaging code to convert to row-major
and shard as needed for input configurations that aren't currently
supported by pad.

### What's changed
- Adds new kernels to support padding along the width dimension.
- For pad operations requiring both NCH and width padding, we use a
fused op using the original height-padding kernels and the new width
kernels.
- The previous point required extensive refactoring to the host code. I
would like eyes on pad.cpp please @yugaoTT @sminakov-tt.
- Also adds a bunch of common utility functions for working with sharded
tensors:
- A function for easily creating sharded memory configs from C++
(analogous to the Python `create_sharded_memory_config` utility function
created by @ntarafdar)
- A function for locating elements of a shard by their coordinates
within the tensor. I've tested this one in the context of this PR, but
it didn't end up being necessary in the final implementation.

### Checklist
- [~] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12327681570)
- [x] [Model regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308045581)
- [x] [Device performance regression CI testing
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12308046347)
- [ ] Blackhole Post commit (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes

---------

Co-authored-by: tarafdarTT <[email protected]>
  • Loading branch information
jaykru-tt and ntarafdar authored Dec 14, 2024
1 parent ed413ee commit a3801c4
Show file tree
Hide file tree
Showing 11 changed files with 745 additions and 75 deletions.
115 changes: 115 additions & 0 deletions tests/ttnn/unit_tests/operations/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,121 @@ def run_pad_rm_sharded(device, n, c, h, w, padding, torch_padding, value, shard_
assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999)


def to_torch_padding(padspec):
def flatten_to_tuple(padding):
return tuple(sum(padding, ()))

def ttnn_pad_spec_to_padding(padspec):
input_tensor_start = padspec["input_tensor_start"]
pad_to_shape = padspec["pad_to_shape"]
input_shape = padspec["input_shape"]

padding = []
for i in range(len(pad_to_shape)):
this_dim_padding = (input_tensor_start[i], pad_to_shape[i] - input_shape[i] - input_tensor_start[i])
padding.append(this_dim_padding)
return padding

torch_padding = flatten_to_tuple(reversed(ttnn_pad_spec_to_padding(padspec)))
return torch_padding


@pytest.mark.parametrize(
"input_shape, pad_to_shape, input_tensor_start, pad_value, input_sharded_memory_config_args",
[
[
(1, 1, 1, 4),
(1, 1, 1, 16),
(0, 0, 0, 0),
3.0,
{"core_grid": ttnn.CoreGrid(x=1, y=1), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# a reduced version of esmal's test case for UNet
(1, 1, 4, 4),
(1, 1, 4, 16),
(0, 0, 0, 0),
3.0,
{"core_grid": ttnn.CoreGrid(x=1, y=1), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# width padding across large core grid, 3 sticks per core
(1, 1, 3 * 64, 4),
(1, 1, 3 * 64, 16),
(0, 0, 0, 0),
0.0,
{"core_grid": ttnn.CoreGrid(x=8, y=8), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# width padding across large core grid, 3 sticks per core, n300 version
(1, 1, 3 * 8 * 7, 4),
(1, 1, 3 * 8 * 7, 16),
(0, 0, 0, 0),
0.0,
{"core_grid": ttnn.CoreGrid(x=8, y=7), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# width padding only, reduced core grid
(1, 1, 12, 8),
(1, 1, 12, 64),
(0, 0, 0, 0),
3.0,
{"core_grid": ttnn.CoreGrid(x=2, y=6), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# height and width padding, small core grid
(1, 1, 2, 4),
(1, 1, 4, 8),
(0, 0, 0, 0),
3.0,
{"core_grid": ttnn.CoreGrid(x=1, y=2), "strategy": ttnn.ShardStrategy.HEIGHT},
],
[
# borys's second test case
(1, 2, 3, 4),
(1, 2, 32, 32),
(0, 0, 0, 0),
3.0,
{"core_grid": ttnn.CoreGrid(x=1, y=6), "strategy": ttnn.ShardStrategy.HEIGHT},
],
],
)
def test_pad_rm_sharded_stickwise(
device, input_shape, pad_to_shape, input_tensor_start, pad_value, input_sharded_memory_config_args
):
core_grid_x_ok = device.core_grid.x >= input_sharded_memory_config_args["core_grid"].x
core_grid_y_ok = device.core_grid.y >= input_sharded_memory_config_args["core_grid"].y
device_core_grid_ok = core_grid_x_ok and core_grid_y_ok
if not device_core_grid_ok:
pytest.skip("core grid for this test is not compatible with the device")

input_shard_memory_config = ttnn.create_sharded_memory_config(input_shape, **input_sharded_memory_config_args)

torch_input_tensor = torch.ones(input_shape, dtype=torch.float32)
ttnn_input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=ttnn.float32, layout=ttnn.ROW_MAJOR_LAYOUT, device=device
)
ttnn_sharded_input_tensor = ttnn.to_memory_config(ttnn_input_tensor, input_shard_memory_config)

padded_tensor = ttnn.pad(ttnn_sharded_input_tensor, pad_to_shape, input_tensor_start, pad_value)

tt_output_tensor = ttnn.to_memory_config(padded_tensor, ttnn.L1_MEMORY_CONFIG)
tt_output_tensor = ttnn.from_device(tt_output_tensor)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)

padspec = {
"input_shape": input_shape,
"pad_to_shape": pad_to_shape,
"input_tensor_start": input_tensor_start,
}
torch_padded_tensor = torch.nn.functional.pad(
torch_input_tensor, to_torch_padding(padspec), mode="constant", value=pad_value
)

assert torch_output_tensor.shape == torch_padded_tensor.shape
assert_with_pcc(torch_padded_tensor, torch_output_tensor, 0.99)


@pytest.mark.parametrize("n", [20])
@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [224])
Expand Down
146 changes: 146 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,152 @@ ttnn::Tensor pad_to_tile_vol(
return tensor;
}
uint32_t wrap_index(int index, int size) { return index < 0 ? size + index : index; }

std::array<uint32_t, 2> compute_block_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const tt::tt_metal::Layout& layout,
const tt::tt_metal::CoreCoord& grid_size,
const tt::tt_metal::ShardOrientation& orientation,
const uint32_t total_num_cores) {
TT_FATAL(grid_size.y * grid_size.x == total_num_cores, "compute_block_sharded_shard_shape received a core grid shape that does not match the total number of cores");
auto adjusted_grid_size = grid_size;
if (orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) {
// for col major, we partition the width of the tensor along the height of the core grid
std::swap(adjusted_grid_size.x, adjusted_grid_size.y);
}

auto [tensor_height, tensor_width] = squeezed_tensor_hw;
auto tensor_height_padded_to_tile =
layout == tt::tt_metal::Layout::TILE
? tt::round_up(tensor_height, adjusted_grid_size.y * tt::constants::TILE_HEIGHT)
: tensor_height;
std::array<uint32_t, 2> shard_shape = {tt::div_up(tensor_height_padded_to_tile, adjusted_grid_size.y),
tt::div_up(tensor_width, adjusted_grid_size.x)};

return shard_shape;
}

std::array<uint32_t, 2> compute_width_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const uint32_t total_num_cores) {
return {squeezed_tensor_hw[0], tt::div_up(squeezed_tensor_hw[1], total_num_cores)};
}

std::array<uint32_t, 2> compute_height_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const tt::tt_metal::Layout& layout,
const uint32_t total_num_cores) {
auto [tensor_height, tensor_width] = squeezed_tensor_hw;
auto squeezed_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE
? tt::round_up(tensor_height, total_num_cores)
: tensor_height;
return {tt::div_up(squeezed_height_padded_to_tile, total_num_cores), tensor_width};
}

ttnn::MemoryConfig create_sharded_memory_config(
const ttnn::SimpleShape& logical_shape,
const tt::tt_metal::CoreRangeSet& core_grid,
const ShardStrategy& strategy,
const tt::tt_metal::ShardOrientation& orientation,
std::optional<std::array<uint32_t, 2>> shard_shape,
const tt::tt_metal::Layout& layout,
bool halo) {
auto rank = logical_shape.rank();
TT_FATAL(rank >= 2, "rank of tensor to shard must be at least 2.");

ttnn::TensorMemoryLayout tensor_memory_layout;
if (strategy == ShardStrategy::BLOCK) {
tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED;
} else if (strategy == ShardStrategy::WIDTH) {
tensor_memory_layout = ttnn::TensorMemoryLayout::WIDTH_SHARDED;
} else if (strategy == ShardStrategy::HEIGHT) {
tensor_memory_layout = ttnn::TensorMemoryLayout::HEIGHT_SHARDED;
}

auto height = logical_shape[-2];
auto width = logical_shape[-1];
std::array<uint32_t, 2> computed_shard_shape;

if (shard_shape.has_value()) {
computed_shard_shape = shard_shape.value();
} else {
uint32_t batch_size = 1;
for (int i = 0; i < rank - 2; i++) {
batch_size *= logical_shape[i];
}

auto tensor_height = batch_size * height;
auto tensor_width = width;
std::array<uint32_t, 2> squeezed_tensor_hw{tensor_height, tensor_width};
auto total_num_cores = core_grid.num_cores();
CoreCoord grid_size = core_grid.bounding_box().grid_size();

switch (strategy) {
case ShardStrategy::BLOCK:
computed_shard_shape = compute_block_sharded_shard_shape(squeezed_tensor_hw, layout, grid_size, orientation, total_num_cores);
break;
case ShardStrategy::WIDTH:
computed_shard_shape = compute_width_sharded_shard_shape(squeezed_tensor_hw, total_num_cores);
break;
case ShardStrategy::HEIGHT:
computed_shard_shape = compute_height_sharded_shard_shape(squeezed_tensor_hw, layout, total_num_cores);
break;
default:
TT_ASSERT(false, "Invalid shard strategy");
}
}

if (layout == tt::tt_metal::Layout::TILE) {
auto [shard_height, shard_width] = computed_shard_shape;
auto tile_divides_shard_height = shard_height % tt::constants::TILE_HEIGHT == 0;
auto tile_divides_shard_width = shard_width % tt::constants::TILE_WIDTH == 0;
TT_FATAL(tile_divides_shard_width && tile_divides_shard_height,
"For sharding tiled tensors, the shard shape must fit neatly into tiles but "
"create_sharded_memory_config got shard width {} and shard height {} while "
"on this architecture we have tile width {} and tile height {}",
computed_shard_shape[0], computed_shard_shape[1], tt::constants::TILE_WIDTH, tt::constants::TILE_HEIGHT);
}

auto shard_spec = tt::tt_metal::ShardSpec(core_grid, computed_shard_shape, orientation, halo);
return ttnn::MemoryConfig(tensor_memory_layout, ttnn::BufferType::L1, shard_spec);
}

std::pair<uint32_t, std::array<uint32_t, 2>> tensor_coord_to_height_sharded_coord(
const std::span<const uint32_t>& tensor_shape,
const std::span<const uint32_t>& shard_shape,
const std::span<const uint32_t>& tensor_coord) {
std::array<uint32_t, 2> tensor_shape_2d{0, 0};
for (size_t i = 0; i < tensor_shape.size(); i++) {
if (i == tensor_shape.size() - 1) {
// width dimension, goes unmodified
tensor_shape_2d[1] = tensor_shape[i];
} else {
// height dimension, squeeze into 2D shape
if (tensor_shape_2d[0] == 0) {
// first time we've seen this dimension
tensor_shape_2d[0] = tensor_shape[i];
} else {
tensor_shape_2d[0] *= tensor_shape[i];
}
}
}

std::array<uint32_t, 2> tensor_coord_2d{0, tensor_coord.back()};
uint32_t height_2d = 0;
for (size_t i = 0; i < tensor_coord.size() - 1; i++) {
std::vector<uint32_t> page_shapes(tensor_shape.begin() + i + 1, tensor_shape.end() - 1);
auto component_sum =
tensor_coord[i] * std::accumulate(page_shapes.begin(), page_shapes.end(), 1, std::multiplies<uint32_t>());
height_2d += component_sum;
}
tensor_coord_2d[0] = height_2d;

uint32_t shard_height = shard_shape[0];
uint32_t w_in_shard = tensor_coord_2d[1];
uint32_t h_in_shard = height_2d % shard_height;
uint32_t which_shard = height_2d / shard_height;

std::array<uint32_t, 2> shard_coord{h_in_shard, w_in_shard};
return std::make_pair(which_shard, shard_coord);
}

} // namespace data_movement
} // namespace operations
} // namespace ttnn
21 changes: 21 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ ttnn::Tensor pad_to_tile_vol(
const bool use_multicore,
const std::optional<MemoryConfig>& memory_config);

enum class ShardStrategy { BLOCK, HEIGHT, WIDTH };

// Helper function for creating a sharded memory configuration for a tensor
// based on its logical shape, a shard strategy and orientation, and a core
// grid. Optionally, you may pass a preferred shard shape to use. If not
// provided, the shard shape will be inferred from the tensor shape and the
// shard strategy.
ttnn::MemoryConfig create_sharded_memory_config(
const ttnn::SimpleShape& logical_shape,
const tt::tt_metal::CoreRangeSet& core_grid,
const ShardStrategy& strategy,
const tt::tt_metal::ShardOrientation& orientation,
std::optional<std::array<uint32_t, 2>> shard_shape = std::nullopt,
const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR,
bool halo = false);

std::pair<uint32_t, std::array<uint32_t, 2>> tensor_coord_to_height_sharded_coord(
const std::span<const uint32_t>& tensor_shape,
const std::span<const uint32_t>& shard_shape,
const std::span<const uint32_t>& tensor_coord);

} // namespace data_movement
} // namespace operations
} // namespace ttnn
42 changes: 42 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/common/kernels/debug.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

// This file contains common kernel functions used for debugging
#pragma once
#include "debug/dprint.h"
namespace tt::data_movement::common {
inline void print_bf16_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) {
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(l1_addr) + start * elts_per_page;
for (uint32_t page = 0; page < npages; ++page) {
DPRINT << start + page << ": ";
for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) {
DPRINT << BF16(*ptr) << " ";
}
DPRINT << ENDL();
}
}

inline void print_f32_pages(uint32_t l1_addr, uint32_t elts_per_page, uint32_t npages, uint32_t start = 0) {
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(l1_addr) + start * elts_per_page;
for (uint32_t page = 0; page < npages; ++page) {
DPRINT << start + page << ": ";
for (uint32_t j = 0; j < elts_per_page; ++j, ++ptr) {
DPRINT << F32(*ptr) << " ";
}
DPRINT << ENDL();
}
}

inline void print_u8_pages(uint32_t l1_addr, uint32_t bytes_per_page, uint32_t npages, uint32_t start = 0) {
volatile tt_l1_ptr uint8_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint8_t*>(l1_addr) + start * bytes_per_page;
for (uint32_t page = 0; page < npages; ++page) {
DPRINT << start + page << ": ";
for (uint32_t j = 0; j < bytes_per_page; ++j, ++ptr) {
DPRINT << SETW(2) << HEX() << "0x" << (uint32_t)*ptr << " ";
}
DPRINT << DEC(); // revert to decimal representation
DPRINT << ENDL();
}
}
} // namespace tt::data_movement::common
Loading

0 comments on commit a3801c4

Please sign in to comment.