diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 8efea1c20f4..6149343ad9b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -149,11 +149,11 @@ std::map get_defines( defines["ELTWISE_OP"] = op_name.c_str(); defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str(); if (fused_activations.has_value()) { - if (op_type == BinaryOpType::ADD and fused_activations.value().size() == 1 and - fused_activations.value().at(0).op_type == UnaryOpType::RELU) { + if (op_type == BinaryOpType::ADD and fused_activations->size() == 1 and + fused_activations->at(0).op_type == UnaryOpType::RELU and not input_tensor_a_activation.has_value()) { defines["PACK_RELU"] = "1"; } else { - defines.merge(ttnn::operations::unary::utils::get_block_defines(fused_activations.value(), "0", idst)); + defines.merge(ttnn::operations::unary::utils::get_block_defines(*fused_activations, "0", idst)); } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 1439186fbd4..5320cd6a70f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -11,44 +11,93 @@ namespace ttnn::operations::binary_ng { template Tensor BinaryNg::invoke( uint8_t queue_id, - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { return ttnn::prim::binary_ng( - queue_id, input_tensor_a, input_tensor_b, binary_op_type, output_dtype, memory_config, optional_output_tensor); + queue_id, + input_tensor_a, + input_tensor_b, + binary_op_type, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor_a, input_tensor_b, output_dtype, memory_config, optional_output_tensor); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + return invoke( + DefaultQueueId, + input_tensor_a, + input_tensor_b, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( uint8_t queue_id, - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - return ttnn::prim::binary_ng(queue_id, input_tensor_a, scalar, binary_op_type, output_dtype, memory_config, optional_output_tensor); + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + return ttnn::prim::binary_ng( + queue_id, + input_tensor_a, + scalar, + binary_op_type, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor_a, scalar, output_dtype, memory_config, optional_output_tensor); + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + return invoke( + DefaultQueueId, + input_tensor_a, + scalar, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template struct BinaryNg; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index d41261f60b9..fdc517c37d2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -7,6 +7,7 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" namespace ttnn::operations::binary_ng { @@ -14,33 +15,45 @@ template struct BinaryNg { static Tensor invoke( uint8_t queue_id, - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( uint8_t queue_id, - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); }; } // namespace ttnn::operations::binary_ng @@ -117,4 +130,5 @@ constexpr auto logaddexp = ttnn::register_operation_with_auto_launch_op< constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::logaddexp2", ttnn::operations::binary_ng::BinaryNg>(); -} + +} // namespace ttnn::experimental diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index da264795941..a392c476a8a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -24,8 +24,20 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, const uint8_t& queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor); + return self( + queue_id, + input_tensor_a, + scalar, + dtype, + memory_config, + output_tensor, + lhs_activations, + rhs_activations, + post_activations); }, py::arg("input_tensor_a"), py::arg("scalar"), @@ -33,6 +45,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}, // tensor and tensor @@ -43,8 +58,20 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor); + return self( + queue_id, + input_tensor_a, + input_tensor_b, + dtype, + memory_config, + output_tensor, + lhs_activations, + rhs_activations, + post_activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -52,6 +79,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}); } } // namespace detail @@ -77,4 +107,4 @@ void py_module(py::module& module) { detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp, "Binary Logaddexp Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp2, "Binary Logaddexp2 Operation"); } -} // namespace ttnn::operations::eltwise::binary_ng +} // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index b55c8a8cccd..5046ea600f2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -40,7 +40,14 @@ SubtileBroadcastType get_subtile_broadcast_type(uint32_t a_h, uint32_t a_w, uint tt::stl::hash::hash_t BinaryNgDeviceOperation::operation_attributes_t::to_hash() const { return tt::stl::hash::hash_objects_with_default_seed( - binary_op_type, memory_config, get_dtype(), compute_kernel_config, subtile_broadcast_type); + binary_op_type, + lhs_activations, + rhs_activations, + post_activations, + memory_config, + get_dtype(), + compute_kernel_config, + subtile_broadcast_type); } DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const { @@ -197,7 +204,10 @@ BinaryNgDeviceOperation::invoke( BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor) { + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { auto subtile_broadcast_type = get_subtile_broadcast_type( input_tensor_a_arg.get_logical_shape()[-2], input_tensor_a_arg.get_logical_shape()[-1], @@ -207,6 +217,9 @@ BinaryNgDeviceOperation::invoke( return { operation_attributes_t{ binary_op_type, + {lhs_activations.begin(), lhs_activations.end()}, + {rhs_activations.begin(), rhs_activations.end()}, + {post_activations.begin(), post_activations.end()}, std::nullopt, memory_config.value_or(input_tensor_a_arg.memory_config()), input_tensor_a_arg.get_dtype(), @@ -223,10 +236,16 @@ BinaryNgDeviceOperation::invoke( BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor) { + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { return { operation_attributes_t{ binary_op_type, + {lhs_activations.begin(), lhs_activations.end()}, + {rhs_activations.begin(), rhs_activations.end()}, + {post_activations.begin(), post_activations.end()}, scalar, memory_config.value_or(input_tensor_a_arg.memory_config()), input_tensor_a_arg.get_dtype(), diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp index 2305d7c261d..7b411694c8d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp @@ -8,7 +8,7 @@ #include "ttnn/device_operation.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" - +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" namespace ttnn::operations::binary_ng { enum class SubtileBroadcastType { @@ -31,6 +31,9 @@ struct BinaryNgDeviceOperation { struct operation_attributes_t { BinaryOpType binary_op_type; + ttnn::SmallVector lhs_activations; + ttnn::SmallVector rhs_activations; + ttnn::SmallVector post_activations; std::optional scalar; MemoryConfig memory_config; DataType input_dtype; @@ -86,7 +89,10 @@ struct BinaryNgDeviceOperation { BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor); + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); // tensor-scalar invocation static std::tuple invoke( @@ -95,7 +101,10 @@ struct BinaryNgDeviceOperation { BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor); + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); }; } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index 8abf63a2877..da6ad197e54 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -5,6 +5,7 @@ #include "binary_ng_utils.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/cb_utils.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" namespace { namespace CMAKE_UNIQUE_NAMESPACE { @@ -182,7 +183,39 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio Buffer* c_buffer = c.buffer(); auto op_type = operation_attributes.binary_op_type; - auto compute_kernel_defines = OpConfig(op_type).as_defines(); + OpConfig op_config(op_type); + auto compute_kernel_defines = op_config.as_defines(); + + { + ttnn::SmallVector lhs_activations = operation_attributes.lhs_activations; + ttnn::SmallVector rhs_activations = operation_attributes.rhs_activations; + ttnn::SmallVector post_activations = operation_attributes.post_activations; + + if (op_config.process_lhs.has_value()) { + lhs_activations.push_back(*op_config.process_lhs); + } + + if (op_config.process_rhs.has_value()) { + rhs_activations.push_back(*op_config.process_rhs); + } + + if (op_config.postprocess.has_value()) { + post_activations.insert(post_activations.begin(), *op_config.postprocess); + } + + add_activation_defines(compute_kernel_defines, lhs_activations, "LHS"); + add_activation_defines(compute_kernel_defines, rhs_activations, "RHS"); + + if (lhs_activations.empty() and rhs_activations.empty() and post_activations.size() == 1 and + post_activations[0] == unary::UnaryOpType::RELU) { + compute_kernel_defines["PACK_RELU"] = "1"; + compute_kernel_defines["PROCESS_POST_ACTIVATIONS(i)"] = ""; + unary::utils::update_macro_defines(unary::UnaryOpType::RELU, compute_kernel_defines); + } else { + add_activation_defines(compute_kernel_defines, post_activations, "POST"); + } + } + bool op_has_exp = op_type == BinaryOpType::LOGADDEXP || op_type == BinaryOpType::LDEXP || op_type == BinaryOpType::LOGADDEXP2; @@ -191,7 +224,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto [a_cb, a_cb_handle] = create_cb(tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); - if (compute_kernel_defines.find("PREPROCESS_A_INIT") != compute_kernel_defines.end()) { + if (not compute_kernel_defines["PROCESS_LHS_ACTIVATIONS(i)"].empty()) { auto a_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : a_data_format; uint32_t a_intermediate_single_tile_size = tt_metal::detail::TileSize(a_intermediate_format); auto [a_cb_interim, a_cb_interim_handle] = create_cb( @@ -206,7 +239,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto [b_cb, b_cb_handle] = create_cb(tt::CBIndex::c_1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); - if (compute_kernel_defines.find("PREPROCESS_B_INIT") != compute_kernel_defines.end()) { + if (not compute_kernel_defines["PROCESS_RHS_ACTIVATIONS(i)"].empty()) { auto b_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : b_data_format; uint32_t b_intermediate_single_tile_size = tt_metal::detail::TileSize(b_intermediate_format); auto [b_cb_interim, b_cb_interim_handle] = create_cb( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp index 3671cd9d10f..442b4a06aa3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp @@ -3,14 +3,19 @@ // SPDX-License-Identifier: Apache-2.0 #include "binary_ng_utils.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" +#include "tt_metal/common/assert.hpp" + +#include #include #include +#include #include template <> struct fmt::formatter : fmt::formatter { - auto format(ttnn::operations::binary_ng::Lowercase const& value, fmt::format_context& ctx) const { + auto format(const ttnn::operations::binary_ng::Lowercase& value, fmt::format_context& ctx) const { auto out = ctx.out(); for (char c : value.view) { *out++ = std::tolower(static_cast(c)); @@ -117,9 +122,6 @@ std::string get_kernel_file_path(KernelName kernel_name) { } } -constexpr OpConfig::SfpuConfig NezConfig("nez_tile_init", "nez_tile(i)"); -constexpr OpConfig::SfpuConfig GtzConfig("gtz_tile_init", "gtz_tile(i)"); - OpConfig::OpConfig(BinaryOpType binary_op_type) { fpu_binary_op = FpuBinaryOp::SUB; switch (binary_op_type) { @@ -127,74 +129,57 @@ OpConfig::OpConfig(BinaryOpType binary_op_type) { case BinaryOpType::SUB: break; case BinaryOpType::MUL: fpu_binary_op = FpuBinaryOp::MUL; break; case BinaryOpType::DIV: - preprocess_b = SfpuConfig("recip_tile_init", "recip_tile(i)", "compute_kernel_api/eltwise_unary/recip.h"); + process_rhs = unary::UnaryOpType::RECIP; fpu_binary_op = FpuBinaryOp::MUL; break; - case BinaryOpType::GT: postprocess = GtzConfig; break; - case BinaryOpType::LT: postprocess = SfpuConfig("ltz_tile_init", "ltz_tile(i)"); break; - case BinaryOpType::GTE: postprocess = SfpuConfig("gez_tile_init", "gez_tile(i)"); break; - case BinaryOpType::LTE: postprocess = SfpuConfig("lez_tile_init", "lez_tile(i)"); break; - case BinaryOpType::EQ: postprocess = SfpuConfig("eqz_tile_init", "eqz_tile(i)"); break; - case BinaryOpType::NE: postprocess = NezConfig; break; - case BinaryOpType::SQUARED_DIFFERENCE: postprocess = SfpuConfig("square_tile_init", "square_tile(i)"); break; + case BinaryOpType::GT: postprocess = unary::UnaryOpType::GTZ; break; + case BinaryOpType::LT: postprocess = unary::UnaryOpType::LTZ; break; + case BinaryOpType::GTE: postprocess = unary::UnaryOpType::GEZ; break; + case BinaryOpType::LTE: postprocess = unary::UnaryOpType::LEZ; break; + case BinaryOpType::EQ: postprocess = unary::UnaryOpType::EQZ; break; + case BinaryOpType::NE: postprocess = unary::UnaryOpType::NEZ; break; + case BinaryOpType::SQUARED_DIFFERENCE: postprocess = unary::UnaryOpType::SQUARE; break; case BinaryOpType::BIAS_GELU: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("gelu_tile_init", "gelu_tile(i)", "compute_kernel_api/eltwise_unary/gelu.h"); + process_lhs = unary::UnaryOpType::GELU; break; case BinaryOpType::LOGICAL_AND: fpu_binary_op = FpuBinaryOp::MUL; - postprocess = NezConfig; + postprocess = unary::UnaryOpType::NEZ; break; case BinaryOpType::LOGICAL_OR: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = GtzConfig; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::GTZ; break; case BinaryOpType::LOGICAL_XOR: - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = NezConfig; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::NEZ; break; case BinaryOpType::LDEXP: fpu_binary_op = FpuBinaryOp::MUL; - preprocess_b = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); + process_rhs = unary::UnaryOpType::EXP2; break; case BinaryOpType::LOGADDEXP: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("exp_tile_init", "exp_tile(i)", "compute_kernel_api/eltwise_unary/exp.h"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_tile_init", "log_tile(i)"); + process_lhs = unary::UnaryOpType::EXP; + process_rhs = unary::UnaryOpType::EXP; + postprocess = unary::UnaryOpType::LOG; break; case BinaryOpType::LOGADDEXP2: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_with_base_tile_init", "log_with_base_tile(i, 0x3dc5u);"); + process_lhs = unary::UnaryOpType::EXP2; + process_rhs = unary::UnaryOpType::EXP2; + postprocess = unary::UnaryOpType::LOG2; break; - default: __builtin_unreachable(); + default: TT_THROW("Unsupported binary op"); } } -std::map OpConfig::SfpuConfig::as_defines(std::string_view prefix) const { - if (init.empty()) { - return {}; - } - - std::map defines; - defines[fmt::format("{}_INIT", prefix)] = init; - defines[fmt::format("{}_APPLY(i)", prefix)] = apply; - defines[fmt::format("{}_INCLUDE", prefix)] = include; - return defines; -} - std::map OpConfig::as_defines() const { std::map defines; - defines.merge(preprocess_a.as_defines("PREPROCESS_A")); - defines.merge(preprocess_b.as_defines("PREPROCESS_B")); - defines.merge(postprocess.as_defines("POSTPROCESS")); auto binary_op_str = magic_enum::enum_name(fpu_binary_op); defines["BINARY_OP"] = fmt::format("{}_tiles", Lowercase{binary_op_str}); @@ -203,4 +188,21 @@ std::map OpConfig::as_defines() const { return defines; } +void add_activation_defines( + std::map& defines, + tt::stl::Span activations, + std::string_view operand) { + defines[fmt::format("PROCESS_{}_ACTIVATIONS(i)", operand)] = fmt::format( + "{}", + fmt::join( + activations | std::views::transform([](auto& a) { + return fmt::format("PROCESS_ACTIVATION({}, i)", magic_enum::enum_name(a)); + }), + ";")); + + for (auto& a : activations) { + unary::utils::update_macro_defines(a, defines); + } +} + } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp index cc8a242fc0c..b1ee7f41700 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp @@ -41,30 +41,23 @@ struct BinaryNgKernelConfig { std::string get_kernel_file_path(KernelName kernel_name); struct OpConfig { - struct SfpuConfig { - SfpuConfig() = default; - constexpr SfpuConfig( - std::string_view init, std::string_view apply, std::string_view include = "compute_kernel_api.h") : - init{init}, apply{apply}, include{include} {} - std::string_view init{}; - std::string_view apply{}; - std::string_view include{}; - - std::map as_defines(std::string_view prefix) const; - }; - enum class FpuBinaryOp { ADD, SUB, MUL }; OpConfig(BinaryOpType binary_op_type); std::map as_defines() const; - SfpuConfig preprocess_a{}; - SfpuConfig preprocess_b{}; - SfpuConfig postprocess{}; + std::optional process_lhs{}; + std::optional process_rhs{}; + std::optional postprocess{}; FpuBinaryOp fpu_binary_op; }; +void add_activation_defines( + std::map& defines, + tt::stl::Span activations, + std::string_view operand); + struct Lowercase { std::string_view view; }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp index c3f5f414d38..8f6f7990bf1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp @@ -3,9 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -22,40 +22,32 @@ ALWI void process_tile( constexpr uint32_t onetile = 1; #if BCAST_INPUT - auto cb_bcast = cb_post_rhs; - auto cb_other = cb_post_lhs; +#define CB_PRE_BCAST cb_pre_rhs +#define CB_POST_BCAST cb_post_rhs +#define CB_PRE_OTHER cb_pre_lhs +#define CB_POST_OTHER cb_post_lhs #else - auto cb_bcast = cb_post_lhs; - auto cb_other = cb_post_rhs; +#define CB_PRE_BCAST cb_pre_lhs +#define CB_POST_BCAST cb_post_lhs +#define CB_PRE_OTHER cb_pre_rhs +#define CB_POST_OTHER cb_post_rhs #endif -#if PREPROCESS_A && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - - cb_wait_front(cb_bcast, onetile); + PREPROCESS(BCAST_OP, CB_PRE_BCAST, CB_POST_BCAST, cb_out, onetile); + cb_wait_front(CB_POST_BCAST, onetile); for (uint32_t j = tile_start; j < freq; ++j) { -#if PREPROCESS_A && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - cb_wait_front(cb_other, onetile); + PREPROCESS(OTHER_OP, CB_PRE_OTHER, CB_POST_OTHER, cb_out, onetile); + cb_wait_front(CB_POST_OTHER, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); @@ -63,9 +55,9 @@ ALWI void process_tile( tile_regs_release(); cb_push_back(cb_out, onetile); - cb_pop_front(cb_other, onetile); + cb_pop_front(CB_POST_OTHER, onetile); } - cb_pop_front(cb_bcast, onetile); + cb_pop_front(CB_POST_BCAST, onetile); } void MAIN { @@ -81,12 +73,15 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp index 2e365c860c3..f2a263bfb1a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp @@ -3,10 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 #include + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "dprint.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -17,39 +17,35 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif constexpr uint32_t onetile = 1; for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); cb_wait_front(cb_post_lhs, onetile); -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); cb_wait_front(cb_post_rhs, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp index 0f468636223..b9f2e29903e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp @@ -3,9 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -16,39 +16,35 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif constexpr uint32_t onetile = 1; -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); cb_wait_front(cb_post_rhs, onetile); - for(uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); cb_wait_front(cb_post_lhs, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp deleted file mode 100644 index 5eacbd9b9dd..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#define DO_QUOTE(x) #x -#define QUOTE(x) DO_QUOTE(x) - -#if defined(PREPROCESS_A_INIT) -#define PREPROCESS_A 1 -#else -#define PREPROCESS_A 0 -#endif - -#if defined(PREPROCESS_B_INIT) -#define PREPROCESS_B 1 -#else -#define PREPROCESS_B 0 -#endif - -#if defined(POSTPROCESS_INIT) -#define POSTPROCESS 1 -#else -#define POSTPROCESS 0 -#endif - -#ifdef PREPROCESS_A_INCLUDE -#include QUOTE(PREPROCESS_A_INCLUDE) -#endif - -#ifdef PREPROCESS_B_INCLUDE -#include QUOTE(PREPROCESS_B_INCLUDE) -#endif - -#ifdef POSTPROCESS_INCLUDE -#include QUOTE(POSTPROCESS_INCLUDE) -#endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp index 63f350c47c1..af07339dbbc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp @@ -7,35 +7,114 @@ #include "compute_kernel_api/common.h" #include "compute_kernel_api/tile_move_copy.h" -#define PREPROCESS(init, apply, cb_pre, cb_post, cb_out, per_core_block_size) \ - do { \ - using namespace ckernel; \ - \ - copy_tile_to_dst_init_short(); \ - \ - reconfig_data_format_srca(/*old*/ cb_post, /*new*/ cb_pre); \ - pack_reconfig_data_format(/*old*/ cb_out, /*new*/ cb_post); \ - \ - cb_wait_front(cb_pre, per_core_block_size); \ - cb_reserve_back(cb_post, per_core_block_size); \ - \ - tile_regs_acquire(); \ - init(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - copy_tile(cb_pre, i, i); \ - apply(i); \ - } \ - tile_regs_commit(); \ - \ - tile_regs_wait(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - pack_tile(i, cb_post); /* DST[0]->cb */ \ - } \ - tile_regs_release(); \ - \ - cb_pop_front(cb_pre, per_core_block_size); \ - cb_push_back(cb_post, per_core_block_size); \ - \ - reconfig_data_format_srca(/*old*/ cb_pre, /*new*/ cb_post); \ - pack_reconfig_data_format(/*old*/ cb_post, /*new*/ cb_out); \ +#define ACTIVATION_INIT_RELU relu_tile_init +#define ACTIVATION_APPLY_RELU relu_tile + +#define ACTIVATION_INIT_SQUARE square_tile_init +#define ACTIVATION_APPLY_SQUARE square_tile + +#define ACTIVATION_INIT_GTZ gtz_tile_init +#define ACTIVATION_APPLY_GTZ gtz_tile + +#define ACTIVATION_INIT_LTZ ltz_tile_init +#define ACTIVATION_APPLY_LTZ ltz_tile + +#define ACTIVATION_INIT_GEZ gez_tile_init +#define ACTIVATION_APPLY_GEZ gez_tile + +#define ACTIVATION_INIT_LEZ lez_tile_init +#define ACTIVATION_APPLY_LEZ lez_tile + +#define ACTIVATION_INIT_EQZ eqz_tile_init +#define ACTIVATION_APPLY_EQZ eqz_tile + +#define ACTIVATION_INIT_NEZ nez_tile_init +#define ACTIVATION_APPLY_NEZ nez_tile + +#define ACTIVATION_INIT_LOG log_tile_init +#define ACTIVATION_APPLY_LOG log_tile + +#define ACTIVATION_INIT_LOG2 log_with_base_tile_init +#define ACTIVATION_APPLY_LOG2(i) log_with_base_tile(i, 0x3dc5u) + +#define ACTIVATION_INIT_EXP exp_tile_init +#define ACTIVATION_APPLY_EXP exp_tile + +#define ACTIVATION_INIT_EXP2 exp2_tile_init +#define ACTIVATION_APPLY_EXP2 exp2_tile + +#define ACTIVATION_INIT_RECIP recip_tile_init +#define ACTIVATION_APPLY_RECIP recip_tile + +#define ACTIVATION_INIT_GELU gelu_tile_init +#define ACTIVATION_APPLY_GELU gelu_tile + +#define IS_EMPTY(...) P_CAT(IS_EMPTY_, IS_BEGIN_PARENS(__VA_ARGS__))(__VA_ARGS__) +#define IS_EMPTY_0(...) IS_BEGIN_PARENS(IS_EMPTY_NON_FUNCTION_C __VA_ARGS__()) +#define IS_EMPTY_1(...) 0 +#define IS_EMPTY_NON_FUNCTION_C(...) () + +#define IS_BEGIN_PARENS(...) P_FIRST(P_CAT(P_IS_VARIADIC_R_, P_IS_VARIADIC_C __VA_ARGS__)) + +#define P_IS_VARIADIC_R_1 1, +#define P_IS_VARIADIC_R_P_IS_VARIADIC_C 0, +#define P_IS_VARIADIC_C(...) 1 + +#define P_FIRST(...) P_FIRST_(__VA_ARGS__, ) +#define P_FIRST_(a, ...) a + +#define P_CAT(a, ...) P_CAT_(a, __VA_ARGS__) +#define P_CAT_(a, ...) a##__VA_ARGS__ + +#define P_COMPL(b) P_CAT(P_COMPL_, b) +#define P_COMPL_0 1 +#define P_COMPL_1 0 + +#define ACTIVATION_INIT(elem) ACTIVATION_INIT_##elem() +#define ACTIVATION_APPLY(elem, i) ACTIVATION_APPLY_##elem(i) + +#define PROCESS_ACTIVATION(elem, i) \ + ACTIVATION_INIT(elem); \ + ACTIVATION_APPLY(elem, i) + +#define PROCESS_ACTIVATIONS(op, i) PROCESS_ACTIVATIONS_(op)(i) +#define PROCESS_ACTIVATIONS_(op) PROCESS_##op##_ACTIVATIONS +#define HAS_ACTIVATIONS(op) P_COMPL(IS_EMPTY(PROCESS_ACTIVATIONS(op, 0))) + +#define BCAST_OP P_CAT(BCAST_OP_, BCAST_INPUT) +#define OTHER_OP P_CAT(BCAST_OP_, P_COMPL(BCAST_INPUT)) +#define BCAST_OP_0 LHS +#define BCAST_OP_1 RHS + +#define PREPROCESS(op, ...) P_CAT(PREPROCESS_, HAS_ACTIVATIONS(op))(op, __VA_ARGS__) +#define PREPROCESS_0(...) +#define PREPROCESS_1(op, cb_pre, cb_post, cb_out, per_core_block_size) \ + do { \ + using namespace ckernel; \ + \ + reconfig_data_format_srca(/*old*/ cb_post, /*new*/ cb_pre); \ + pack_reconfig_data_format(/*old*/ cb_out, /*new*/ cb_post); \ + \ + cb_wait_front(cb_pre, per_core_block_size); \ + cb_reserve_back(cb_post, per_core_block_size); \ + \ + tile_regs_acquire(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + copy_tile_to_dst_init_short(); \ + copy_tile(cb_pre, i, i); \ + PROCESS_ACTIVATIONS(op, i); \ + } \ + tile_regs_commit(); \ + \ + tile_regs_wait(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + pack_tile(i, cb_post); \ + } \ + tile_regs_release(); \ + \ + cb_pop_front(cb_pre, per_core_block_size); \ + cb_push_back(cb_post, per_core_block_size); \ + \ + reconfig_data_format_srca(/*old*/ cb_pre, /*new*/ cb_post); \ + pack_reconfig_data_format(/*old*/ cb_post, /*new*/ cb_out); \ } while (0) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index 4ed08212b53..3c2c71bb06c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -27,60 +27,59 @@ union Converter { } }; -// update split eltwise ops include macros -void update_macro_defines(UnaryOpType op_type, std::map& defines) { +std::string get_macro_definition(UnaryOpType op_type) { switch (op_type) { - case UnaryOpType::EXP: defines["SFPU_OP_EXP_INCLUDE"] = "1"; break; - case UnaryOpType::GELU: defines["SFPU_OP_GELU_INCLUDE"] = "1"; break; - case UnaryOpType::RECIP: defines["SFPU_OP_RECIP_INCLUDE"] = "1"; break; - case UnaryOpType::SQRT: defines["SFPU_OP_SQRT_INCLUDE"] = "1"; break; - case UnaryOpType::ERFINV: defines["SFPU_OP_ERFINV_INCLUDE"] = "1"; break; + case UnaryOpType::EXP: return "SFPU_OP_EXP_INCLUDE"; + case UnaryOpType::GELU: return "SFPU_OP_GELU_INCLUDE"; + case UnaryOpType::RECIP: return "SFPU_OP_RECIP_INCLUDE"; + case UnaryOpType::SQRT: return "SFPU_OP_SQRT_INCLUDE"; + case UnaryOpType::ERFINV: return "SFPU_OP_ERFINV_INCLUDE"; case UnaryOpType::ERFC: - case UnaryOpType::ERF: defines["SFPU_OP_ERF_ERFC_INCLUDE"] = "1"; break; - case UnaryOpType::ELU: defines["SFPU_OP_ELU_INCLUDE"] = "1"; break; + case UnaryOpType::ERF: return "SFPU_OP_ERF_ERFC_INCLUDE"; + case UnaryOpType::ELU: return "SFPU_OP_ELU_INCLUDE"; case UnaryOpType::RELU: case UnaryOpType::RELU6: case UnaryOpType::RELU_MAX: case UnaryOpType::RELU_MIN: - case UnaryOpType::LEAKY_RELU: defines["SFPU_OP_RELU_FAMILY_INCLUDE"] = "1"; break; + case UnaryOpType::LEAKY_RELU: return "SFPU_OP_RELU_FAMILY_INCLUDE"; case UnaryOpType::ADD_UNARY_SFPU: case UnaryOpType::SUB_UNARY_SFPU: case UnaryOpType::MUL_UNARY_SFPU: - case UnaryOpType::DIV_UNARY_SFPU: defines["SFPU_OP_BINOP_WITH_SCALAR_INCLUDE"] = "1"; break; + case UnaryOpType::DIV_UNARY_SFPU: return "SFPU_OP_BINOP_WITH_SCALAR_INCLUDE"; case UnaryOpType::IDENTITY: - case UnaryOpType::IDENTITY_UINT32: defines["SFPU_OP_IDENTITY_INCLUDE"] = "1"; break; + case UnaryOpType::IDENTITY_UINT32: return "SFPU_OP_IDENTITY_INCLUDE"; case UnaryOpType::FLOOR: - case UnaryOpType::FLOOR_FLOAT32: defines["SFPU_OP_FLOOR_INCLUDE"] = "1"; break; + case UnaryOpType::FLOOR_FLOAT32: return "SFPU_OP_FLOOR_INCLUDE"; case UnaryOpType::CEIL: - case UnaryOpType::CEIL_FLOAT32: defines["SFPU_OP_CEIL_INCLUDE"] = "1"; break; - case UnaryOpType::RDIV: break; - case UnaryOpType::RSUB: defines["SFPU_OP_REVERSE_FAMILY_INCLUDE"] = "1"; + case UnaryOpType::CEIL_FLOAT32: return "SFPU_OP_CEIL_INCLUDE"; + case UnaryOpType::RDIV: + case UnaryOpType::RSUB: return "SFPU_OP_REVERSE_FAMILY_INCLUDE"; case UnaryOpType::ISINF: case UnaryOpType::ISNAN: case UnaryOpType::ISNEGINF: case UnaryOpType::ISPOSINF: - case UnaryOpType::ISFINITE: defines["SFPU_OP_ISINF_ISNAN_INCLUDE"] = "1"; break; - case UnaryOpType::LOGICAL_NOT_UNARY: defines["SFPU_OP_LOGICAL_NOT_NOTI_INCLUDE"] = "1"; break; - case UnaryOpType::I0: defines["SFPU_OP_I0_INCLUDE"] = "1"; break; - case UnaryOpType::I1: defines["SFPU_OP_I1_INCLUDE"] = "1"; break; + case UnaryOpType::ISFINITE: return "SFPU_OP_ISINF_ISNAN_INCLUDE"; + case UnaryOpType::LOGICAL_NOT_UNARY: return "SFPU_OP_LOGICAL_NOT_NOTI_INCLUDE"; + case UnaryOpType::I0: return "SFPU_OP_I0_INCLUDE"; + case UnaryOpType::I1: return "SFPU_OP_I1_INCLUDE"; case UnaryOpType::COS: case UnaryOpType::SIN: - case UnaryOpType::TAN: defines["SFPU_OP_TRIG_FAMILY_INCLUDE"] = "1"; break; - case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break; - case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break; - case UnaryOpType::PRELU_SFPU: defines["SFPU_OP_PRELU_INCLUDE"] = "1"; break; - case UnaryOpType::TYPECAST: defines["SFPU_OP_TYPECAST_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_XOR: defines["SFPU_OP_BITWISE_XOR_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_NOT: defines["SFPU_OP_BITWISE_NOT_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_AND: defines["SFPU_OP_BITWISE_AND_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_OR: defines["SFPU_OP_BITWISE_OR_INCLUDE"] = "1"; break; - case UnaryOpType::RIGHT_SHIFT: defines["SFPU_OP_RIGHT_SHIFT_INCLUDE"] = "1"; break; - case UnaryOpType::LEFT_SHIFT: defines["SFPU_OP_LEFT_SHIFT_INCLUDE"] = "1"; break; - case UnaryOpType::REMAINDER: defines["SFPU_OP_REMAINDER_INCLUDE"] = "1"; break; - case UnaryOpType::FMOD: defines["SFPU_OP_FMOD_INCLUDE"] = "1"; break; - case UnaryOpType::DROPOUT: defines["SFPU_OP_DROPOUT_INCLUDE"] = "1"; break; - case UnaryOpType::FILL: defines["SFPU_OP_FILL_INCLUDE"] = "1"; break; - default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break; + case UnaryOpType::TAN: return "SFPU_OP_TRIG_FAMILY_INCLUDE"; + case UnaryOpType::NEG: return "SFPU_OP_NEG_INCLUDE"; + case UnaryOpType::SOFTPLUS: return "SFPU_OP_SOFTPLUS_INCLUDE"; + case UnaryOpType::PRELU_SFPU: return "SFPU_OP_PRELU_INCLUDE"; + case UnaryOpType::TYPECAST: return "SFPU_OP_TYPECAST_INCLUDE"; + case UnaryOpType::BITWISE_XOR: return "SFPU_OP_BITWISE_XOR_INCLUDE"; + case UnaryOpType::BITWISE_NOT: return "SFPU_OP_BITWISE_NOT_INCLUDE"; + case UnaryOpType::BITWISE_AND: return "SFPU_OP_BITWISE_AND_INCLUDE"; + case UnaryOpType::BITWISE_OR: return "SFPU_OP_BITWISE_OR_INCLUDE"; + case UnaryOpType::RIGHT_SHIFT: return "SFPU_OP_RIGHT_SHIFT_INCLUDE"; + case UnaryOpType::LEFT_SHIFT: return "SFPU_OP_LEFT_SHIFT_INCLUDE"; + case UnaryOpType::REMAINDER: return "SFPU_OP_REMAINDER_INCLUDE"; + case UnaryOpType::FMOD: return "SFPU_OP_FMOD_INCLUDE"; + case UnaryOpType::DROPOUT: return "SFPU_OP_DROPOUT_INCLUDE"; + case UnaryOpType::FILL: return "SFPU_OP_FILL_INCLUDE"; + default: return "SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"; }; } @@ -451,4 +450,9 @@ std::map get_block_defines( return block_defines; } +// update split eltwise ops include macros +void update_macro_defines(UnaryOpType op_type, std::map& defines) { + defines[get_macro_definition(op_type)] = "1"; +} + } // namespace ttnn::operations::unary::utils diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp index 920ee695be9..e7046f61ffc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp @@ -71,4 +71,6 @@ bool is_parametrized_type(T val) { return false; } +void update_macro_defines(UnaryOpType op_type, std::map& defines); + } // namespace ttnn::operations::unary::utils diff --git a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp b/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp index 5adce7b92b4..90a984ae656 100644 --- a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp @@ -28,6 +28,7 @@ std::ostream& operator<<(std::ostream& os, const SmallVector 0) { os << ", "; } + using tt::stl::reflection::operator<<; os << vec[i]; } os << "])";