Skip to content

Commit

Permalink
#6271: Add sharded support for silu
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Mar 18, 2024
1 parent 1267b58 commit d9a8db1
Show file tree
Hide file tree
Showing 6 changed files with 464 additions and 5 deletions.
168 changes: 168 additions & 0 deletions tests/ttnn/unit_tests/operations/test_silu_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import math
from typing import Union, Tuple

import torch
import torch.nn as nn
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0


TILE_WIDTH = 32


def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet:
max_grid_size = (9, 12) ## (y, x)
if isinstance(ncores, int):
if ncores % max_grid_size[1] == 0:
core_grid = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1])
grid_coord = ttnn.experimental.tensor.CoreCoord(core_grid.x - 1, core_grid.y - 1)
return ttnn.experimental.tensor.CoreRangeSet(
{ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord)}
)
else:
if ncores < max_grid_size[1]:
core_grid = ttnn.CoreGrid(y=1, x=ncores)
grid_coord = ttnn.experimental.tensor.CoreCoord(core_grid.x - 1, 0)
return ttnn.experimental.tensor.CoreRangeSet(
{ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord)}
)
else:
core_grid_1 = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1])
core_grid_2 = ttnn.CoreGrid(y=ncores // max_grid_size[1] + 1, x=ncores % max_grid_size[1])
grid_coord_1 = ttnn.experimental.tensor.CoreCoord(core_grid_1.x - 1, core_grid_1.y - 1)
grid_coord_2 = ttnn.experimental.tensor.CoreCoord(core_grid_2.x - 1, core_grid_2.y - 1)
return ttnn.experimental.tensor.CoreRangeSet(
{
ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord_1),
ttnn.experimental.tensor.CoreRange(
ttnn.experimental.tensor.CoreCoord(0, grid_coord_2.y), grid_coord_2
),
}
)
elif isinstance(ncores, tuple):
ncores_h, ncores_w = ncores
assert ncores_h <= max_grid_size[0]
assert ncores_w <= max_grid_size[1]
return ttnn.experimental.tensor.CoreRangeSet(
{
ttnn.experimental.tensor.CoreRange(
ttnn.experimental.tensor.CoreCoord(0, 0),
ttnn.experimental.tensor.CoreCoord(ncores_w - 1, ncores_h - 1),
)
}
)
else:
raise ValueError("Invalid ncores")


@pytest.mark.parametrize(
"input_shape",
[
[2, 8, 8, 640],
[2, 16, 16, 640],
[1, 16, 16, 640],
[2, 8, 8, 1280],
[2, 16, 16, 1280],
],
)
@pytest.mark.parametrize(
"shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK, ttnn.ShardStrategy.WIDTH]
)
def test_silu_multi_core(device, input_shape, shard_strategy):
## input shape is N C H W
batch_size, height, width, num_channels = input_shape
torch.manual_seed(0)
input = torch.rand(input_shape, dtype=torch.bfloat16)

torch_result = nn.functional.silu(input)

tt_input = input
num_bytes = 2 ## only BFLOAT16 is supported

## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape
ncores = None
max_grid_size = (9, 12) ## (y, x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height % nshards == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.WIDTH:
## nsticks per shard should be divisible by in_w
max_nshards_w = min(num_channels, max_grid_size[1])
nshards_w = max_nshards_w
while nshards_w > 0:
## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B
if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0:
break
nshards_w -= 1
ncores = nshards_w
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
while nshards_h > 0:
if batch_size * height % nshards_h == 0:
break
nshards_h -= 1
## find nshards_w along C
nshards_w = max_nshards_w
while nshards_w > 0:
## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B
if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0:
break
nshards_w -= 1
if nshards_w == 0 or nshards_h == 0:
raise ValueError("nshards_h or nshards_w is 0")
ncores = (nshards_h, nshards_w)

shard_grid = get_shard_grid_from_num_cores(ncores)
shard_orientation = ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED
elif shard_strategy == ttnn.ShardStrategy.WIDTH:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.WIDTH_SHARDED

## input shard
if shard_strategy == ttnn.ShardStrategy.BLOCK:
shard_height = math.ceil(batch_size * height * width / ncores[0])
shard_width = math.ceil(num_channels / ncores[1])
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
shard_height = math.ceil(batch_size * height * width / ncores)
shard_width = num_channels
elif shard_strategy == ttnn.ShardStrategy.WIDTH:
shard_height = math.ceil(batch_size * height * width)
shard_width = math.ceil(num_channels / ncores)
shard_shape = (shard_height, shard_width)

shard_spec = ttnn.experimental.tensor.ShardSpec(shard_grid, shard_shape, shard_orientation, False)
in_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec)

