diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index 226f7913e13c..125562f57023 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -392,6 +392,15 @@ Tensor create_device_tensor( Device *device, const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); +static Tensor create_device_tensor( + const ttnn::Shape &shape, + DataType dtype, + Layout layout, + Device *device, + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + return create_device_tensor(shape.value(), dtype, layout, device, memory_config); +} + // template // void *get_host_buffer(const Tensor &tensor); void *get_raw_host_data_ptr(const Tensor &tensor); diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp new file mode 100644 index 000000000000..3b9482158f9d --- /dev/null +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -0,0 +1,160 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "third_party/magic_enum/magic_enum.hpp" +#include "tt_dnn/op_library/operation_history.hpp" +#include "tt_stl/concepts.hpp" +#include "tt_stl/reflection.hpp" +#include "tt_stl/unique_any.hpp" + +namespace ttnn { + +namespace device_operation { + +template +struct CachedProgram { + tt::tt_metal::Program program; + // Cached program needs to share attributes between create and override_runtime_arguments functions + std::tuple attributes; + + CachedProgram(tt::tt_metal::Program&& program, attributes_t... attributes) : + program{std::move(program)}, attributes{std::tuple{attributes...}} {} +}; + +template +concept ProgramManagerConcept = requires { [](auto&&... args) { program_manager_t::create(args...); }; }; + +template +concept CacheableProgramManagerConcept = ProgramManagerConcept and requires { + [](auto&&... args) { program_manager_t::override_runtime_arguments(args...); }; +}; + +template +concept DeviceOperationConcept = requires { + [](const typename operation_t::operation_attributes_t& attributes, + const typename operation_t::tensor_args_t& tensor_args) { + const auto program_manager = operation_t::select_program_manager(attributes, tensor_args); + + operation_t::validate(program_manager, attributes, tensor_args); + + using shape_return_t = typename operation_t::shape_return_t; + static_assert(std::same_as< + decltype(operation_t::compute_output_shapes(program_manager, attributes, tensor_args)), + shape_return_t>); + + using tensor_return_value_t = typename operation_t::tensor_return_value_t; + static_assert(std::same_as< + decltype(operation_t::create_output_tensors(program_manager, attributes, tensor_args)), + tensor_return_value_t>); + }; +}; + +template +concept DeviceOperationWithCustomProgramCacheConcept = DeviceOperationConcept and requires { + [](auto&& program_manager, + const typename operation_t::operation_attributes_t& attributes, + const typename operation_t::tensor_args_t& tensor_args) { + operation_t::compute_program_hash(program_manager, attributes, tensor_args); + }; +}; + +template + requires ProgramManagerConcept +constexpr auto create_or_get_program_from_cache( + auto& program_cache, const typename operation_t::operation_attributes_t& attributes, auto&&... args) { + if constexpr (CacheableProgramManagerConcept) { + auto program_hash = [&]() { + const auto& tensor_args = std::get<0>(std::forward_as_tuple(args...)); + if constexpr (DeviceOperationWithCustomProgramCacheConcept) { + ZoneScopedN("Compute Custom Program Hash"); + return operation_t::compute_program_hash(program_manager_t{}, attributes, tensor_args); + } else { + ZoneScopedN("Compute Default Program Hash"); + return tt::stl::hash::hash_objects_with_default_seed( + typeid(operation_t).hash_code(), attributes, tensor_args); + } + }(); + + using cached_program_t = decltype(program_manager_t::create(attributes, std::forward(args)...)); + + auto cache_hit = program_cache.contains(program_hash); + if (not cache_hit) { + program_cache.insert( + program_hash, program_manager_t::create(attributes, std::forward(args)...)); + auto& cached_program = program_cache.template get(program_hash); + return std::reference_wrapper{cached_program.program}; + } else { + auto& cached_program = program_cache.template get(program_hash); + program_manager_t::override_runtime_arguments( + cached_program, attributes, std::forward(args)...); + return std::reference_wrapper{cached_program.program}; + } + + } else { + return program_manager_t::create(attributes, std::forward(args)...); + } +} + +struct void_t {}; + +template + requires DeviceOperationConcept +constexpr typename operation_t::tensor_return_value_t run( + const typename operation_t::operation_attributes_t& attributes, + const typename operation_t::tensor_args_t& tensor_args) { + auto program_manager = operation_t::select_program_manager(attributes, tensor_args); + + operation_t::validate(program_manager, attributes, tensor_args); + + using tensor_return_value_t = typename operation_t::tensor_return_value_t; + auto tensor_return_value = [&program_manager, &attributes, &tensor_args]() { + ZoneScopedN("Create Output Tensors"); + if constexpr (std::is_same_v) { + operation_t::create_output_tensors(program_manager, attributes, tensor_args); + return void_t{}; + } else { + return operation_t::create_output_tensors(program_manager, attributes, tensor_args); + } + }(); + + auto cq_id = 0; + auto device = tensor_args.input_tensor_a.device(); + auto& queue = device->command_queue(cq_id); + + auto program = std::visit( + [&device, &attributes, &tensor_args, &tensor_return_value](auto&& program_manager) + -> std::variant> { + ZoneScopedN("Create Or Get Program From the Cache"); + using program_manager_t = std::decay_t; + if constexpr (std::is_same_v) { + return create_or_get_program_from_cache( + device->program_cache, attributes, tensor_args); + } else { + return create_or_get_program_from_cache( + device->program_cache, attributes, tensor_args, tensor_return_value); + } + }, + program_manager); + + std::visit( + [&queue](auto&& program) { + ZoneScopedN("Enqueue Program"); + tt::tt_metal::EnqueueProgram(queue, program, false); + }, + program); + + if constexpr (not std::is_same_v) { + return tensor_return_value; + } +} + +} // namespace device_operation + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index f69744b8a048..6410215581cb 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -18,73 +18,13 @@ namespace ttnn { namespace operations { namespace binary { - -enum class BinaryProgramType { - ElementWiseMultiCore, - BroadcastWidthMultiCore, - BroadcastHeightMultiCore, - BroadcastHeightAndWidthMultiCore, -}; - -inline BinaryProgramType get_program_type(const Binary& operation, const std::vector& input_tensors) { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - - const auto& input_shape_a = input_tensor_a.get_shape(); - const auto& input_shape_b = input_tensor_b.get_shape(); - - auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; - auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; - auto height_a = input_shape_a[-2]; - auto width_a = input_shape_a[-1]; - - auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; - auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; - auto height_b = input_shape_b[-2]; - auto width_b = input_shape_b[-1]; - - /* - fmt::print("input_shape_a: {}, input_shape_b: {}\n", input_shape_a, input_shape_b); - fmt::print( - "batch_size_0_a: {}, batch_size_1_a: {}, height_a: {}, width_a: {}\n", - batch_size_0_a, - batch_size_1_a, - height_a, - width_a); - fmt::print( - "batch_size_0_b: {}, batch_size_1_b: {}, height_b: {}, width_b: {}\n", - batch_size_0_b, - batch_size_1_b, - height_b, - width_b); - */ - - if (batch_size_0_a == batch_size_0_b and batch_size_1_a == batch_size_1_b and height_a == height_b and - width_a == width_b) { - return BinaryProgramType::ElementWiseMultiCore; - } else if (height_b == 1 or width_b == 1) { - if (operation.dtype != input_tensor_a.get_dtype()) { - TT_THROW("ttnn::operations::binary::Binary: cannot change dtype when broadcasting"); - } - if (height_b == 1 and width_b == 1) { - // fmt::print("BinaryProgramType::BroadcastHeightAndWidthMultiCore\n"); - return BinaryProgramType::BroadcastHeightAndWidthMultiCore; - } else if (height_b == 1) { - // fmt::print("BinaryProgramType::BroadcastHeightMultiCore\n"); - return BinaryProgramType::BroadcastHeightMultiCore; - } else if (width_b == 1) { - // fmt::print("BinaryProgramType::BroadcastWidthMultiCore\n"); - return BinaryProgramType::BroadcastWidthMultiCore; - } - } - TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); -} - -void Binary::validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const { - auto program_type = get_program_type(*this, input_tensors); - - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); +/* static */ void Binary::validate( + const program_manager_t& program_manager, + const operation_attributes_t& attributes, + const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + const auto& output_tensor = tensor_args.output_tensor; const auto& input_shape_a = input_tensor_a.get_shape(); const auto& input_shape_b = input_tensor_b.get_shape(); @@ -123,10 +63,10 @@ void Binary::validate_with_output_tensors(const std::vector &input_tenso TT_FATAL( (input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE), "Inputs to eltwise binary must be tilized"); - if (this->in_place) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == this->memory_config.memory_layout); - TT_FATAL(input_tensor_a.memory_config().buffer_type == this->memory_config.buffer_type); - TT_FATAL(input_tensor_a.get_dtype() == this->dtype); + if (attributes.in_place) { + TT_FATAL(input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout); + TT_FATAL(input_tensor_a.memory_config().buffer_type == attributes.memory_config.buffer_type); + TT_FATAL(input_tensor_a.get_dtype() == attributes.dtype); } if (input_tensor_a.memory_config().is_sharded()) { if (input_tensor_a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { @@ -139,70 +79,73 @@ void Binary::validate_with_output_tensors(const std::vector &input_tenso TT_FATAL(input_tensor_a.memory_config().memory_layout == input_tensor_b.memory_config().memory_layout); TT_FATAL(input_tensor_a.shard_spec().value() == input_tensor_b.shard_spec().value()); } - if (this->memory_config.is_sharded()) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == this->memory_config.memory_layout); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } else if (input_tensor_b.memory_config().is_sharded()) { TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); - if (this->memory_config.is_sharded()) { - TT_FATAL(input_tensor_b.memory_config().memory_layout == this->memory_config.memory_layout); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(input_tensor_b.memory_config().memory_layout == attributes.memory_config.memory_layout); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } else { TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); - if (this->memory_config.is_sharded()) { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + if (attributes.memory_config.is_sharded()) { + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); uint32_t num_blocks = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; auto core_grid = input_tensor_a.device()->compute_with_storage_grid_size(); uint32_t num_cores = core_grid.x * core_grid.y; TT_FATAL(num_blocks < num_cores or num_blocks % num_cores == 0); } else { - TT_FATAL(this->memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED); } } - if (program_type != BinaryProgramType::ElementWiseMultiCore) { - TT_FATAL(not this->activations.has_value()); - } - - if (!output_tensors.empty()) { - TT_FATAL(output_tensors.size() == 1, "Must have 1 output tensors"); - - if(output_tensors.at(0).has_value()) { - TT_FATAL(!this->in_place, "Operation is configured as in_place. First input is used as output. Provided output tensor is ignored"); - } + std::visit( + [&attributes](auto&& program_manager) { + if constexpr (std::is_same_v) { + TT_FATAL(not attributes.activations.has_value()); + } + }, + program_manager); + + if (output_tensor.has_value()) { + TT_FATAL( + not attributes.in_place, + "Operation is configured as in_place. First input is used as output. Provided output tensor is " + "ignored"); } } -std::vector Binary::compute_output_shapes(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - if (input_tensor_a.get_shape().rank() >= input_tensor_b.get_shape().rank()) { - return {input_tensor_a.get_legacy_shape()}; - } - return {input_tensor_b.get_legacy_shape()}; +/* static */ Binary::shape_return_t Binary::compute_output_shapes( + const program_manager_t&, const operation_attributes_t&, const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + return input_tensor_a.shape(); } -std::vector Binary::create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - if (this->in_place) { +/* static */ Binary::tensor_return_value_t Binary::create_output_tensors( + const program_manager_t& program_manager, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args) { + auto output_shape = compute_output_shapes(program_manager, operation_attributes, tensor_args); + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + const auto& output_tensor = tensor_args.output_tensor; + if (operation_attributes.in_place) { return {input_tensor_a}; } else { - if (!output_tensors.empty() && output_tensors.at(0).has_value()) { - return {output_tensors.at(0).value()}; + if (output_tensor.has_value()) { + return output_tensor.value(); } - auto program_type = get_program_type(*this, input_tensors); - - if (program_type == BinaryProgramType::ElementWiseMultiCore) { - if (this->memory_config.is_sharded()) { + if (std::holds_alternative(program_manager)) { + if (operation_attributes.memory_config.is_sharded()) { ShardSpec shard_spec{CoreRangeSet({}), {0, 0}}; if (input_tensor_a.memory_config().is_sharded()) { shard_spec = input_tensor_a.shard_spec().value(); @@ -218,96 +161,90 @@ std::vector Binary::create_output_tensors(const std::vector& inp num_blocks / target_num_cores * TILE_HEIGHT, input_tensor_a.get_legacy_shape()[-1]}; shard_spec.orientation = ShardOrientation::ROW_MAJOR; } - auto memory_config = this->memory_config; + auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->dtype, + return create_device_tensor( + output_shape, + operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), - memory_config)}; + operation_attributes.memory_config); } } else { - if (this->memory_config.is_sharded()) { + if (operation_attributes.memory_config.is_sharded()) { ShardSpec shard_spec{CoreRangeSet({}), {0, 0}}; if (input_tensor_a.memory_config().is_sharded()) { // Derive output shard_spec based on input shard_spec = input_tensor_a.shard_spec().value(); } - auto memory_config = this->memory_config; + auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->dtype, + return create_device_tensor( + output_shape, + operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), - memory_config)}; + operation_attributes.memory_config); } } - return operation::generic_create_output_tensors( - *this, input_tensors, this->dtype, Layout::TILE, this->memory_config); + return create_device_tensor( + output_shape, + operation_attributes.dtype, + Layout::TILE, + input_tensor_a.device(), + operation_attributes.memory_config); } } -const std::optional binary_op_type_to_bcast_op_math(const BinaryOpType binary_op_type) { - switch (binary_op_type) { - case BinaryOpType::ADD: return tt::tt_metal::BcastOpMath::ADD; - case BinaryOpType::SUB: return tt::tt_metal::BcastOpMath::SUB; - case BinaryOpType::MUL: return tt::tt_metal::BcastOpMath::MUL; - default: return std::nullopt; - } -} +/* static */ Binary::program_manager_t Binary::select_program_manager( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + + const auto& input_shape_a = input_tensor_a.get_shape(); + const auto& input_shape_b = input_tensor_b.get_shape(); -operation::ProgramWithCallbacks Binary::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - const auto& output_tensor = output_tensors.at(0); + auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; + auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; + auto height_a = input_shape_a[-2]; + auto width_a = input_shape_a[-1]; - std::vector activations; - if (this->activations.has_value()) { - const auto activations_as_strings = this->activations.value(); - std::transform( - activations_as_strings.begin(), - activations_as_strings.end(), - std::back_inserter(activations), - [](const std::string& activation) { return string_to_unary_with_param(activation); }); - } + auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; + auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; + auto height_b = input_shape_b[-2]; + auto width_b = input_shape_b[-1]; - auto program_type = get_program_type(*this, input_tensors); - auto bcast_op_math = binary_op_type_to_bcast_op_math(this->binary_op_type); - if (bcast_op_math.has_value()) { - switch (program_type) { - case BinaryProgramType::ElementWiseMultiCore: - return eltwise_binary_multi_core( - input_tensor_a, input_tensor_b, output_tensor, this->binary_op_type, activations); - case BinaryProgramType::BroadcastHeightAndWidthMultiCore: - return bcast_multi_core_hw( - input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value(), false /* in-place */); - case BinaryProgramType::BroadcastHeightMultiCore: - return bcast_multi_core_h(input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value()); - case BinaryProgramType::BroadcastWidthMultiCore: - return bcast_multi_core_w(input_tensor_a, input_tensor_b, output_tensor, bcast_op_math.value()); - default: TT_THROW("Invalid program type"); + if (batch_size_0_a == batch_size_0_b and batch_size_1_a == batch_size_1_b and height_a == height_b and + width_a == width_b) { + return ElementWiseMultiCore{}; + } else if (height_b == 1 or width_b == 1) { + if (operation_attributes.dtype != input_tensor_a.get_dtype()) { + TT_THROW("ttnn::operations::binary::Binary: cannot change dtype when broadcasting"); } - } else { - switch (program_type) { - case BinaryProgramType::ElementWiseMultiCore: - return eltwise_binary_multi_core( - input_tensor_a, input_tensor_b, output_tensor, this->binary_op_type, activations); - default: TT_THROW("Invalid program type"); + if (height_b == 1 and width_b == 1) { + // fmt::print("BinaryProgramType::BroadcastHeightAndWidthMultiCore\n"); + return BroadcastHeightAndWidthMultiCore{}; + } else if (height_b == 1) { + // fmt::print("BinaryProgramType::BroadcastHeightMultiCore\n"); + return BroadcastHeightMultiCore{}; + } else if (width_b == 1) { + // fmt::print("BinaryProgramType::BroadcastWidthMultiCore\n"); + return BroadcastWidthMultiCore{}; } } + TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); } -const operation::Hash Binary::compute_program_hash(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); - auto program_type = get_program_type(*this, input_tensors); +/* static */ tt::stl::hash::hash_t Binary::compute_program_hash( + const program_manager_t& program_manager, + const operation_attributes_t& attributes, + const tensor_args_t& tensor_args) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; operation::Hash hash = tt::stl::hash::hash_objects_with_default_seed( - typeid(*this).hash_code(), - this, - program_type, + typeid(Binary).hash_code(), + attributes, + std::visit([](auto&& program_manager) { return typeid(program_manager).hash_code(); }, program_manager), input_tensor_a.dtype(), std::get(input_tensor_a.storage()).memory_config(), input_tensor_b.dtype(), @@ -315,30 +252,6 @@ const operation::Hash Binary::compute_program_hash(const std::vector& in return hash; } -operation::OpPerformanceModel Binary::create_op_performance_model( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - std::vector& output_tensors) const { - // GS specific parameters - // 80 B/cycle unpacker BW shared - // 128 datums per cycle math, but unpacker cant keep up - constexpr int num_cores = 9 * 12; - - int total_bytes = 0; - for (const auto& t : input_tensors) { - total_bytes += t.volume() * t.element_size(); - } - int ideal_eltwise_cycles = total_bytes / 80 / num_cores; - - operation::OpPerformanceModel result(input_tensors, output_tensors, ideal_eltwise_cycles); -#if 0 - tt::log_info(tt::LogOp, "Binary PerfModel:"); - tt::log_info(tt::LogOp, "\t Data (Bytes): {}", total_bytes); - tt::log_info(tt::LogOp, "\t ideal_eltwise_cycles: {}", ideal_eltwise_cycles); -#endif - return result; -} - } // namespace binary } // namespace operations diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp index e9ea111d4476..26c75e1ce120 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp @@ -6,6 +6,7 @@ #include #include +#include #include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" @@ -19,6 +20,8 @@ #include "tt_metal/impl/dispatch/command_queue.hpp" #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/op_library/binary/element_wise_multi_core.hpp" #include "ttnn/types.hpp" namespace ttnn { @@ -32,37 +35,107 @@ using BinaryOpType = tt::tt_metal::BinaryOpType; constexpr uint8_t DefaultQueueId = 0; struct Binary { - BinaryOpType binary_op_type; - bool in_place; - const std::optional> activations; - const MemoryConfig memory_config; - const DataType dtype; - std::optional compute_kernel_config; - - void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector &input_tensors, std::vector &output_tensors) const; - - const operation::Hash compute_program_hash(const std::vector &input_tensors) const; - - operation::OpPerformanceModel create_op_performance_model( - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - std::vector &output_tensors) const; - - static constexpr auto attribute_names = std::forward_as_tuple( - "binary_op_type", "in_place", "activations", "memory_config", "dtype", "compute_kernel_config"); - const auto attribute_values() const { - return std::forward_as_tuple( - this->binary_op_type, - this->in_place, - this->activations, - this->memory_config, - this->dtype, - this->compute_kernel_config); - } + struct operation_attributes_t { + BinaryOpType binary_op_type; + bool in_place; + const std::optional> activations; + const MemoryConfig memory_config; + const DataType dtype; + std::optional compute_kernel_config; + + static constexpr auto attribute_names = std::forward_as_tuple( + "binary_op_type", "in_place", "activations", "memory_config", "dtype", "compute_kernel_config"); + const auto attribute_values() const { + return std::forward_as_tuple( + this->binary_op_type, + this->in_place, + this->activations, + this->memory_config, + this->dtype, + this->compute_kernel_config); + } + }; + struct tensor_args_t { + const Tensor& input_tensor_a; + const Tensor& input_tensor_b; + std::optional output_tensor; + + static constexpr auto attribute_names = + std::forward_as_tuple("input_tensor_a", "input_tensor_b", "output_tensor"); + const auto attribute_values() const { + return std::forward_as_tuple(this->input_tensor_a, this->input_tensor_b, this->output_tensor); + } + }; + using shape_return_t = ttnn::Shape; + using tensor_return_value_t = Tensor; + + struct ElementWiseMultiCore { + static auto create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return) { + return element_wise_multi_core::create(operation_attributes, tensor_args, tensor_return); + } + static void override_runtime_arguments( + auto& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return) { + element_wise_multi_core::override_runtime_arguments( + cached_program, operation_attributes, tensor_args, tensor_return); + } + }; + + struct BroadcastWidthMultiCore { + static auto create(const operation_attributes_t&, const tensor_args_t&, tensor_return_value_t&) { + return device_operation::CachedProgram{tt::tt_metal::Program()}; + } + static void override_runtime_arguments( + auto& cached_program, + const operation_attributes_t&, + const tensor_args_t&, + tensor_return_value_t&) {} + }; + + struct BroadcastHeightMultiCore { + static auto create(const operation_attributes_t&, const tensor_args_t&, tensor_return_value_t&) { + return device_operation::CachedProgram{tt::tt_metal::Program()}; + } + static void override_runtime_arguments( + auto& cached_program, + const operation_attributes_t&, + const tensor_args_t&, + tensor_return_value_t&) {} + }; + + struct BroadcastHeightAndWidthMultiCore { + static auto create(const operation_attributes_t&, const tensor_args_t&, tensor_return_value_t&) { + return device_operation::CachedProgram{tt::tt_metal::Program()}; + } + static void override_runtime_arguments( + auto& cached_program, + const operation_attributes_t&, + const tensor_args_t&, + tensor_return_value_t&) {} + }; + + using program_manager_t = std::variant< + ElementWiseMultiCore, + BroadcastWidthMultiCore, + BroadcastHeightMultiCore, + BroadcastHeightAndWidthMultiCore>; + + static program_manager_t select_program_manager(const operation_attributes_t&, const tensor_args_t&); + + static void validate(const program_manager_t&, const operation_attributes_t&, const tensor_args_t&); + + static shape_return_t compute_output_shapes(const program_manager_t&, const operation_attributes_t&, const tensor_args_t&); + + static tensor_return_value_t create_output_tensors( + const program_manager_t&, const operation_attributes_t& operation_attributes, const tensor_args_t&); + + static tt::stl::hash::hash_t compute_program_hash( + const program_manager_t&, const operation_attributes_t&, const tensor_args_t&); }; } // namespace binary diff --git a/ttnn/cpp/ttnn/op_library/binary/element_wise_multi_core.hpp b/ttnn/cpp/ttnn/op_library/binary/element_wise_multi_core.hpp new file mode 100644 index 000000000000..539955bff231 --- /dev/null +++ b/ttnn/cpp/ttnn/op_library/binary/element_wise_multi_core.hpp @@ -0,0 +1,468 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +namespace ttnn { + +namespace operations { +namespace binary { +namespace element_wise_multi_core { + +template +inline __attribute__((always_inline)) void set_eltwise_binary_runtime_args( + Program& program, + const Tensor& a, + const Tensor& b, + const Tensor& output, + const KernelHandle binary_reader_kernel_id, + const KernelHandle unary_writer_kernel_id, + const KernelHandle eltwise_binary_kernel_id, + const CBHandle cb_src0, + const CBHandle cb_src1, + const CBHandle cb_output, + const CoreCoord compute_with_storage_grid_size, + const uint32_t src0_single_tile_size, + const uint32_t src1_single_tile_size, + const uint32_t dst_single_tile_size) { + using namespace tt; + using namespace tt::tt_metal; + + auto src_buffer_a = a.buffer(); + auto src_buffer_b = b.buffer(); + auto dst_buffer = output.buffer(); + + CoreRangeSet all_cores({}), core_group_1({}), core_group_2({}); + + std::optional shard_spec = std::nullopt; + bool src0_sharded = a.memory_config().is_sharded(); + bool src1_sharded = b.memory_config().is_sharded(); + bool out_sharded = output.memory_config().is_sharded(); + + bool block_sharded = false; + if (src0_sharded) { + shard_spec = a.shard_spec().value(); + block_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (src1_sharded) { + shard_spec = b.shard_spec().value(); + block_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (out_sharded) { + shard_spec = output.shard_spec().value(); + block_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } + + uint32_t num_tiles = a.volume() / TILE_HW; + + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores, num_tiles_per_core_group_1, num_tiles_per_core_group_2; + uint32_t num_cores_total = num_cores_x * num_cores_y; + + uint32_t block_size_per_core_group_1 = 1, block_size_per_core_group_2 = 1, max_block_size = 1; + + uint32_t block_cnt_per_core_group_1, block_cnt_per_core_group_2; + + bool row_major; + uint32_t block_height = 0, block_width = 0, block_size = 0, output_width = 0, last_unpadded_block_height = 0, + last_unpadded_block_width = 0; + CoreCoord end_core; + vector cores; + + if (shard_spec.has_value()) { + all_cores = shard_spec.value().grid; + num_cores = all_cores.num_cores(); + core_group_1 = all_cores; + core_group_2 = CoreRangeSet({}); + num_tiles_per_core_group_1 = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_core_group_2 = 0; + block_size_per_core_group_1 = find_max_block_size(num_tiles_per_core_group_1); + max_block_size = block_size_per_core_group_1; + + block_cnt_per_core_group_1 = num_tiles_per_core_group_1 / block_size_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2 / block_size_per_core_group_2; + row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; + if (block_sharded) { + block_height = shard_spec.value().shape[0] / TILE_HEIGHT; + block_width = shard_spec.value().shape[1] / TILE_WIDTH; + block_size = block_width * block_height; + end_core = (*shard_spec.value().grid.ranges().begin()).end; + output_width = output.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t output_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT; + last_unpadded_block_height = block_height - (round_up(output_height, block_height) - output_height); + last_unpadded_block_width = block_width - (round_up(output_width, block_width) - output_width); + } + auto bbox = core_group_1.bounding_box(); + cores = grid_to_cores_with_noop(bbox.end.x, bbox.end.y, num_cores_x, num_cores_y, row_major); + } else { + row_major = true; + std::tie( + num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2) = + split_work_to_cores(compute_with_storage_grid_size, num_tiles, row_major); + block_cnt_per_core_group_1 = num_tiles_per_core_group_1; + block_cnt_per_core_group_2 = num_tiles_per_core_group_2; + cores = grid_to_cores(num_cores_x * num_cores_y, num_cores_x, num_cores_y, row_major); + } + + uint32_t g1_numcores = core_group_1.num_cores(); + uint32_t g2_numcores = core_group_2.num_cores(); + + std::vector> binary_reader_args; + std::vector> eltwise_binary_args; + std::vector> unary_writer_args; + if constexpr (initialize_args) { + binary_reader_args = {cores.size(), std::vector(4)}; + eltwise_binary_args = {cores.size(), std::vector(2)}; + if (block_sharded and not out_sharded) + unary_writer_args = {cores.size(), std::vector(7)}; + else + unary_writer_args = {cores.size(), std::vector(3)}; + } + + auto& cached_reader_args = GetRuntimeArgs(program, binary_reader_kernel_id); + auto& cached_eltwise_args = GetRuntimeArgs(program, eltwise_binary_kernel_id); + auto& cached_writer_args = GetRuntimeArgs(program, unary_writer_kernel_id); + + for (uint32_t i = 0, num_tiles_read = 0; i < num_cores_total; ++i) { + const CoreCoord& core = cores.at(i); + uint32_t num_tiles_per_core = 0; + uint32_t block_cnt_per_core = 0; + uint32_t block_size_per_core = 0; + if (i < g1_numcores) { + num_tiles_per_core = num_tiles_per_core_group_1; + block_cnt_per_core = block_cnt_per_core_group_1; + block_size_per_core = block_size_per_core_group_1; + } else if (i < num_cores) { + num_tiles_per_core = num_tiles_per_core_group_2; + block_cnt_per_core = block_cnt_per_core_group_2; + block_size_per_core = block_size_per_core_group_2; + } else { + // Zero out non-working cores RT args. Only necessary in override + // since initialization pushes zero vectors to unused cores. + if constexpr (!initialize_args) { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[2] = 0; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = 0; + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[1] = 0; + } + continue; + } + if constexpr (initialize_args) { + binary_reader_args[i] = { + src_buffer_a->address(), src_buffer_b->address(), num_tiles_per_core, num_tiles_read}; + eltwise_binary_args[i] = {block_cnt_per_core, block_size_per_core}; + } else { + auto& reader_args = cached_reader_args.at(core.x).at(core.y); + reader_args[0] = src_buffer_a->address(); + reader_args[1] = src_buffer_b->address(); + reader_args[2] = num_tiles_per_core; + reader_args[3] = num_tiles_read; + auto& eltwise_args = cached_eltwise_args.at(core.x).at(core.y); + eltwise_args[0] = block_cnt_per_core; + eltwise_args[1] = block_size_per_core; + } + if (block_sharded and not out_sharded) { + uint32_t block_start_width_offset; + uint32_t block_start_height_offset; + uint32_t unpadded_block_height = block_height; + uint32_t unpadded_block_width = block_width; + if (row_major) { + block_start_width_offset = core.x * block_width; + block_start_height_offset = core.y * block_height; + if (core.x == end_core.x) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.y == end_core.y) { + unpadded_block_height = last_unpadded_block_height; + } + } else { + block_start_width_offset = core.y * block_width; + block_start_height_offset = core.x * block_height; + if (core.y == end_core.y) { + unpadded_block_width = last_unpadded_block_width; + } + if (core.x == end_core.x) { + unpadded_block_height = last_unpadded_block_height; + } + } + if constexpr (initialize_args) { + unary_writer_args[i] = { + dst_buffer->address(), + block_height, + block_width, + unpadded_block_height, + unpadded_block_width, + output_width, + block_size, + block_start_height_offset * output_width + block_start_width_offset, + 0}; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = block_height; + writer_args[2] = block_width; + writer_args[3] = unpadded_block_height; + writer_args[4] = unpadded_block_width; + writer_args[5] = output_width; + writer_args[6] = block_size; + writer_args[7] = block_start_height_offset * output_width + block_start_width_offset; + writer_args[8] = 0; + } + } else { + if constexpr (initialize_args) { + unary_writer_args[i] = {dst_buffer->address(), num_tiles_per_core, num_tiles_read}; + } else { + auto& writer_args = cached_writer_args.at(core.x).at(core.y); + writer_args[0] = dst_buffer->address(); + writer_args[1] = num_tiles_per_core; + writer_args[2] = num_tiles_read; + } + } + num_tiles_read += num_tiles_per_core; + } + + if constexpr (initialize_args) { + SetRuntimeArgs(program, binary_reader_kernel_id, cores, binary_reader_args); + SetRuntimeArgs(program, eltwise_binary_kernel_id, cores, eltwise_binary_args); + SetRuntimeArgs(program, unary_writer_kernel_id, cores, unary_writer_args); + } + + if (src0_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer_a); + UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * src0_single_tile_size); + } + if (src1_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_src1, *src_buffer_b); + UpdateCircularBufferTotalSize(program, cb_src1, num_tiles_per_core_group_1 * src1_single_tile_size); + } + if (out_sharded) { + UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); + UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size); + } +} + +inline auto create(const auto& operation_attributes, const auto& tensor_args, auto& tensor_return) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& a = tensor_args.input_tensor_a; + const auto& b = tensor_args.input_tensor_b; + auto& output = tensor_return; + const auto& op_type = operation_attributes.binary_op_type; + + std::vector fused_activations; + if (operation_attributes.activations.has_value()) { + const auto activations_as_strings = operation_attributes.activations.value(); + std::transform( + activations_as_strings.begin(), + activations_as_strings.end(), + std::back_inserter(fused_activations), + [](const std::string& activation) { return string_to_unary_with_param(activation); }); + } + + Program program{}; + + tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); + tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); + + tt_metal::Buffer* src0_buffer = a.buffer(); + tt_metal::Buffer* src1_buffer = b.buffer(); + + tt_metal::Device* device = a.device(); + + std::optional shard_spec = std::nullopt; + bool src0_sharded = a.memory_config().is_sharded(); + bool src1_sharded = b.memory_config().is_sharded(); + bool out_sharded = output.memory_config().is_sharded(); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + bool block_sharded = false; + + if (src0_sharded) { + shard_spec = a.shard_spec().value(); + block_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (src1_sharded) { + shard_spec = b.shard_spec().value(); + block_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } else if (out_sharded) { + shard_spec = output.shard_spec().value(); + block_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + } + + uint32_t max_block_size = 1, num_tiles_per_shard = 0; + if (shard_spec.has_value()) { + num_tiles_per_shard = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + max_block_size = find_max_block_size(num_tiles_per_shard); + } + + tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = src0_sharded ? num_tiles_per_shard : 2 * max_block_size; + tt_metal::CircularBufferConfig cb_src0_config = + tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) + .set_page_size(src0_cb_index, src0_single_tile_size); + if (src0_sharded) { + cb_src0_config = cb_src0_config.set_globally_allocated_address(*a.buffer()); + } + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src0_config); + + uint32_t src1_cb_index = 1; + num_input_tiles = src1_sharded ? num_tiles_per_shard : 2 * max_block_size; + tt_metal::CircularBufferConfig cb_src1_config = + tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) + .set_page_size(src1_cb_index, src1_single_tile_size); + if (src1_sharded) { + cb_src1_config = cb_src1_config.set_globally_allocated_address(*b.buffer()); + } + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); + + std::map eltwise_defines = + eltwise_binary_op_utils::get_defines(op_type, output.get_dtype(), fused_activations); + + if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN0_0") != eltwise_defines.end()) { + tt_metal::CircularBufferConfig cb_interm_config = + tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{CB::c_intermed0, src0_cb_data_format}}) + .set_page_size(CB::c_intermed0, src0_single_tile_size); + auto cb_interm = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm_config); + } + if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN1_0") != eltwise_defines.end()) { + tt_metal::CircularBufferConfig cb_interm2_config = + tt_metal::CircularBufferConfig(1 * src1_single_tile_size, {{CB::c_intermed1, src1_cb_data_format}}) + .set_page_size(CB::c_intermed1, src1_single_tile_size); + auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config); + } + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = (out_sharded || block_sharded) ? num_tiles_per_shard : 2 * max_block_size; + tt_metal::CircularBufferConfig cb_output_config = + tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) + .set_page_size(output_cb_index, dst_single_tile_size); + if (out_sharded) { + cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); + } + auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_output_config); + + std::map reader_defines; + if (src0_sharded) { + reader_defines["IN0_SHARDED"] = "1"; + } + if (src1_sharded) { + reader_defines["IN1_SHARDED"] = "1"; + } + std::map writer_defines; + if (out_sharded) { + writer_defines["OUT_SHARDED"] = "1"; + } + + bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(std::uint32_t)src0_is_dram, (std::uint32_t)src1_is_dram}; + + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; + + KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/dataflow/reader_binary_interleaved_start_id.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); + + KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + program, + (block_sharded and not out_sharded) + ? "tt_eager/tt_dnn/op_library/sharded/kernels/dataflow/writer_unary_sharded_blocks_interleaved_start_id.cpp" + : "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + all_device_cores, + tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + + bool fp32_dest_acc_en = dst_cb_data_format == tt::DataFormat::UInt32 || + dst_cb_data_format == tt::DataFormat::Int32 || + dst_cb_data_format == tt::DataFormat::Float32; + auto eltwise_binary_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/compute/eltwise_binary.cpp", + all_device_cores, + tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .defines = eltwise_defines}); + + set_eltwise_binary_runtime_args( + program, + a, + b, + output, + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size); + + return device_operation::CachedProgram{ + std::move(program), + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size}; +} + +inline void override_runtime_arguments( + auto& cached_program, auto& operation_attributes, auto& tensor_args, auto& tensor_return) { + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto& input_tensor_b = tensor_args.input_tensor_b; + auto& output_tensor = tensor_return; + + auto&& [binary_reader_kernel_id, unary_writer_kernel_id, eltwise_binary_kernel_id, cb_src0, cb_src1, cb_output, compute_with_storage_grid_size, src0_single_tile_size, src1_single_tile_size, dst_single_tile_size] = + cached_program.attributes; + + set_eltwise_binary_runtime_args( + cached_program.program, + input_tensor_a, + input_tensor_b, + output_tensor, + binary_reader_kernel_id, + unary_writer_kernel_id, + eltwise_binary_kernel_id, + cb_src0, + cb_src1, + cb_output, + compute_with_storage_grid_size, + src0_single_tile_size, + src1_single_tile_size, + dst_single_tile_size); +} + +} // namespace element_wise_multi_core + +} // namespace binary + +} // namespace operations + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/binary.hpp b/ttnn/cpp/ttnn/operations/binary.hpp index 2029547fd80a..b58178eb70fc 100644 --- a/ttnn/cpp/ttnn/operations/binary.hpp +++ b/ttnn/cpp/ttnn/operations/binary.hpp @@ -84,14 +84,9 @@ struct ExecuteBinary { dtype = optional_output_tensor.value().get_dtype(); } - auto output_tensors = operation::run( - Binary{binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt}, - {input_tensor_a, input_tensor_b}, - {}, - {optional_output_tensor}, - queue_id); - - return output_tensors.at(0); + return ttnn::device_operation::run( + {binary_op_type, in_place, activations, output_memory_config, dtype, std::nullopt}, + {input_tensor_a, input_tensor_b, optional_output_tensor}); } template