diff --git a/tests/ttnn/unit_tests/operations/test_silu_sharded.py b/tests/ttnn/unit_tests/operations/test_silu_sharded.py new file mode 100644 index 000000000000..40b3d560d7bd --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_silu_sharded.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import math +from typing import Union, Tuple + +import torch +import torch.nn as nn +import ttnn + +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_wormhole_b0 + + +TILE_WIDTH = 32 + + +def get_shard_grid_from_num_cores(ncores: Union[int, Tuple[int, int]]) -> ttnn.experimental.tensor.CoreRangeSet: + max_grid_size = (9, 12) ## (y, x) + if isinstance(ncores, int): + if ncores % max_grid_size[1] == 0: + core_grid = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1]) + grid_coord = ttnn.experimental.tensor.CoreCoord(core_grid.x - 1, core_grid.y - 1) + return ttnn.experimental.tensor.CoreRangeSet( + {ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord)} + ) + else: + if ncores < max_grid_size[1]: + core_grid = ttnn.CoreGrid(y=1, x=ncores) + grid_coord = ttnn.experimental.tensor.CoreCoord(core_grid.x - 1, 0) + return ttnn.experimental.tensor.CoreRangeSet( + {ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord)} + ) + else: + core_grid_1 = ttnn.CoreGrid(y=ncores // max_grid_size[1], x=max_grid_size[1]) + core_grid_2 = ttnn.CoreGrid(y=ncores // max_grid_size[1] + 1, x=ncores % max_grid_size[1]) + grid_coord_1 = ttnn.experimental.tensor.CoreCoord(core_grid_1.x - 1, core_grid_1.y - 1) + grid_coord_2 = ttnn.experimental.tensor.CoreCoord(core_grid_2.x - 1, core_grid_2.y - 1) + return ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), grid_coord_1), + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, grid_coord_2.y), grid_coord_2 + ), + } + ) + elif isinstance(ncores, tuple): + ncores_h, ncores_w = ncores + assert ncores_h <= max_grid_size[0] + assert ncores_w <= max_grid_size[1] + return ttnn.experimental.tensor.CoreRangeSet( + { + ttnn.experimental.tensor.CoreRange( + ttnn.experimental.tensor.CoreCoord(0, 0), + ttnn.experimental.tensor.CoreCoord(ncores_w - 1, ncores_h - 1), + ) + } + ) + else: + raise ValueError("Invalid ncores") + + +@pytest.mark.parametrize( + "input_shape", + [ + [2, 8, 8, 640], + [2, 16, 16, 640], + [1, 16, 16, 640], + [2, 8, 8, 1280], + [2, 16, 16, 1280], + ], +) +@pytest.mark.parametrize( + "shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK, ttnn.ShardStrategy.WIDTH] +) +def test_silu_multi_core(device, input_shape, shard_strategy): + ## input shape is N C H W + batch_size, height, width, num_channels = input_shape + torch.manual_seed(0) + input = torch.rand(input_shape, dtype=torch.bfloat16) + + torch_result = nn.functional.silu(input) + + tt_input = input + num_bytes = 2 ## only BFLOAT16 is supported + + ## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape + ncores = None + max_grid_size = (9, 12) ## (y, x) + if shard_strategy == ttnn.ShardStrategy.HEIGHT: + ## nsticks per shard should be divisible by in_w + max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1]) + nshards = max_nshards + while nshards > 0: + if batch_size * height % nshards == 0: + break + nshards -= 1 + ncores = nshards + elif shard_strategy == ttnn.ShardStrategy.WIDTH: + ## nsticks per shard should be divisible by in_w + max_nshards_w = min(num_channels, max_grid_size[1]) + nshards_w = max_nshards_w + while nshards_w > 0: + ## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B + if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0: + break + nshards_w -= 1 + ncores = nshards_w + elif shard_strategy == ttnn.ShardStrategy.BLOCK: + max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW + max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C + ## find nshards_h along NHW + nshards_h = max_nshards_h + while nshards_h > 0: + if batch_size * height % nshards_h == 0: + break + nshards_h -= 1 + ## find nshards_w along C + nshards_w = max_nshards_w + while nshards_w > 0: + ## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B + if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0: + break + nshards_w -= 1 + if nshards_w == 0 or nshards_h == 0: + raise ValueError("nshards_h or nshards_w is 0") + ncores = (nshards_h, nshards_w) + + shard_grid = get_shard_grid_from_num_cores(ncores) + shard_orientation = ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR + + if shard_strategy == ttnn.ShardStrategy.BLOCK: + tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED + elif shard_strategy == ttnn.ShardStrategy.HEIGHT: + tensor_memory_layout = ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED + elif shard_strategy == ttnn.ShardStrategy.WIDTH: + tensor_memory_layout = ttnn.types.TensorMemoryLayout.WIDTH_SHARDED + + ## input shard + if shard_strategy == ttnn.ShardStrategy.BLOCK: + shard_height = math.ceil(batch_size * height * width / ncores[0]) + shard_width = math.ceil(num_channels / ncores[1]) + elif shard_strategy == ttnn.ShardStrategy.HEIGHT: + shard_height = math.ceil(batch_size * height * width / ncores) + shard_width = num_channels + elif shard_strategy == ttnn.ShardStrategy.WIDTH: + shard_height = math.ceil(batch_size * height * width) + shard_width = math.ceil(num_channels / ncores) + shard_shape = (shard_height, shard_width) + + shard_spec = ttnn.experimental.tensor.ShardSpec(shard_grid, shard_shape, shard_orientation, False) + in_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec) + + ## output shard + shard_shape = (shard_height, shard_width) + shard_spec = ttnn.experimental.tensor.ShardSpec(shard_grid, shard_shape, shard_orientation, False) + + input_tensor = ttnn.from_torch(tt_input, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config) + + output_tensor = ttnn.silu(input_tensor, memory_config=in_sharded_mem_config) + output_tensor = ttnn.to_memory_config(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.to_torch(output_tensor) + + ## compare the results + assert_with_pcc(torch_result, output_tensor, 0.999) diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 44c09ffdea47..4b67abbac63a 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -257,18 +257,70 @@ void EltwiseUnary::validate(const std::vector &input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to eltwise unary need to be on device!"); TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands to eltwise unary need to be allocated in buffers on device!"); - TT_FATAL((input_tensor_a.get_layout() == Layout::TILE), "Inputs to eltwise unary must be tilized"); - TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Eltwise unary does not currently support sharding"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Eltwise unary does not currently support sharding"); } std::vector EltwiseUnary::compute_output_shapes(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return {input_tensor.get_legacy_shape()}; + const auto input_shape = input_tensor.get_legacy_shape().without_padding(); + + uint32_t out_n = input_shape[0]; + uint32_t out_h = input_shape[1]; + uint32_t out_w = input_shape[2]; + uint32_t out_c = input_shape[3]; + const auto out_dims = std::vector({ out_n, out_h, out_w, out_c }); //in the NHWC format + auto out_shape = Shape{out_dims}; + + return {out_shape}; } std::vector EltwiseUnary::create_output_tensors(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); + if (output_mem_config.is_sharded()) { + if (input_tensor.memory_config().is_sharded()) { + auto mem_config = output_mem_config; + auto input_shard_spec = input_tensor.memory_config().shard_spec.value(); + auto output_shape = compute_output_shapes(input_tensors).at(0); + if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + auto ncores = input_shard_spec.num_cores(); + array output_shard_shape = {div_up(output_shape[0] * output_shape[1] * output_shape[2], ncores), output_shape[-1]}; + auto output_shard_spec = input_shard_spec; + output_shard_spec.shape = output_shard_shape; + mem_config.shard_spec = output_shard_spec; + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); + return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)}; + }else if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + auto ncores = input_shard_spec.num_cores(); + array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2], div_up(output_shape[-1],ncores)}; + auto output_shard_spec = input_shard_spec; + output_shard_spec.shape = output_shard_shape; + mem_config.shard_spec = output_shard_spec; + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + log_debug(LogOp, "output_shard_spec: {}", output_shard_spec); + return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)}; + } + else if (input_tensor.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + auto shard_grid = input_shard_spec.grid.ranges(); + TT_FATAL(shard_grid.size() == 1, "Block sharded input should have only one CoreRange"); + auto core_range = *shard_grid.begin(); + uint32_t ncores_w = core_range.end.x + 1; + uint32_t ncores_h = core_range.end.y + 1; + // array output_shard_shape = {output_shape[0] * output_shape[1] * output_shape[2] / ncores_h, output_shape[-1] / ncores_w}; + // auto output_shard_spec = input_shard_spec; + // output_shard_spec.shape = output_shard_shape; + // mem_config.shard_spec = output_shard_spec; + auto output_shard_spec = mem_config.shard_spec.value(); + auto output_shard_shape = output_shard_spec.shape; + log_debug(LogOp, "ncores_w, ncores_h: {} {}", ncores_w, ncores_h); + log_debug(LogOp, "output_shard_shape: {}", output_shard_shape); + return {create_sharded_device_tensor(output_shape, input_tensor.get_dtype(), input_tensor.get_layout(), input_tensor.device(), mem_config)}; + } else { + TT_FATAL(false, "input memory config is not HEIGHT or WIDTH or BLOCK sharded"); + } + } else { + TT_FATAL(false, "Output memory config is sharded but input memory config is not sharded"); + } + } return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); } diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index 99ee4cd1dfaa..6826d2fbda79 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -149,6 +149,12 @@ inline Tensor run_eltwise_unary( TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); Shape pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); FormatParams input_format_params = {.pad_shape = pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; + if(output_mem_config.is_sharded() && (output_mem_config.memory_layout == + TensorMemoryLayout::HEIGHT_SHARDED || output_mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED || output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED)){ + return operation::run_without_autoformat( + EltwiseUnary{ops_chain, output_mem_config}, {input_tensor}) + .at(0); + } return operation::run_with_autoformat( EltwiseUnary{ops_chain, output_mem_config}, {input_tensor}, {input_format_params}, {Layout::TILE}) .at(0); diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/kernels/dataflow/reader_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/kernels/dataflow/reader_unary_op.cpp new file mode 100644 index 000000000000..b152174840b3 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/kernels/dataflow/reader_unary_op.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" + +void kernel_main() { + uint32_t num_tiles_per_core = get_arg_val(0); + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + + constexpr uint32_t onetile = 1; + for (uint32_t i = 0; i < num_tiles_per_core; ++ i) { + cb_push_back(cb_id_in0, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/multi_core/eltwise_unary_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/multi_core/eltwise_unary_op_multi_core.cpp index 2c146825f08b..1b4faa29b676 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/multi_core/eltwise_unary_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/multi_core/eltwise_unary_op_multi_core.cpp @@ -17,8 +17,220 @@ namespace tt { namespace tt_metal { +operation::ProgramWithCallbacks eltwise_unary_sharded_multi_core(const Tensor &input, Tensor &output, const std::vector op_chain){ + Program program = CreateProgram(); + Device *device = input.device(); + + auto shard_spec = input.shard_spec().value(); + auto all_cores = shard_spec.grid; + uint32_t ncores = shard_spec.num_cores(); + + auto out_shard_spec = output.shard_spec().value(); + TT_FATAL(out_shard_spec.num_cores() == ncores, "Output tensor should have same number of cores {} as input tensor {}", out_shard_spec.num_cores(), ncores); + + DataFormat act_df = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); //fix this later. the vaklue is already there. + DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + + uint32_t input_tile_size = tt::tt_metal::detail::TileSize(act_df); + uint32_t output_tile_size = tt::tt_metal::detail::TileSize(out_df); + + uint32_t num_tile_per_core = shard_spec.numel() * datum_size(act_df) /input_tile_size; + TT_ASSERT((shard_spec.numel() * datum_size(act_df)) % input_tile_size == 0, "Shard size should be multiple of the 1024"); + + uint32_t ncores_x, ncores_nhw; + if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + ncores_x = device->compute_with_storage_grid_size().x; + ncores_nhw = ncores; + }else if (input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + ncores_x = device->compute_with_storage_grid_size().x; + ncores_nhw = ncores; + }else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + ncores_x = all_cores.ranges().begin()->end.x + 1; + ncores_nhw = all_cores.ranges().begin()->end.y + 1; + } else { + TT_FATAL(false, "Unsupported sharding layout"); + } + + uint32_t in_cb_id = CB::c_in0; + uint32_t buffering_factor = 1; // data is already fully buffered in the CBs since its sharded + uint32_t aligned_input_tile_nbytes = round_up_to_mul32(input_tile_size); + uint32_t in_cb_pagesize = aligned_input_tile_nbytes; + uint32_t in_cb_npages = num_tile_per_core * buffering_factor; + CircularBufferConfig cb_src0_config = CircularBufferConfig( + in_cb_pagesize * in_cb_npages, + {{in_cb_id, act_df}}) + .set_page_size(in_cb_id, in_cb_pagesize) + .set_globally_allocated_address(*input.buffer()); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + // output sharded CB with upsampled data + uint32_t out_cb_id = CB::c_out0; + CircularBufferConfig out_cb_config = CircularBufferConfig( + in_cb_pagesize * in_cb_npages, + {{out_cb_id, out_df}}) + .set_page_size(out_cb_id, in_cb_pagesize) + .set_globally_allocated_address(*output.buffer()); + auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, out_cb_config); + + log_debug(LogOp, "input_cb: {}, npages: {}, pagesize: {}", in_cb_id, in_cb_npages, in_cb_pagesize); + log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x); + log_debug(LogOp, "input_tile_size: {}", input_tile_size); + + auto src_buffer = input.buffer(); + auto dst_buffer = output.buffer(); + + bool src_is_dram = src_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + in_cb_id, + out_cb_id, + }; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = { + (std::uint32_t) out_cb_id, + (std::uint32_t) dst_is_dram + }; + + CoreRange temp_core({0, 0}, {0, 0}); + std::map kernel_defines; + + tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/eltwise_unary/kernels/dataflow/reader_unary_op.cpp", + all_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); + + vector compute_kernel_args_group_1 = { + num_tile_per_core, // per_core_block_cnt + 1 // per_core_block_size + }; + + bool fp32_dest_acc_en = false; + bool math_approx_mode = std::all_of(op_chain.begin(), op_chain.end(), [](const auto& u) {return eltwise_unary_op_utils::get_op_approx_mode(u.op_type);}); + std::map unary_defines = eltwise_unary_op_utils::get_block_defines(op_chain); + auto eltwise_unary_kernel_group_1_id = tt_metal::CreateKernel( + program, + "tt_metal/kernels/compute/eltwise_sfpu.cpp", + all_cores, + tt_metal::ComputeConfig{ + .math_fidelity = MathFidelity::LoFi, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_kernel_args_group_1, + .defines = unary_defines + } + ); + + uint32_t start_input_stick_id = 0; + if(input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED){ + for (int32_t core = 0; core < ncores_nhw; ++core) { + CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical + tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core_coord, + { + (uint32_t)(num_tile_per_core), + } + ); + tt_metal::SetRuntimeArgs( + program, + eltwise_unary_kernel_group_1_id, + core_coord, + { + } + ); + } + }else if(input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED){ + for (int32_t core = 0; core < ncores_nhw; ++core) { + CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical + tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core_coord, + { + (uint32_t)(num_tile_per_core), + } + ); + tt_metal::SetRuntimeArgs( + program, + eltwise_unary_kernel_group_1_id, + core_coord, + { + } + ); + } + } + else if(input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED){ + ncores_nhw = 1; + ncores_x = 1; + for (int32_t core = 0; core < ncores_nhw; ++core) { + for (int32_t core_x = 0; core_x < ncores_x; ++core_x) { + CoreCoord core_coord(core_x, core); // logical + tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core_coord, + { + (uint32_t)(num_tile_per_core), + (uint32_t)(input_tile_size) + } + ); + tt_metal::SetRuntimeArgs( + program, + eltwise_unary_kernel_group_1_id, + core_coord, + { + } + ); + } + } + }else{ + TT_FATAL(false, "Only width, height and block memory is supported by this fuction"); + } + + auto override_runtime_args_callback = [ + unary_reader_kernel_id, + ncores_nhw, + ncores_x + ] + ( + const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers + ) { + + std::vector src_addrs(input_buffers.size()); + for(uint32_t i = 0; i < input_buffers.size(); i++) { + src_addrs[i] = input_buffers.at(0)->address(); + } + + auto dst_buffer = output_buffers.at(0); + + for (uint32_t i = 0, num_tiles_written = 0; i < ncores_nhw; i++){ + CoreCoord core = {i / ncores_x, i % ncores_nhw}; + + { + auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + std::copy(src_addrs.begin(), src_addrs.end(), runtime_args.begin() + 4); + } + + { + auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + + operation::ProgramWithCallbacks eltwise_unary_multi_core(const Tensor &a, Tensor &output, const std::vector op_chain) { tt_metal::Program program{}; + if(a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED || a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED){ + return eltwise_unary_sharded_multi_core(a, output, op_chain); + } tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 2375b5197aa8..ae8b3812394d 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -50,7 +50,7 @@ def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs) input_tensor, ranks=(2, 3, 4), dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), - layouts=(ttnn.TILE_LAYOUT,), + layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), can_be_on_device=True, can_be_on_cpu=False, )