## output shard
shard_shape = (shard_height, shard_width)
shard_spec = ttnn.experimental.tensor.ShardSpec(shard_grid, shard_shape, shard_orientation, False)

input_tensor = ttnn.from_torch(tt_input, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config)

output_tensor = ttnn.silu(input_tensor, memory_config=in_sharded_mem_config)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

## compare the results
assert_with_pcc(torch_result, output_tensor, 0.999)
65 changes: 61 additions & 4 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,75 @@ void EltwiseUnary::validate(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to eltwise unary need to be on device!");
TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands to eltwise unary need to be allocated in buffers on device!");
TT_FATAL((input_tensor_a.get_layout() == Layout::TILE), "Inputs to eltwise unary must be tilized");
TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Eltwise unary does not currently support sharding");
TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Eltwise unary does not currently support sharding");
}

std::vector<Shape> EltwiseUnary::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
return {input_tensor.get_legacy_shape()};
if (output_mem_config.is_sharded()){
const auto input_shape = input_tensor.get_legacy_shape().without_padding();

uint32_t out_n = input_shape[0];
uint32_t out_h = input_shape[1];
uint32_t out_w = input_shape[2];
uint32_t out_c = input_shape[3];
const auto out_dims = std::vector<uint32_t>({ out_n, out_h, out_w, out_c }); //in the NHWC format
auto out_shape = Shape{out_dims};

return {out_shape};
}
else{
return {input_tensor.get_legacy_shape()};
}
}

std::vector<Tensor> EltwiseUnary::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
if (output_mem_config.is_sharded()) {
if (input_tensor.memory_config().is_sharded()) {
auto mem_config = output_mem_config;
auto input_shard_spec = input_tensor.memory_config().shard_spec.value();
auto output_shape = compute_output_shapes(input_tensors).at(0);
if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
auto ncores = input_shard_spec.num_cores();
array<uint32_t, 2> output_shard_shape = {div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]};
auto output_shard_spec = input_shard_spec;
output_shard_spec.shape = output_shard_shape;
mem_config.shard_spec = output_shard_spec;
log_debug(LogOp, "output_shard_shape: {}", output_shard_shape);
log_debug(LogOp, "output_shard_spec: {}", output_shard_spec);
return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)};
}else if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
auto ncores = input_shard_spec.num_cores();
array<uint32_t, 2> output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2], div_up(output_shape[-1],ncores)};
auto output_shard_spec = input_shard_spec;
output_shard_spec.shape = output_shard_shape;
mem_config.shard_spec = output_shard_spec;
log_debug(LogOp, "output_shard_shape: {}", output_shard_shape);
log_debug(LogOp, "output_shard_spec: {}", output_shard_spec);
return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)};
}
else if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
auto shard_grid = input_shard_spec.grid.ranges();
TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange");
auto core_range = *shard_grid.begin();
uint32_t ncores_w = core_range.end.x + 1;
uint32_t ncores_h = core_range.end.y + 1;
// array<uint32_t, 2> output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / ncores_h, output_shape[-1] / ncores_w};
// auto output_shard_spec = input_shard_spec;
// output_shard_spec.shape = output_shard_shape;
// mem_config.shard_spec = output_shard_spec;
auto output_shard_spec = mem_config.shard_spec.value();
auto output_shard_shape = output_shard_spec.shape;
log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h);
log_debug(LogOp, "output_shard_shape: {}", output_shard_shape);
return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)};
} else {
TT_FATAL(false, "input memory config is not HEIGHT or WIDTH or BLOCK sharded");
}
} else {
TT_FATAL(false, "Output memory config is sharded but input memory config is not sharded");
}
}
return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config);
}

Expand Down
6 changes: 6 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ inline Tensor run_eltwise_unary(
TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified");
Shape pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape());
FormatParams input_format_params = {.pad_shape = pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE};
if(output_mem_config.is_sharded() && (output_mem_config.memory_layout ==
TensorMemoryLayout::HEIGHT_SHARDED || output_mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED || output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)){
return operation::run_without_autoformat(
EltwiseUnary{ops_chain, output_mem_config}, {input_tensor})
.at(0);
}
return operation::run_with_autoformat(
EltwiseUnary{ops_chain, output_mem_config}, {input_tensor}, {input_format_params}, {Layout::TILE})
.at(0);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0);
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0);

constexpr uint32_t onetile = 1;
for (uint32_t i = 0; i < num_tiles_per_core; ++ i) {
cb_push_back(cb_id_in0, onetile);
}
}
Loading

0 comments on commit d9a8db1

Please sign in to comment.