Skip to content

Commit

Permalink
#0: Address more Stas comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jaykru-tt committed Dec 14, 2024
1 parent a85673c commit a4d01f9
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 71 deletions.
129 changes: 78 additions & 51 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,81 +56,108 @@ ttnn::Tensor pad_to_tile_vol(
}
uint32_t wrap_index(int index, int size) { return index < 0 ? size + index : index; }

std::array<uint32_t, 2> compute_block_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const tt::tt_metal::Layout& layout,
const tt::tt_metal::CoreCoord& grid_size,
const tt::tt_metal::ShardOrientation& orientation,
const uint32_t total_num_cores) {
TT_FATAL(grid_size.y * grid_size.x == total_num_cores, "compute_block_sharded_shard_shape received a core grid shape that does not match the total number of cores");
auto adjusted_grid_size = grid_size;
if (orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) {
// for col major, we partition the width of the tensor along the height of the core grid
std::swap(adjusted_grid_size.x, adjusted_grid_size.y);
}

auto [tensor_height, tensor_width] = squeezed_tensor_hw;
auto tensor_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE
? tt::round_up(tensor_height, adjusted_grid_size.y * tt::constants::TILE_HEIGHT)
: tensor_height;
std::array<uint32_t, 2> shard_shape = {tt::div_up(tensor_height_padded_to_tile, adjusted_grid_size.y),
tt::div_up(tensor_width, adjusted_grid_size.x)};

return shard_shape;
}

std::array<uint32_t, 2> compute_width_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const uint32_t total_num_cores) {
return {squeezed_tensor_hw[0], tt::div_up(squeezed_tensor_hw[1], total_num_cores)};
}

std::array<uint32_t, 2> compute_height_sharded_shard_shape(const std::array<uint32_t, 2>& squeezed_tensor_hw,
const tt::tt_metal::Layout& layout,
const uint32_t total_num_cores) {
auto [tensor_height, tensor_width] = squeezed_tensor_hw;
auto squeezed_height_padded_to_tile = layout == tt::tt_metal::Layout::TILE
? tt::round_up(tensor_height, total_num_cores)
: tensor_height;
return {tt::div_up(squeezed_height_padded_to_tile, total_num_cores), tensor_width};
}

