diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py b/tests/ttnn/unit_tests/operations/test_moreh_getitem.py similarity index 95% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py rename to tests/ttnn/unit_tests/operations/test_moreh_getitem.py index df1608dab3f..f46f03813ca 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_getitem.py @@ -71,7 +71,7 @@ def test_getitem_RAW_MJOR_one_index(shape_index_dim, dtype, index_size, device): elif index_dim == 4: tt_cpu = x[:, :, :, :, idx] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, [dev_idx], [index_dim]) + tt_npu = ttnn.operations.moreh.getitem(dev_x, [dev_idx], [index_dim]) assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape) tt_dev = tt_npu.cpu().to_torch() @@ -132,7 +132,7 @@ def test_getitem_RAW_MAJOR_two_indices(shape_index_dims, dtype, index_size, devi tt_cpu = x[:, indices[0], indices[1]] if index_dims == (2, 3): tt_cpu = x[:, :, indices[0], indices[1]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape) tt_dev = tt_npu.cpu().to_torch() @@ -191,7 +191,7 @@ def test_getitem_RAW_MAJOR_three_indices(shape_index_dims, dtype, index_size, de tt_cpu = x[indices[0], indices[1], indices[2]] if index_dims == (1, 2, 3): tt_cpu = x[:, indices[0], indices[1], indices[2]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) assert list(tt_npu.get_legacy_shape()) == list(tt_cpu.shape) tt_dev = tt_npu.cpu().to_torch() @@ -300,7 +300,7 @@ def test_getitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major elif index_dim == 4: tt_cpu = x[:, :, :, :, idx] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, [dev_idx], [index_dim]) + tt_npu = ttnn.operations.moreh.getitem(dev_x, [dev_idx], [index_dim]) tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT) cpu_5d_shape = to_output_5d_shape(shape, [index_dim], index_size) @@ -392,7 +392,7 @@ def test_getitem_tilized_two_indices(shape_index_dims, dtype, index_size, row_ma if index_dims == (3, 4): tt_cpu = x[:, :, :, indices[0], indices[1]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT) output_5d_shape = to_output_5d_shape(shape, index_dims, index_size) @@ -478,7 +478,7 @@ def test_getitem_tilized_three_indices(shape_index_dims, dtype, index_size, row_ if index_dims == (2, 3, 4): tt_cpu = x[:, :, indices[0], indices[1], indices[2]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT) output_5d_shape = to_output_5d_shape(shape, index_dims, index_size) @@ -559,7 +559,7 @@ def test_getitem_tilized_four_indices(shape_index_dims, dtype, index_size, row_m if index_dims == (1, 2, 3, 4): tt_cpu = x[:, indices[0], indices[1], indices[2], indices[3]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) tt_npu = tt_npu.cpu().to(ttnn.Layout.ROW_MAJOR) output_5d_shape = to_output_5d_shape(shape, index_dims, index_size) @@ -634,7 +634,7 @@ def test_getitem_tilized_five_indices(shape_index_dims, dtype, index_size, row_m tt_cpu = x[indices[0], indices[1], indices[2], indices[3], indices[4]] - tt_npu = ttnn.experimental.operations.primary.moreh_getitem(dev_x, dev_indices, index_dims) + tt_npu = ttnn.operations.moreh.getitem(dev_x, dev_indices, index_dims) tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT) output_5d_shape = to_output_5d_shape(shape, index_dims, index_size) diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 71de4e73086..f46152cbf0b 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -342,6 +342,12 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp ) # Split src and python bindings diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt index 578bc9d81b3..83add399c48 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt @@ -77,9 +77,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/moreh_cumsum/moreh_cumsum_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_sgd/moreh_sgd.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp CACHE INTERNAL "tt_dnn sources to reuse in ttnn build" ) diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp deleted file mode 100644 index 5b782efb2d7..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp +++ /dev/null @@ -1,227 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/run_operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -void MorehGetitem::validate_with_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - // validate input tensor - auto& input_tensor = input_tensors.at(0); - auto input_layout = input_tensor.get_layout(); - - TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to getitem need to be on device!"); - TT_FATAL(input_tensor.buffer() != nullptr, "Operands to getitem need to be allocated in buffers on device!"); - auto dtype = input_tensor.get_dtype(); - TT_FATAL(dtype == DataType::INT32 || dtype == DataType::BFLOAT16, "Error"); - - // validate index tensors - uint32_t index_size = input_tensors.at(1).get_legacy_shape()[-1]; - for (uint32_t i = 1; i < input_tensors.size(); i++) { - auto& index_tensor = input_tensors.at(i); - TT_FATAL(index_tensor.storage_type() == StorageType::DEVICE, "Operands to getitem need to be on device!"); - TT_FATAL(index_tensor.buffer() != nullptr, "Operands to getitem need to be allocated in buffers on device!"); - TT_FATAL(index_tensor.get_dtype() == DataType::INT32, "Error"); - - auto index_shape = index_tensor.get_legacy_shape(); - auto index_layout = index_tensor.get_layout(); - if (index_layout == Layout::ROW_MAJOR) { - TT_FATAL(index_shape.rank() == 1, "Error"); - } else if (index_layout == Layout::TILE) { - TT_FATAL(index_shape.rank() == 5, "Error"); - } - TT_FATAL( - !(input_layout == Layout::ROW_MAJOR && index_layout == Layout::TILE), - "input layout ROW_MAJOR and index layout TILE not supported"); - TT_FATAL(index_size == index_shape[-1], "The shapes of all index tensors must be identical!"); - } - - if (input_layout == Layout::ROW_MAJOR) { - for (auto dim : this->index_dims) { - TT_FATAL(dim != 4, "getitem for ROW_MAJOR layout not support W index tensor!"); - } - } - - uint32_t dim_start = this->index_dims.front(); - uint32_t i = 0; - for (auto dim : this->index_dims) { - TT_FATAL( - dim_start + i == dim, - fmt::format("The value of index_dims={} must be consecutive integers.", this->index_dims)); - i++; - } - - if (output_tensors.empty() || !output_tensors.at(0).has_value()) { - // If the user decided to not use any optional output tensors, then this would be empty or would be a nullptr. - return; - } - TT_FATAL(output_tensors.size() == 1, "Must have 1 output tensor"); - TT_FATAL(dtype == output_tensors.front().value().get_dtype(), "Error"); -} - -std::vector MorehGetitem::compute_output_shapes(const std::vector& input_tensors) const { - auto input_shape = input_tensors.at(0).get_legacy_shape(); - auto output_shape = input_shape; - auto layout = input_tensors.at(0).get_layout(); - - if (layout == Layout::TILE) { - // compute output shape - // ex) - // input: (10, 20, 30, 40) - // index_tensor: [(100), (100)] - // index_dims = 1,2 - // output: (10, 1, 100, 40) - - auto dim_offset = 5 - input_shape.rank(); - auto dimensions_pads = std::vector(); - std::vector output_size_vec; - - for (int dim = 0; dim < output_shape.size(); dim++) { - dimensions_pads.push_back(output_shape.padding()[dim]); - output_size_vec.push_back(output_shape[dim]); - } - - auto index = input_tensors.at(1); - uint32_t index_size = index.get_legacy_shape()[-1]; - uint32_t index_size_without_padding = index.get_legacy_shape().without_padding()[-1]; - - uint32_t last_dim = this->index_dims.back() + dim_offset; - - for (uint32_t i = 0; i < this->index_dims.size(); i++) { - uint32_t out_put_dim = this->index_dims.at(i); - uint32_t dim = out_put_dim + dim_offset; - auto index = input_tensors.at(i + 1); - - if (dim == 3 || dim == 4) { - dimensions_pads[out_put_dim] = Padding::PadDimension{.front = 0, .back = 31}; - output_size_vec[out_put_dim] = 32; - } else { - output_size_vec[out_put_dim] = 1; - } - } - - if (last_dim == 3 || last_dim == 4) { - output_size_vec[this->index_dims.back()] = round_up_to_mul32(index_size); - uint32_t padding_back = round_up_to_mul32(index_size_without_padding) - index_size_without_padding; - dimensions_pads[this->index_dims.back()] = Padding::PadDimension{.front = 0, .back = padding_back}; - } else { - output_size_vec[this->index_dims.back()] = index_size_without_padding; - } - - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - output_shape = Shape(output_size_vec, padding); - } else { - // compute output shape - // ex) - // input: (10, 20, 30, 40) - // index_tensor: [(100), (100)] - // index_dims = 1,2 - // output: (10, 100, 40) - std::vector output_size_vec; - - auto input_shape = input_tensors.at(0).get_legacy_shape(); - uint32_t input_rank = input_shape.rank(); - - auto index = input_tensors.at(1); - uint32_t index_size = index.get_legacy_shape()[0]; - - uint32_t start_dim = this->index_dims.front(); - uint32_t last_dim = this->index_dims.back(); - for (uint32_t input_dim = 0; input_dim < input_rank; input_dim++) { - if (input_dim < start_dim) { - output_size_vec.push_back(input_shape[input_dim]); - } else if (start_dim == input_dim) { - output_size_vec.push_back(index_size); - } else if (last_dim < input_dim) { - output_size_vec.push_back(input_shape[input_dim]); - } - } - - output_shape = Shape(output_size_vec); - } - - return {output_shape}; -} - -std::vector MorehGetitem::create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - if (!output_tensors.empty() && output_tensors.at(0).has_value()) { - return {output_tensors.at(0).value()}; - } - - auto dtype = input_tensors.at(0).get_dtype(); - auto layout = input_tensors.at(0).get_layout(); - Tensor output = - operation::generic_create_output_tensors(*this, input_tensors, dtype, layout, this->output_mem_config).at(0); - - return {output}; -} - -operation::ProgramWithCallbacks MorehGetitem::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - auto& input = input_tensors.at(0); - auto& output = output_tensors.at(0); - - std::vector index_tensors; - for (uint32_t i = 1; i < input_tensors.size(); i++) { - index_tensors.push_back(input_tensors.at(i)); - } - - if (input.get_layout() == Layout::ROW_MAJOR) { - return {moreh_getitem_rm(input, index_tensors, this->index_dims, output, this->core_range)}; - } - - return {moreh_getitem_tilized(input, index_tensors, this->index_dims, output, this->core_range)}; -} - -Tensor moreh_getitem( - const Tensor& input_tensor, - const std::vector& index_tensors, - const std::vector& index_dims, - std::optional output_tensor, - const MemoryConfig& output_mem_config) { - auto device = input_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - std::vector new_input_tensors; - new_input_tensors.push_back(input_tensor); - new_input_tensors.insert(new_input_tensors.end(), index_tensors.begin(), index_tensors.end()); - - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output(new_input_tensors))}; - - operation::launch_op( - [index_dims, all_cores, output_mem_config]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehGetitem{.index_dims = index_dims, .core_range = all_cores, .output_mem_config = output_mem_config}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - new_input_tensors, - output_tensors, - {}, - {output_tensor}); - - return output_tensors.at(0); -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp deleted file mode 100644 index 1f63897e6fb..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include - -#include "ttnn/operation.hpp" -#include "ttnn/tensor/tensor.hpp" - -namespace tt { -namespace operations { -namespace primary { - -using namespace tt_metal; - -operation::ProgramWithCallbacks moreh_getitem_rm( - const Tensor &input, - const std::vector &index_tensors, - const std::vector &index_dims, - const Tensor &output, - const CoreRange core_range); - -operation::ProgramWithCallbacks moreh_getitem_tilized( - const Tensor &input, - const std::vector &index_tensors, - const std::vector &index_dims, - const Tensor &output, - const CoreRange core_range); - -struct MorehGetitem { - const std::vector index_dims; - const CoreRange core_range; // unused for now - const MemoryConfig output_mem_config; - - void validate_with_output_tensors( - const std::vector &input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors( - const std::vector &input_tensors, const std::vector> &output_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector &input_tensors, std::vector &output_tensors) const; - static constexpr auto attribute_names = std::make_tuple("index_dims", "output_mem_config"); - const auto attribute_values() const { - return std::make_tuple(std::cref(this->index_dims), std::cref(this->output_mem_config)); - } -}; - -Tensor moreh_getitem( - const Tensor &input_tensor, - const std::vector &index_tensors, - const std::vector &index_dims, - std::optional output_tensor = std::nullopt, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp index bd28d7408b5..574f0ec19b2 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp @@ -12,7 +12,6 @@ #include "ttnn/deprecated/tt_dnn/op_library/moreh_bmm_backward/moreh_bmm_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_groupnorm/moreh_groupnorm_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_groupnorm_backward/moreh_groupnorm_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp" @@ -510,15 +509,6 @@ void py_module(py::module& m_primary) { py::arg("compute_kernel_config").noconvert() = std::nullopt, "Performs mean backward operation. Returns an input_grad tensor."); - m_primary.def( - "moreh_getitem", - &moreh_getitem, - py::arg("input_tensor").noconvert(), - py::arg("index_tensors").noconvert(), - py::arg("index_dims").noconvert(), - py::arg("output_tensor").noconvert() = std::nullopt, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - "Performs a getitem operation. Returns an output tensor."); } } // namespace diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp new file mode 100644 index 00000000000..b9930777208 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_getitem_device_operation.hpp" + +#include + +#include "tt_dnn/op_library/moreh_helper_functions.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::moreh::moreh_getitem { +void MorehGetItemOperation::validate_inputs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor = tensor_args.input; + auto input_layout = input_tensor.get_layout(); + const auto& index_tensors = tensor_args.index_tensors; + const auto& output_tensor = tensor_args.output; + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to getitem need to be on device!"); + TT_FATAL(input_tensor.buffer() != nullptr, "Operands to getitem need to be allocated in buffers on device!"); + auto dtype = input_tensor.get_dtype(); + TT_FATAL(dtype == DataType::INT32 || dtype == DataType::BFLOAT16, "Input tensor must be of type INT32 or BFLOAT16!"); + + // validate index tensors + uint32_t index_size = index_tensors[0].get_shape()[-1]; + for (uint32_t i = 0; i < index_tensors.size(); i++) { + auto& index_tensor = index_tensors[i]; + TT_FATAL(index_tensor.storage_type() == StorageType::DEVICE, "Operands to getitem need to be on device!"); + TT_FATAL(index_tensor.buffer() != nullptr, "Operands to getitem need to be allocated in buffers on device!"); + TT_FATAL(index_tensor.get_dtype() == DataType::INT32, "Index tensor must be of type INT32!"); + + auto index_shape = index_tensor.get_shape(); + auto index_layout = index_tensor.get_layout(); + if (index_layout == Layout::ROW_MAJOR) { + TT_FATAL(index_shape.rank() == 1, "Index tensor must be 1D for ROW_MAJOR layout!"); + } else if (index_layout == Layout::TILE) { + TT_FATAL(index_shape.rank() == 5, "Index tensor must be 5D for TILE layout!"); + } + TT_FATAL( + !(input_layout == Layout::ROW_MAJOR && index_layout == Layout::TILE), + "input layout ROW_MAJOR and index layout TILE not supported"); + TT_FATAL(index_size == index_shape[-1], "The shapes of all index tensors must be identical!"); + } + + if (input_layout == Layout::ROW_MAJOR) { + for (auto dim : operation_attributes.index_dims) { + TT_FATAL(dim != 4, "getitem for ROW_MAJOR layout not support W index tensor!"); + } + } + + uint32_t dim_start = operation_attributes.index_dims.front(); + uint32_t i = 0; + for (auto dim : operation_attributes.index_dims) { + TT_FATAL( + dim_start + i == dim, + fmt::format("The value of index_dims={} must be consecutive integers.", operation_attributes.index_dims)); + i++; + } + if (!output_tensor.has_value()) { + // If the user decided to not use any optional output tensors, then this would be empty or would be a nullptr. + return; + } + TT_ASSERT(output_tensor->buffer() != nullptr, "Must have 1 output tensor."); + TT_FATAL(dtype == output_tensor.value().get_dtype(), "Output tensor must have the same dtype as input tensor!"); +} +MorehGetItemOperation::program_factory_t MorehGetItemOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + auto& input_tensor = tensor_args.input; + auto input_layout = input_tensor.get_layout(); + if (input_layout == Layout::ROW_MAJOR) { + return MorehGetItemRmFactory(); + } else { + return MorehGetItemTilizedFactory(); + } +} + +void MorehGetItemOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +void MorehGetItemOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +MorehGetItemOperation::shape_return_value_t MorehGetItemOperation::compute_output_shapes( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor = tensor_args.input; + const auto index_dims = operation_attributes.index_dims; + auto input_layout = input_tensor.get_layout(); + const auto& index_tensors = tensor_args.index_tensors; + auto input_shape = input_tensor.get_shape(); + auto output_shape = input_shape; + auto layout = input_tensor.get_layout(); + + if (layout == Layout::TILE) { + // compute output shape + // ex) + // input: (10, 20, 30, 40) + // index_tensor: [(100), (100)] + // index_dims = 1,2 + // output: (10, 1, 100, 40) + auto dim_offset = 5 - input_shape.rank(); + auto dimensions_pads = std::vector(); + std::vector output_size_vec; + for (int dim = 0; dim < output_shape.size(); dim++) { + dimensions_pads.push_back(output_shape.value.padding()[dim]); + output_size_vec.push_back(output_shape.value[dim]); + } + + auto index = index_tensors[0]; + uint32_t index_size = index.get_shape()[-1]; + uint32_t index_size_without_padding = index.get_shape().value.without_padding()[-1]; + + uint32_t last_dim = index_dims.back() + dim_offset; + + for (uint32_t i = 0; i < index_dims.size(); i++) { + uint32_t out_put_dim = index_dims[i]; + uint32_t dim = out_put_dim + dim_offset; + auto index = index_tensors[i]; + + if (dim == 3 || dim == 4) { + dimensions_pads[out_put_dim] = Padding::PadDimension{.front = 0, .back = 31}; + output_size_vec[out_put_dim] = 32; + } else { + output_size_vec[out_put_dim] = 1; + } + } + + if (last_dim == 3 || last_dim == 4) { + output_size_vec[index_dims.back()] = round_up_to_mul32(index_size); + uint32_t padding_back = round_up_to_mul32(index_size_without_padding) - index_size_without_padding; + dimensions_pads[index_dims.back()] = Padding::PadDimension{.front = 0, .back = padding_back}; + } else { + output_size_vec[index_dims.back()] = index_size_without_padding; + } + + const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); + output_shape = Shape(tt::tt_metal::Shape(output_size_vec, padding)); + + } else { + // compute output shape + // ex) + // input: (10, 20, 30, 40) + // index_tensor: [(100), (100)] + // index_dims = 1,2 + // output: (10, 100, 40) + std::vector output_size_vec; + + auto input_shape = input_tensor.get_shape(); + uint32_t input_rank = input_shape.rank(); + + auto index = index_tensors[0]; + uint32_t index_size = index.get_shape()[0]; + + uint32_t start_dim = operation_attributes.index_dims.front(); + uint32_t last_dim = operation_attributes.index_dims.back(); + for (uint32_t input_dim = 0; input_dim < input_rank; input_dim++) { + if (input_dim < start_dim) { + output_size_vec.push_back(input_shape[input_dim]); + } else if (start_dim == input_dim) { + output_size_vec.push_back(index_size); + } else if (last_dim < input_dim) { + output_size_vec.push_back(input_shape[input_dim]); + } + } + + output_shape = Shape(output_size_vec); + } + return {output_shape}; +}; + +MorehGetItemOperation::tensor_return_value_t MorehGetItemOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + if (tensor_args.output.has_value()) { + log_debug(tt::LogOp, "{}:{} use output tensor", __func__, __LINE__); + return {tensor_args.output.value()}; + } + log_debug(tt::LogOp, "{}:{} create output tensor", __func__, __LINE__); + const auto& output_shape = compute_output_shapes(operation_attributes, tensor_args); + return create_device_tensor( + output_shape, + tensor_args.input.get_dtype(), + tensor_args.input.get_layout(), + tensor_args.input.device(), + operation_attributes.output_memory_config); +}; + +std::tuple +MorehGetItemOperation::invoke( + const Tensor& input, + const std::vector& index_tensors, + const std::vector index_dims, + const std::optional& output, + const std::optional output_memory_config) { + operation_attributes_t operation_attributes = {index_dims, output_memory_config.value_or(input.memory_config())}; + tensor_args_t tensor_args = {input, index_tensors, output}; + return {operation_attributes, tensor_args}; +} +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp new file mode 100644 index 00000000000..274cdfc196c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp @@ -0,0 +1,100 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/tensor/types.hpp" + +namespace ttnn::operations::moreh::moreh_getitem { +struct MorehGetItemOperation { + struct operation_attributes_t { + const std::vector index_dims; + // const CoreRange core_range; + const MemoryConfig output_memory_config; + }; + + struct tensor_args_t { + const Tensor& input; + const std::vector& index_tensors; + const std::optional& output; + }; + + using shape_return_value_t = ttnn::Shape; + using tensor_return_value_t = Tensor; + + struct MorehGetItemRmFactory { + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + std::size_t num_cores; + uint32_t core_h; + std::vector index_dims; + uint32_t input_dim_offset; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + struct MorehGetItemTilizedFactory { + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + std::size_t num_cores; + uint32_t core_h; + std::vector index_dims; + uint32_t input_dim_offset; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; + + static void validate_inputs(const operation_attributes_t&, const tensor_args_t&); + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + static std::tuple invoke( + const Tensor& input, + const std::vector& index_tensors, + const std::vector index_dims, + const std::optional& output, + // const CoreRange core_range, + const std::optional output_memory_config); +}; +} // namespace ttnn::operations::moreh::moreh_getitem + +namespace ttnn::prim { +constexpr auto moreh_getitem = ttnn:: + register_operation<"ttnn::prim::moreh_getitem", ttnn::operations::moreh::moreh_getitem::MorehGetItemOperation>(); +} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/reader_moreh_getitem.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/reader_moreh_getitem.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/reader_moreh_getitem.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/reader_moreh_getitem.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/writer_moreh_getitem.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/writer_moreh_getitem.cpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/writer_moreh_getitem.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/writer_moreh_getitem.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp similarity index 52% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp index 1fea8bba490..66a55474d24 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/moreh_getitem_rm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp @@ -2,42 +2,38 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/run_operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp" +#include "moreh_getitem_device_operation.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" -#include "tt_metal/host_api.hpp" -#include "ttnn/tensor/tensor_impl.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -struct IndexInfo -{ + +struct IndexInfo { bool is_defined; bool is_dram; uint32_t address; uint32_t unit_size; }; -operation::ProgramWithCallbacks moreh_getitem_rm( - const Tensor &input, - const std::vector &index_tensors, - const std::vector &index_dims, - const Tensor &output, - const CoreRange core_range) { - - log_debug(LogTest, "moreh_getitem_rm"); - - auto input_shape = input.get_legacy_shape(); - auto output_shape = output.get_legacy_shape(); +namespace ttnn::operations::moreh::moreh_getitem { +MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOperation::MorehGetItemRmFactory::create( + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &output_tensor) { + using namespace tt; + using namespace tt::tt_metal; + using namespace tt::operations::primary; + + auto input = tensor_args.input; + auto index_tensors = tensor_args.index_tensors; + auto output = output_tensor; + auto index_dims = operation_attributes.index_dims; + auto output_memory_config = operation_attributes.output_memory_config; + // auto core_range = operation_attributes.core_range; + auto device = input.device(); + auto grid_coord = device->compute_with_storage_grid_size(); + const CoreRange allCores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); + auto core_range = allCores; + + auto input_shape = input.get_shape(); + auto output_shape = output.get_shape(); std::array new_input_shape{}; std::array new_output_shape{}; @@ -59,30 +55,29 @@ operation::ProgramWithCallbacks moreh_getitem_rm( uint32_t index_end_dim = index_dims.back(); Tensor input_5d = input; - input_5d = input_5d.reshape(input_5d_shape); + input_5d = input_5d.reshape(input_5d_shape.value); - auto input_5d_shape_without_padding = input_5d_shape.without_padding(); + auto input_5d_shape_without_padding = input_5d_shape.value.without_padding(); IndexInfo index_info[5] = {0}; - for (uint32_t i = 0 ; i < index_tensors.size(); i++) { + for (uint32_t i = 0; i < index_tensors.size(); i++) { auto dim = index_dims[i] + input_dim_offset; - auto index = index_tensors.at(i); + auto index = index_tensors[i]; index_info[dim].is_defined = true; - index_info[dim].address = index_tensors.at(i).buffer()->address(); - index_info[dim].is_dram = is_dram(index_tensors.at(i)); - index_info[dim].unit_size = index.get_legacy_shape()[-1] * index.element_size(); + index_info[dim].address = index_tensors[i].buffer()->address(); + index_info[dim].is_dram = is_dram(index_tensors[i]); + index_info[dim].unit_size = index.get_shape().value[-1] * index.element_size(); } - uint32_t index_size = index_tensors.front().get_legacy_shape()[-1]; + uint32_t index_size = index_tensors.front().get_shape().value[-1]; uint32_t input_unit_size = input_5d_shape[-1] * input_5d.element_size(); uint32_t output_unit_size = input_unit_size; // split work uint32_t num_units = output.volume() / output_shape[-1]; - log_debug(LogTest, "num_units {}", num_units); uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; @@ -93,34 +88,35 @@ operation::ProgramWithCallbacks moreh_getitem_rm( Program program = Program(); // create circular buffers - auto src_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto index_cb_data_format = tt_metal::datatype_to_dataformat_converter(index_tensors.at(0).get_dtype()); - auto output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + auto src_cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); + auto index_cb_data_format = datatype_to_dataformat_converter(index_tensors[0].get_dtype()); + auto output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); - auto src_cb_index = tt::CB::c_in0; + auto src_cb_index = CB::c_in0; auto rounded_input_page_size = round_up_to_mul32(input_unit_size); - auto cb_src0_config = tt_metal::CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) - .set_page_size(src_cb_index, rounded_input_page_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + auto cb_src0_config = + CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) + .set_page_size(src_cb_index, rounded_input_page_size); + auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config); for (uint32_t dim = 0; dim < 5; dim++) { if (!index_info[dim].is_defined) continue; - auto src1_cb_index = tt::CB::c_in1 + dim; + auto src1_cb_index = CB::c_in1 + dim; auto index_page_size = round_up_to_mul32(index_info[dim].unit_size); - auto cb_index_config = tt_metal::CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) - .set_page_size(src1_cb_index, index_page_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_index_config); + auto cb_index_config = + CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) + .set_page_size(src1_cb_index, index_page_size); + auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config); } - auto out_cb_index = tt::CB::c_out0; + auto out_cb_index = CB::c_out0; auto rounded_output_page_size = round_up_to_mul32(input_unit_size); auto cb_out0_config = - tt_metal::CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}}) + CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}}) .set_page_size(out_cb_index, rounded_input_page_size); - auto cb_out0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_out0_config); - + auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config); // create read/wrtie kernel auto src_is_dram = is_dram(input_5d); @@ -131,7 +127,7 @@ operation::ProgramWithCallbacks moreh_getitem_rm( auto reader_kernel_id = CreateReadKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/reader_moreh_getitem.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/reader_moreh_getitem.cpp", all_cores, { src_is_dram, @@ -144,15 +140,15 @@ operation::ProgramWithCallbacks moreh_getitem_rm( reader_defines); auto writer_kernel_id = CreateWriteKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_rm/kernels/writer_moreh_getitem.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_kernels/writer_moreh_getitem.cpp", all_cores, {dst_is_dram}, writer_defines); uint32_t input_stick_idx_stride_h = 1; - uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.without_padding()[3]; - uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.without_padding()[2]; - uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.without_padding()[1]; + uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.value.without_padding()[3]; + uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.value.without_padding()[2]; + uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.value.without_padding()[1]; // Set Runtime Args auto core_x_offset = core_range.start_coord.x; @@ -233,48 +229,53 @@ operation::ProgramWithCallbacks moreh_getitem_rm( start_id += num_units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h, index_dims, input_dim_offset]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(output_buffers.size() == 1); - - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - IndexInfo index_info[5] = {0}; - - for (uint32_t i = 0; i < index_dims.size(); i++) { - auto dim = index_dims[i] + input_dim_offset; - auto index_buffer = input_buffers.at(i + 1); - - index_info[dim].address = index_buffer->address(); - } - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = index_info[0].address; - runtime_args[2] = index_info[1].address; - runtime_args[3] = index_info[2].address; - runtime_args[4] = index_info[3].address; - runtime_args[5] = index_info[4].address; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return {std::move(program), {reader_kernel_id, writer_kernel_id, num_cores, core_h, index_dims, input_dim_offset}}; } -} // namespace primary -} // namespace operations -} // namespace tt +void MorehGetItemOperation::MorehGetItemRmFactory::override_runtime_arguments( + cached_program_t &cached_program, + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &tensor_return_value) { + auto &program = cached_program.program; + auto &reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto &writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto num_cores = cached_program.shared_variables.num_cores; + auto core_h = cached_program.shared_variables.core_h; + auto index_dims = cached_program.shared_variables.index_dims; + auto input_dim_offset = cached_program.shared_variables.input_dim_offset; + + TT_ASSERT(tensor_return_value.buffer()->size() == 1); + + auto src_buffer = tensor_args.input.buffer(); + auto dst_buffer = tensor_return_value.buffer(); + auto index_tensors = tensor_args.index_tensors; + IndexInfo index_info[5] = {0}; + + for (uint32_t i = 0; i < index_dims.size(); i++) { + auto dim = index_dims[i] + input_dim_offset; + auto index_buffer = index_tensors[i]; + + index_info[dim].address = index_buffer.buffer()->address(); + } + + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = index_info[0].address; + runtime_args[2] = index_info[1].address; + runtime_args[3] = index_info[2].address; + runtime_args[4] = index_info[3].address; + runtime_args[5] = index_info[4].address; + } + + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + } +} +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_factory.cpp similarity index 61% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_factory.cpp index 41d42b6da35..92f15b704ec 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/moreh_getitem_tilized.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_factory.cpp @@ -2,22 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/run_operation.hpp" -#include "ttnn/tensor/tensor_impl.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" -#include "tt_metal/host_api.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; +#include +#include -namespace tt { -namespace operations { -namespace primary { +#include "moreh_getitem_device_operation.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" struct IndexInfo { bool is_defined; @@ -26,18 +15,33 @@ struct IndexInfo { uint32_t unit_size; }; -operation::ProgramWithCallbacks moreh_getitem_tilized( - const Tensor &input, - const std::vector &index_tensors, - const std::vector &index_dims, - const Tensor &output, - const CoreRange core_range) { - log_debug(LogTest, "moreh_getitem_tilized"); - - auto input_shape = input.get_legacy_shape(); - auto input_shape_without_padding = input_shape.without_padding(); - auto output_shape = output.get_legacy_shape(); - auto output_shape_without_padding = output_shape.without_padding(); +namespace ttnn::operations::moreh::moreh_getitem { +MorehGetItemOperation::MorehGetItemTilizedFactory::cached_program_t +MorehGetItemOperation::MorehGetItemTilizedFactory::create( + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &output_tensor) { + using namespace tt; + using namespace tt::tt_metal; + using namespace tt::operations::primary; + + auto input = tensor_args.input; + auto index_tensors = tensor_args.index_tensors; + auto output = output_tensor; + auto index_dims = operation_attributes.index_dims; + auto output_memory_config = operation_attributes.output_memory_config; + auto TILE_HEIGHT = constants::TILE_HEIGHT; + auto TILE_WIDTH = constants::TILE_WIDTH; + // auto core_range = operation_attributes.core_range; + auto device = input.device(); + auto grid_coord = device->compute_with_storage_grid_size(); + const CoreRange allCores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); + auto core_range = allCores; + + auto input_shape = input.get_shape(); + auto input_shape_without_padding = input_shape.value.without_padding(); + auto output_shape = output.get_shape(); + auto output_shape_without_padding = output_shape.value.without_padding(); std::array new_input_shape{}; std::array new_output_shape{}; @@ -49,7 +53,7 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( auto input_dim_offset = 5 - input_shape.rank(); for (auto index = 0; index < input_shape.rank(); index++) { new_input_shape[index + input_dim_offset] = input_shape_without_padding[index]; - new_input_padded_shape[index + input_dim_offset] = input_shape[index]; + new_input_padded_shape[index + input_dim_offset] = input_shape.value[index]; } new_output_shape.fill(1); @@ -57,7 +61,7 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( auto output_dim_offset = 5 - input_shape.rank(); for (auto index = 0; index < output_shape.rank(); index++) { new_output_shape[index + output_dim_offset] = output_shape_without_padding[index]; - new_output_padded_shape[index + output_dim_offset] = output_shape[index]; + new_output_padded_shape[index + output_dim_offset] = output_shape.value[index]; } Shape input_5d_shape(new_input_shape, new_input_padded_shape); @@ -70,8 +74,8 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( } } - auto input_5d_without_padding = input_5d_shape.without_padding(); - auto output_5d_without_padding = output_5d_shape.without_padding(); + auto input_5d_shape_without_padding = input_5d_shape.value.without_padding(); + auto output_5d_shape_without_padding = output_5d_shape.value.without_padding(); auto index_layout = index_tensors.front().get_layout(); bool is_row_major_index = (index_layout == Layout::ROW_MAJOR); @@ -82,7 +86,7 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( for (uint32_t i = 0; i < index_tensors.size(); i++) { auto dim = index_dims[i] + input_dim_offset; - auto index = index_tensors.at(i); + auto index = index_tensors[i]; index_info[dim].is_defined = true; index_info[dim].address = index.buffer()->address(); @@ -90,20 +94,17 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( index_info[dim].unit_size = index.element_size(); } - uint32_t index_size = index_tensors.at(0).get_legacy_shape().without_padding()[-1]; + uint32_t index_size = index_tensors[0].get_shape().value.without_padding()[-1]; uint32_t input_unit_size = input.element_size(); uint32_t output_unit_size = output.element_size(); - // split work - auto input_5d_shape_without_padding = input_5d_shape.without_padding(); - auto output_5d_shape_without_padding = output_5d_shape.without_padding(); uint32_t alignment_size = 32; uint32_t num_elements_per_alignment = alignment_size / output_unit_size; uint32_t num_units = - output_5d_shape_without_padding[0] * output_5d_shape_without_padding[1] * output_5d_shape_without_padding[2] * - output_5d_shape_without_padding[3] * ((output_5d_shape_without_padding[4] + num_elements_per_alignment - 1) / num_elements_per_alignment); - log_debug(LogTest, "num_units {}", num_units); + output_5d_shape_without_padding[0] * output_5d_shape_without_padding[1] * + output_5d_shape_without_padding[2] * output_5d_shape_without_padding[3] * + ((output_5d_shape_without_padding[4] + num_elements_per_alignment - 1) / num_elements_per_alignment); uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; @@ -115,41 +116,41 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( Program program = Program(); // create circular buffers - auto src_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto index_cb_data_format = tt_metal::datatype_to_dataformat_converter(index_tensors.at(0).get_dtype()); - auto output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + auto src_cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); + auto index_cb_data_format = datatype_to_dataformat_converter(index_tensors[0].get_dtype()); + auto output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); - auto src_cb_index = tt::CB::c_in0; + auto src_cb_index = CB::c_in0; auto rounded_input_page_size = round_up_to_mul32(input_unit_size); auto cb_src0_config = - tt_metal::CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) + CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) .set_page_size(src_cb_index, rounded_input_page_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config); for (uint32_t dim = 0; dim < 5; dim++) { if (!index_info[dim].is_defined) continue; - auto src1_cb_index = tt::CB::c_in1 + dim; + auto src1_cb_index = CB::c_in1 + dim; auto index_page_size = 1024 * 4; auto cb_index_config = - tt_metal::CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) + CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) .set_page_size(src1_cb_index, index_page_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_index_config); + auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config); } - auto out_cb0_index = tt::CB::c_out0; + auto out_cb0_index = CB::c_out0; auto rounded_output_page_size = round_up_to_mul32(output_unit_size); auto cb_out0_config = - tt_metal::CircularBufferConfig(rounded_output_page_size, {{out_cb0_index, output_cb_data_format}}) + CircularBufferConfig(rounded_output_page_size, {{out_cb0_index, output_cb_data_format}}) .set_page_size(out_cb0_index, rounded_output_page_size); - auto cb_out0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_out0_config); + auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config); - auto out_cb1_index = tt::CB::c_out1; + auto out_cb1_index = CB::c_out1; auto cb_out1_config = - tt_metal::CircularBufferConfig(rounded_output_page_size, {{out_cb1_index, output_cb_data_format}}) + CircularBufferConfig(rounded_output_page_size, {{out_cb1_index, output_cb_data_format}}) .set_page_size(out_cb1_index, rounded_output_page_size); - auto cb_out1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_out1_config); + auto cb_out1 = CreateCircularBuffer(program, all_cores, cb_out1_config); // create read/wrtie kernel auto src_is_dram = is_dram(input); @@ -166,7 +167,8 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( auto reader_kernel_id = CreateReadKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize_w.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/" + "reader_moreh_getitem_tilize_w.cpp", all_cores, { src_is_dram, @@ -179,7 +181,8 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( reader_defines); auto writer_kernel_id = CreateWriteKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize_w.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/" + "writer_moreh_getitem_tilize_w.cpp", all_cores, {dst_is_dram}, writer_defines); @@ -189,19 +192,19 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( uint32_t num_alignment_width = div_up(output_5d_shape_without_padding[4], num_elements_per_alignment); uint32_t output_num_stick_width = div_up(output_5d_shape_without_padding[4], face_width); - uint32_t input_num_tile_c = input_5d_shape[1]; - uint32_t input_num_tile_d = input_5d_shape[2]; - uint32_t input_num_tile_height = input_5d_shape[3] / TILE_HEIGHT; - uint32_t input_num_tile_width = input_5d_shape[4] / TILE_WIDTH; + uint32_t input_num_tile_c = input_5d_shape.value[1]; + uint32_t input_num_tile_d = input_5d_shape.value[2]; + uint32_t input_num_tile_height = input_5d_shape.value[3] / TILE_HEIGHT; + uint32_t input_num_tile_width = input_5d_shape.value[4] / TILE_WIDTH; uint32_t input_noc_id_stride_h = input_num_tile_width; uint32_t input_noc_id_stride_d = input_noc_id_stride_h * input_num_tile_height; uint32_t input_noc_id_stride_c = input_noc_id_stride_d * input_num_tile_d; uint32_t input_noc_id_stride_n = input_noc_id_stride_c * input_num_tile_c; - uint32_t output_num_tile_c = output_5d_shape[1]; - uint32_t output_num_tile_d = output_5d_shape[2]; - uint32_t output_num_tile_height = output_5d_shape[3] / TILE_HEIGHT; - uint32_t output_num_tile_width = output_5d_shape[4] / TILE_WIDTH; + uint32_t output_num_tile_c = output_5d_shape.value[1]; + uint32_t output_num_tile_d = output_5d_shape.value[2]; + uint32_t output_num_tile_height = output_5d_shape.value[3] / TILE_HEIGHT; + uint32_t output_num_tile_width = output_5d_shape.value[4] / TILE_WIDTH; uint32_t output_noc_id_stride_h = output_num_tile_width; uint32_t output_noc_id_stride_d = output_noc_id_stride_h * output_num_tile_height; @@ -210,9 +213,9 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( uint32_t input_stick_idx_stride_w = 1; uint32_t input_stick_idx_stride_h = input_num_stick_width; - uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.without_padding()[3]; - uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.without_padding()[2]; - uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.without_padding()[1]; + uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.value.without_padding()[3]; + uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.value.without_padding()[2]; + uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.value.without_padding()[1]; // Set Runtime Args auto core_x_offset = core_range.start_coord.x; @@ -314,74 +317,31 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( start_id += num_units_per_core; } - auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, - writer_kernel_id = writer_kernel_id, - num_cores, - core_h, - index_dims, - input_dim_offset]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(output_buffers.size() == 1); - - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - IndexInfo index_info[5] = {0}; - - for (uint32_t i = 0; i < index_dims.size(); i++) { - auto dim = index_dims[i] + input_dim_offset; - auto index_buffer = input_buffers.at(i + 1); - - index_info[dim].address = index_buffer->address(); - } - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = index_info[0].address; - runtime_args[2] = index_info[1].address; - runtime_args[3] = index_info[2].address; - runtime_args[4] = index_info[3].address; - runtime_args[4] = index_info[4].address; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + std::move(program), {reader_kernel_id, writer_kernel_id, num_cores, core_h, index_dims, input_dim_offset}}; + } else { // compute index info + IndexInfo index_info[5] = {0}; for (uint32_t i = 0; i < index_tensors.size(); i++) { auto dim = index_dims[i] + input_dim_offset; - auto index = index_tensors.at(i); + auto index = index_tensors[i]; index_info[dim].is_defined = true; - index_info[dim].address = index_tensors.at(i).buffer()->address(); - index_info[dim].is_dram = is_dram(index_tensors.at(i)); - index_info[dim].unit_size = index.get_legacy_shape()[-1] * index.element_size(); + index_info[dim].address = index_tensors[i].buffer()->address(); + index_info[dim].is_dram = is_dram(index_tensors[i]); + index_info[dim].unit_size = index.get_shape().value[-1] * index.element_size(); } - uint32_t index_size = index_tensors.at(0).get_legacy_shape().without_padding()[-1]; + uint32_t index_size = index_tensors[0].get_shape().value.without_padding()[-1]; uint32_t input_unit_size = 16 * input.element_size(); uint32_t output_unit_size = 16 * output.element_size(); - // split work - auto input_5d_shape_without_padding = input_5d_shape.without_padding(); - auto output_5d_shape_without_padding = output_5d_shape.without_padding(); uint32_t num_units = output_5d_shape_without_padding[0] * output_5d_shape_without_padding[1] * - output_5d_shape_without_padding[2] * output_5d_shape_without_padding[3] * ((output_5d_shape_without_padding[4] + 15) / 16); - log_debug("num_units {}", num_units); + output_5d_shape_without_padding[2] * output_5d_shape_without_padding[3] * + ((output_5d_shape_without_padding[4] + 15) / 16); uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; @@ -393,36 +353,36 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( Program program = Program(); // create circular buffers - auto src_cb_data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto index_cb_data_format = tt_metal::datatype_to_dataformat_converter(index_tensors.at(0).get_dtype()); - auto output_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + auto src_cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); + auto index_cb_data_format = datatype_to_dataformat_converter(index_tensors[0].get_dtype()); + auto output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); - auto src_cb_index = tt::CB::c_in0; + auto src_cb_index = CB::c_in0; auto rounded_input_page_size = round_up_to_mul32(input_unit_size); auto cb_src0_config = - tt_metal::CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) + CircularBufferConfig(rounded_input_page_size, {{src_cb_index, src_cb_data_format}}) .set_page_size(src_cb_index, rounded_input_page_size); - auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config); for (uint32_t dim = 0; dim < 5; dim++) { if (!index_info[dim].is_defined) continue; - auto src1_cb_index = tt::CB::c_in1 + dim; + auto src1_cb_index = CB::c_in1 + dim; // auto index_page_size = round_up_to_mul32(index_info[dim].unit_size); auto index_page_size = 1024 * 4; auto cb_index_config = - tt_metal::CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) + CircularBufferConfig(index_page_size, {{src1_cb_index, index_cb_data_format}}) .set_page_size(src1_cb_index, index_page_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_index_config); + auto cb_src1 = CreateCircularBuffer(program, all_cores, cb_index_config); } - auto out_cb_index = tt::CB::c_out0; + auto out_cb_index = CB::c_out0; auto rounded_output_page_size = round_up_to_mul32(input_unit_size); auto cb_out0_config = - tt_metal::CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}}) + CircularBufferConfig(rounded_input_page_size, {{out_cb_index, output_cb_data_format}}) .set_page_size(out_cb_index, rounded_input_page_size); - auto cb_out0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_out0_config); + auto cb_out0 = CreateCircularBuffer(program, all_cores, cb_out0_config); // create read/wrtie kernel auto src_is_dram = is_dram(input); @@ -439,7 +399,8 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( auto reader_kernel_id = CreateReadKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/" + "reader_moreh_getitem_tilize.cpp", all_cores, { src_is_dram, @@ -452,7 +413,8 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( reader_defines); auto writer_kernel_id = CreateWriteKernel( program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize.cpp", + "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/" + "writer_moreh_getitem_tilize.cpp", all_cores, {dst_is_dram}, writer_defines); @@ -461,20 +423,19 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( uint32_t input_num_stick_width = div_up(input_5d_shape_without_padding[4], face_width); uint32_t output_num_stick_width = div_up(output_5d_shape_without_padding[4], face_width); - uint32_t input_num_tile_c = input_5d_shape[1]; - uint32_t input_num_tile_d = input_5d_shape[2]; - uint32_t input_num_tile_height = input_5d_shape[3] / TILE_HEIGHT; - uint32_t input_num_tile_width = input_5d_shape[4] / TILE_WIDTH; - + uint32_t input_num_tile_c = input_5d_shape.value[1]; + uint32_t input_num_tile_d = input_5d_shape.value[2]; + uint32_t input_num_tile_height = input_5d_shape.value[3] / TILE_HEIGHT; + uint32_t input_num_tile_width = input_5d_shape.value[4] / TILE_WIDTH; uint32_t input_noc_id_stride_h = input_num_tile_width; uint32_t input_noc_id_stride_d = input_noc_id_stride_h * input_num_tile_height; uint32_t input_noc_id_stride_c = input_noc_id_stride_d * input_num_tile_d; uint32_t input_noc_id_stride_n = input_noc_id_stride_c * input_num_tile_c; - uint32_t output_num_tile_c = output_5d_shape[1]; - uint32_t output_num_tile_d = output_5d_shape[2]; - uint32_t output_num_tile_height = output_5d_shape[3] / TILE_HEIGHT; - uint32_t output_num_tile_width = output_5d_shape[4] / TILE_WIDTH; + uint32_t output_num_tile_c = output_5d_shape.value[1]; + uint32_t output_num_tile_d = output_5d_shape.value[2]; + uint32_t output_num_tile_height = output_5d_shape.value[3] / TILE_HEIGHT; + uint32_t output_num_tile_width = output_5d_shape.value[4] / TILE_WIDTH; uint32_t output_noc_id_stride_h = output_num_tile_width; uint32_t output_noc_id_stride_d = output_noc_id_stride_h * output_num_tile_height; @@ -483,14 +444,13 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( uint32_t input_stick_idx_stride_w = 1; uint32_t input_stick_idx_stride_h = input_num_stick_width; - uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.without_padding()[3]; - uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.without_padding()[2]; - uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.without_padding()[1]; + uint32_t input_stick_idx_stride_d = input_stick_idx_stride_h * input_5d_shape.value.without_padding()[3]; + uint32_t input_stick_idx_stride_c = input_stick_idx_stride_d * input_5d_shape.value.without_padding()[2]; + uint32_t input_stick_idx_stride_n = input_stick_idx_stride_c * input_5d_shape.value.without_padding()[1]; // Set Runtime Args auto core_x_offset = core_range.start_coord.x; auto core_y_offset = core_range.start_coord.y; - uint32_t g1_numcores = core_group_1.num_cores(); uint32_t g2_numcores = core_group_2.num_cores(); @@ -550,13 +510,12 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( output_5d_shape_without_padding[4], output_num_stick_width, - //etc + // etc start_id, num_units_per_core, input_unit_size, input.element_size(), }; - vector writer_args = { // buffers output.buffer()->address(), @@ -584,53 +543,55 @@ operation::ProgramWithCallbacks moreh_getitem_tilized( start_id += num_units_per_core; } - auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, - writer_kernel_id = writer_kernel_id, - num_cores, - core_h, - index_dims, - input_dim_offset]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(output_buffers.size() == 1); - - auto src_buffer = input_buffers.at(0); - auto dst_buffer = output_buffers.at(0); - - IndexInfo index_info[5] = {0}; - - for (uint32_t i = 0; i < index_dims.size(); i++) { - auto dim = index_dims[i] + input_dim_offset; - auto index_buffer = input_buffers.at(i + 1); - - index_info[dim].address = index_buffer->address(); - } - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - runtime_args[1] = index_info[0].address; - runtime_args[2] = index_info[1].address; - runtime_args[3] = index_info[2].address; - runtime_args[4] = index_info[3].address; - runtime_args[5] = index_info[4].address; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + + return { + std::move(program), {reader_kernel_id, writer_kernel_id, num_cores, core_h, index_dims, input_dim_offset}}; } } -} // namespace primary -} // namespace operations -} // namespace tt +void MorehGetItemOperation::MorehGetItemTilizedFactory::override_runtime_arguments( + cached_program_t &cached_program, + const operation_attributes_t &operation_attributes, + const tensor_args_t &tensor_args, + tensor_return_value_t &tensor_return_value) { + auto &program = cached_program.program; + auto &reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto &writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto num_cores = cached_program.shared_variables.num_cores; + auto core_h = cached_program.shared_variables.core_h; + auto index_dims = cached_program.shared_variables.index_dims; + auto input_dim_offset = cached_program.shared_variables.input_dim_offset; + + TT_ASSERT(tensor_return_value.buffer()->size() == 1); + + auto src_buffer = tensor_args.input.buffer(); + auto dst_buffer = tensor_return_value.buffer(); + auto index_tensors = tensor_args.index_tensors; + IndexInfo index_info[5] = {0}; + for (uint32_t i = 0; i < index_dims.size(); i++) { + auto dim = index_dims[i] + input_dim_offset; + auto index_buffer = index_tensors[i]; + + index_info[dim].address = index_buffer.buffer()->address(); + } + + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + runtime_args[1] = index_info[0].address; + runtime_args[2] = index_info[1].address; + runtime_args[3] = index_info[2].address; + runtime_args[4] = index_info[3].address; + runtime_args[4] = index_info[4].address; + } + + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + } +} +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize.cpp similarity index 98% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize.cpp index d7d08d3524a..a754b1f8a77 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp" +#include "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp" void kernel_main() { uint32_t i = 0; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize_w.cpp similarity index 99% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize_w.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize_w.cpp index 95c3262299b..ab10a316f6f 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/reader_moreh_getitem_tilize_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/reader_moreh_getitem_tilize_w.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp" +#include "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp" void kernel_main() { uint32_t i = 0; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize.cpp similarity index 95% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize.cpp index 51c73a04a94..9951d58fff0 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp" +#include "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp" void kernel_main() { uint32_t i = 0; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize_w.cpp similarity index 97% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize_w.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize_w.cpp index 5c4537beb57..9b0d7004cca 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/writer_moreh_getitem_tilize_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/writer_moreh_getitem_tilize_w.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "dataflow_api.h" -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_getitem/moreh_getitem_tilized/kernels/common.hpp" +#include "ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_tilized_kernels/common.hpp" void kernel_main() { uint32_t i = 0; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp new file mode 100644 index 00000000000..8121e9798eb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.cpp @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_getitem.hpp" + +namespace ttnn::operations::moreh::moreh_getitem { +Tensor MorehGetItem::invoke( + const Tensor& input, + const std::vector& index_tensors, + const std::vector index_dims, + const std::optional& output, + // const CoreRange core_range, + const std::optional output_memory_config) { + return ttnn::prim::moreh_getitem(input, index_tensors, index_dims, output, output_memory_config); +} +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp new file mode 100644 index 00000000000..60c4405d032 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/decorators.hpp" +#include "ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.hpp" + +namespace ttnn::operations::moreh::moreh_getitem { +struct MorehGetItem { + static Tensor invoke( + const Tensor& input, + const std::vector& index_tensors, + const std::vector index_dims, + const std::optional& output, + // const CoreRange core_range, + const std::optional output_memory_config); + }; +} // namespace ttnn::operations::moreh::moreh_getitem + +namespace ttnn { +constexpr auto moreh_getitem = ttnn::register_operation_with_auto_launch_op<"ttnn::moreh_getitem", ttnn::operations::moreh::moreh_getitem::MorehGetItem>(); +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.cpp new file mode 100644 index 00000000000..c2a999b1551 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_getitem_pybind.hpp" + +#include "pybind11/decorators.hpp" +#include "ttnn/operations/moreh/moreh_getitem/moreh_getitem.hpp" + +namespace ttnn::operations::moreh::moreh_getitem { +void bind_moreh_getitem_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::moreh_getitem, + "Moreh moreh_getitem operation", + ttnn::pybind_arguments_t{ + py::arg("input"), + py::arg("index_tensors"), + py::arg("index_dims"), + py::kw_only(), + py::arg("output") = std::nullopt, + py::arg("output_memory_config") = std::nullopt}); +} +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp new file mode 100644 index 00000000000..961bd937428 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::moreh::moreh_getitem { +void bind_moreh_getitem_operation(py::module& module); +} // namespace ttnn::operations::moreh::moreh_getitem diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp index 5a943455f8f..a22c53a04d6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp @@ -6,10 +6,12 @@ #include "ttnn/operations/moreh/moreh_adam/moreh_adam_pybind.hpp" #include "ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.hpp" +#include "ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp" namespace ttnn::operations::moreh { void bind_moreh_operations(py::module &module) { moreh_arange::bind_moreh_arange_operation(module); moreh_adam::bind_moreh_adam_operation(module); + moreh_getitem::bind_moreh_getitem_operation(module); } } // namespace ttnn::operations::moreh diff --git a/ttnn/ttnn/operations/moreh.py b/ttnn/ttnn/operations/moreh.py index f80153fd09d..e3a58daef9e 100644 --- a/ttnn/ttnn/operations/moreh.py +++ b/ttnn/ttnn/operations/moreh.py @@ -6,3 +6,4 @@ arange = ttnn._ttnn.operations.moreh.moreh_arange adam = ttnn._ttnn.operations.moreh.moreh_adam +getitem = ttnn._ttnn.operations.moreh.moreh_getitem