ttnn::MemoryConfig create_sharded_memory_config(
const ttnn::Shape& shape,
const ttnn::SimpleShape& logical_shape,
const tt::tt_metal::CoreRangeSet& core_grid,
const ShardStrategy& strategy,
const tt::tt_metal::ShardOrientation& orientation,
bool halo,
bool use_height_and_width_as_shard_shape,
const tt::tt_metal::Layout& layout) {
auto is_tile_layout = layout == tt::tt_metal::Layout::TILE;

auto rank = shape.rank();
std::optional<std::array<uint32_t, 2>> shard_shape,
const tt::tt_metal::Layout& layout,
bool halo) {
auto rank = logical_shape.rank();
TT_FATAL(rank >= 2, "rank of tensor to shard must be at least 2.");

auto tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED;
if (strategy == ShardStrategy::WIDTH) {
ttnn::TensorMemoryLayout tensor_memory_layout;
if (strategy == ShardStrategy::BLOCK) {
tensor_memory_layout = ttnn::TensorMemoryLayout::BLOCK_SHARDED;
} else if (strategy == ShardStrategy::WIDTH) {
tensor_memory_layout = ttnn::TensorMemoryLayout::WIDTH_SHARDED;
} else if (strategy == ShardStrategy::HEIGHT) {
tensor_memory_layout = ttnn::TensorMemoryLayout::HEIGHT_SHARDED;
}

auto shard_orientation = orientation;
auto shard_grid = core_grid;
auto height = logical_shape[-2];
auto width = logical_shape[-1];
std::array<uint32_t, 2> computed_shard_shape;

auto height = shape[-2];
auto width = shape[-1];
std::array<uint32_t, 2> shard_shape;

if (use_height_and_width_as_shard_shape) {
if (shard_orientation == tt::tt_metal::ShardOrientation::ROW_MAJOR) {
shard_shape = {height, width};
} else if (shard_orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) {
shard_shape = {width, height};
} else {
TT_THROW("Invalid shard orientation");
}
if (shard_shape.has_value()) {
computed_shard_shape = shard_shape.value();
} else {
uint32_t batch_size = 1;
for (int i = 0; i < rank - 2; i++) {
batch_size *= shape[i];
batch_size *= logical_shape[i];
}

auto tensor_height = batch_size * height;
auto tensor_width = width;
auto total_num_cores = shard_grid.num_cores();
auto grid_size = shard_grid.bounding_box().grid_size();

if (tensor_memory_layout == ttnn::TensorMemoryLayout::BLOCK_SHARDED) {
TT_ASSERT(grid_size.y * grid_size.x == total_num_cores, "Invalid CoreRangeSet for block sharding strategy");

if (shard_orientation == tt::tt_metal::ShardOrientation::ROW_MAJOR) {
auto tensor_height_padded =
is_tile_layout ? tt::round_up(tensor_height, grid_size.y * 32) : tensor_height;
shard_shape = {tt::div_up(tensor_height_padded, grid_size.y), tt::div_up(tensor_width, grid_size.x)};
} else if (shard_orientation == tt::tt_metal::ShardOrientation::COL_MAJOR) {
auto tensor_height_padded =
is_tile_layout ? tt::round_up(tensor_height, grid_size.x * 32) : tensor_height;
shard_shape = {tt::div_up(tensor_height_padded, grid_size.x), tt::div_up(tensor_width, grid_size.y)};
} else {
TT_THROW("Invalid shard orientation");
}
} else if (tensor_memory_layout == ttnn::TensorMemoryLayout::HEIGHT_SHARDED) {
auto tensor_height_padded = is_tile_layout ? tt::round_up(tensor_height, total_num_cores) : tensor_height;
shard_shape = {tt::div_up(tensor_height_padded, total_num_cores), tensor_width};
} else if (tensor_memory_layout == ttnn::TensorMemoryLayout::WIDTH_SHARDED) {
shard_shape = {tensor_height, tt::div_up(tensor_width, total_num_cores)};
} else {
TT_THROW("Invalid sharding scheme");
std::array<uint32_t, 2> squeezed_tensor_hw{tensor_height, tensor_width};
auto total_num_cores = core_grid.num_cores();
CoreCoord grid_size = core_grid.bounding_box().grid_size();

switch (strategy) {
case ShardStrategy::BLOCK:
computed_shard_shape = compute_block_sharded_shard_shape(squeezed_tensor_hw, layout, grid_size, orientation, total_num_cores);
break;
case ShardStrategy::WIDTH:
computed_shard_shape = compute_width_sharded_shard_shape(squeezed_tensor_hw, total_num_cores);
break;
case ShardStrategy::HEIGHT:
computed_shard_shape = compute_height_sharded_shard_shape(squeezed_tensor_hw, layout, total_num_cores);
break;
default:
TT_ASSERT(false, "Invalid shard strategy");
}
}

if (is_tile_layout && shard_shape[0] % 32 != 0 && shard_shape[1] % 32 != 0) {
TT_THROW("For sharding tiled tensors, the shard shape must fit neatly into tiles.");
if (layout == tt::tt_metal::Layout::TILE) {
auto [shard_height, shard_width] = computed_shard_shape;
auto tile_divides_shard_height = shard_height % tt::constants::TILE_HEIGHT == 0;
auto tile_divides_shard_width = shard_width % tt::constants::TILE_WIDTH == 0;
TT_FATAL(tile_divides_shard_width && tile_divides_shard_height,
"For sharding tiled tensors, the shard shape must fit neatly into tiles but "
"create_sharded_memory_config got shard width {} and shard height {} while "
"on this architecture we have tile width {} and tile height {}",
computed_shard_shape[0], computed_shard_shape[1], tt::constants::TILE_WIDTH, tt::constants::TILE_HEIGHT);
}

auto shard_spec = tt::tt_metal::ShardSpec(shard_grid, shard_shape, shard_orientation, halo);
auto shard_spec = tt::tt_metal::ShardSpec(core_grid, computed_shard_shape, orientation, halo);
return ttnn::MemoryConfig(tensor_memory_layout, ttnn::BufferType::L1, shard_spec);
}

Expand Down
13 changes: 9 additions & 4 deletions ttnn/cpp/ttnn/operations/data_movement/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,19 @@ ttnn::Tensor pad_to_tile_vol(

enum class ShardStrategy { BLOCK, HEIGHT, WIDTH };

// Helper function for creating a sharded memory configuration for a tensor
// based on its logical shape, a shard strategy and orientation, and a core
// grid. Optionally, you may pass a preferred shard shape to use. If not
// provided, the shard shape will be inferred from the tensor shape and the
// shard strategy.
ttnn::MemoryConfig create_sharded_memory_config(
const ttnn::Shape& shape,
const ttnn::SimpleShape& logical_shape,
const tt::tt_metal::CoreRangeSet& core_grid,
const ShardStrategy& strategy,
const tt::tt_metal::ShardOrientation& orientation,
bool halo = false,
bool use_height_and_width_as_shard_shape = false,
const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR);
std::optional<std::array<uint32_t, 2>> shard_shape = std::nullopt,
const tt::tt_metal::Layout& layout = tt::tt_metal::Layout::ROW_MAJOR,
bool halo = false);

std::pair<uint32_t, std::array<uint32_t, 2>> tensor_coord_to_height_sharded_coord(
const std::span<const uint32_t>& tensor_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ void kernel_main() {
// optimization (upcoming as of 12/12/2024) this might be worth
// investigating.

// paulk says that an optimized loop will still be faster.
// TODO(jkruer): get paul's help optimizing this.

// read the input stick into the padded output stick starting after the
// front padding
for (uint32_t i = 0; i < unpadded_stick_bytes; i++) {
Expand Down
23 changes: 7 additions & 16 deletions ttnn/cpp/ttnn/operations/data_movement/pad/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ static ttnn::Tensor pad_impl(
auto input_tensor_shape = input_tensor.get_shape();
const auto rank = input_tensor_shape.rank();

TT_FATAL(rank == 4, "ttnn.pad: input tensor rank is not 4");
TT_ASSERT(rank == 4, "ttnn.pad: input tensor passed to pad_impl must have rank == 4.");

using ShardStrategy = ttnn::operations::data_movement::ShardStrategy;
using ShardOrientation = tt::tt_metal::ShardOrientation;
Expand All @@ -73,24 +73,18 @@ static ttnn::Tensor pad_impl(
auto width_distinct = [](const auto& shape, const auto& other_shape) { return shape[3] != other_shape[3]; };

uint32_t input_w = input_logical_shape[3];

uint32_t output_w = output_padded_shape[3];

if (width_distinct(input_logical_shape, output_padded_shape)) {
ttnn::SmallVector<uint32_t> output_shape_width_padded{
input_logical_shape.begin(), input_logical_shape.end() - 1};
output_shape_width_padded.push_back(output_w);

std::array<uint32_t, 4> output_shape_width_padded{
input_logical_shape[0], input_logical_shape[1], input_logical_shape[2], output_w};
auto width_pad_memory_config = create_sharded_memory_config(
ttnn::Shape{output_shape_width_padded},
ttnn::SimpleShape{output_shape_width_padded},
input_tensor.shard_spec()->grid, // reuse input cores for now: FIXME: can we do better?
// it's complicated because we need the input shards to be local
// to the core holding the output shard currently.
ShardStrategy::HEIGHT, // stay height sharded
ShardOrientation::ROW_MAJOR,
false,
false,
Layout::ROW_MAJOR);
ShardOrientation::ROW_MAJOR);
output_memory_config = width_pad_memory_config;

if (height_distinct(input_logical_shape, output_padded_shape)) {
Expand Down Expand Up @@ -119,13 +113,10 @@ static ttnn::Tensor pad_impl(
"infinite recursion");

auto height_pad_memory_config = create_sharded_memory_config(
ttnn::Shape{output_padded_shape},
ttnn::SimpleShape{output_padded_shape},
input_tensor.shard_spec()->grid,
ShardStrategy::HEIGHT,
ShardOrientation::ROW_MAJOR,
false,
false,
Layout::ROW_MAJOR);
ShardOrientation::ROW_MAJOR);

// then pad height
auto output_tensor_height_padded = pad_impl(
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/tensor/shape/shape_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ShapeBase {
explicit ShapeBase(const std::array<uint32_t, N>& arr) : value_(arr.begin(), arr.end()) {
init();
}
explicit ShapeBase(std::span<const uint32_t> span) : value_(span.begin(), span.end()) { init(); }

template <std::size_t N>
bool operator==(const std::array<uint32_t, N>& other) const {
Expand Down

0 comments on commit a4d01f9

Please sign in to comment.