diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index 34774b95cc4..7f6f592bf6d 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -5,10 +5,8 @@ import torch import pytest import ttnn - -from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( - compare_pcc, -) +import random +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( @@ -19,30 +17,7 @@ (torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])), ), ) -@pytest.mark.parametrize( - "ttnn_fn", - [ - ttnn.experimental.gte, - ttnn.experimental.gt, - ttnn.experimental.lte, - ttnn.experimental.lt, - ttnn.experimental.eq, - ttnn.experimental.ne, - ttnn.experimental.logical_and, - ttnn.experimental.logical_or, - ttnn.experimental.logical_xor, - ttnn.experimental.ldexp, - ttnn.experimental.logaddexp, - ttnn.experimental.logaddexp2, - ttnn.experimental.squared_difference, - ttnn.experimental.add, - ttnn.experimental.sub, - ttnn.experimental.mul, - ttnn.experimental.div, - ttnn.experimental.bias_gelu, - ], -) -def test_binary_scalar_ops(input_shapes, ttnn_fn, device): +def test_binary_scalar_ops(input_shapes, device): a_shape, b_shape = input_shapes a_pt = torch.rand(a_shape).bfloat16() b_pt = torch.rand(b_shape).bfloat16() @@ -50,59 +25,13 @@ def test_binary_scalar_ops(input_shapes, ttnn_fn, device): a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) cq_id = 0 - out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id) - golden_fn = ttnn.get_golden_function(ttnn_fn) - out_pt = golden_fn(a_pt, b_pt) + out_tt = ttnn.experimental.add(a_tt, b_tt, queue_id=cq_id) + out_pt = a_pt + b_pt comp_pass = compare_pcc([out_tt], [out_pt]) assert comp_pass -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])), - (torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])), - (torch.Size([5, 1, 1, 64]), torch.Size([2, 3, 128, 1])), - ), -) -@pytest.mark.parametrize( - "ttnn_fn", - [ - ttnn.experimental.gte, - ttnn.experimental.gt, - ttnn.experimental.lte, - ttnn.experimental.lt, - ttnn.experimental.eq, - ttnn.experimental.ne, - ttnn.experimental.logical_and, - ttnn.experimental.logical_or, - ttnn.experimental.logical_xor, - ttnn.experimental.ldexp, - ttnn.experimental.logaddexp, - ttnn.experimental.logaddexp2, - ttnn.experimental.squared_difference, - ttnn.experimental.add, - ttnn.experimental.sub, - ttnn.experimental.mul, - ttnn.experimental.div, - ttnn.experimental.bias_gelu, - ], -) -def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device): - a_shape, b_shape = input_shapes - a_pt = torch.rand(a_shape).bfloat16() - b_pt = torch.rand(b_shape).bfloat16() - - a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - - with pytest.raises(RuntimeError) as e: - cq_id = 0 - _ = ttnn_fn(a_tt, b_tt, queue_id=cq_id) - assert "Broadcasting rule violation" in str(e.value) - - @pytest.mark.parametrize( "shapes", [ diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index a3bcf32d425..cf777f35fd7 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -123,7 +123,6 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 1439186fbd4..7e476e293b2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -52,22 +52,5 @@ Tensor BinaryNg::invoke( } template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; -template struct BinaryNg; } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index d41261f60b9..ac1faf4a4e7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -49,72 +49,4 @@ namespace ttnn::experimental { constexpr auto add = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::add", ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto sub = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::sub", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto mul = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::mul", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto div = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::div", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto eq = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::eq", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto ne = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::ne", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto gt = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::gt", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto gte = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::gte", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto lt = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::lt", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto lte = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::lte", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::squared_difference", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto bias_gelu = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::bias_gelu", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto logical_and = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::logical_and", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto logical_or = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::logical_or", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto logical_xor = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::logical_xor", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto ldexp = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::ldexp", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto logaddexp = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::logaddexp", - ttnn::operations::binary_ng::BinaryNg>(); - -constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op< - "ttnn::experimental::logaddexp2", - ttnn::operations::binary_ng::BinaryNg>(); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index da264795941..ab6c28e4421 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -9,16 +9,17 @@ namespace ttnn::operations::binary_ng { namespace detail { -template -void bind_binary_ng_operation(py::module& module, T op, const std::string& docstring) { +void bind_binary_ng_operation(py::module& module) { + using OperationType = decltype(ttnn::experimental::add); + bind_registered_operation( module, - op, - docstring, + ttnn::experimental::add, + "Binary Add Ng Operation", // tensor and scalar ttnn::pybind_overload_t{ - [](const T& self, + [](const OperationType& self, const ttnn::Tensor& input_tensor_a, const float scalar, const std::optional& dtype, @@ -37,7 +38,7 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst // tensor and tensor ttnn::pybind_overload_t{ - [](const T& self, + [](const OperationType& self, const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, const std::optional& dtype, @@ -56,25 +57,5 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst } } // namespace detail -void py_module(py::module& module) { - detail::bind_binary_ng_operation(module, ttnn::experimental::add, "Binary Add Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::sub, "Binary Sub Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::mul, "Binary Mul Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::div, "Binary Div Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::gt, "Binary Greater Than Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::lt, "Binary Less Than Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::lte, "Binary Less Than or Equal To Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::gte, "Binary Greater Than or Equal To Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::eq, "Binary Equal Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::ne, "Binary Not Equal Operation"); - detail::bind_binary_ng_operation( - module, ttnn::experimental::squared_difference, "Binary Squared Difference Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::bias_gelu, "Binary Bias GELU Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::logical_and, "Binary Logical And Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::logical_or, "Binary Logical Or Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::logical_xor, "Binary Logical Xor Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::ldexp, "Binary Ldexp Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp, "Binary Logaddexp Operation"); - detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp2, "Binary Logaddexp2 Operation"); -} +void py_module(py::module& module) { detail::bind_binary_ng_operation(module); } } // namespace ttnn::operations::eltwise::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.hpp index b8cc769c780..3a417d010cb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.hpp @@ -5,14 +5,12 @@ #pragma once #include "pybind11/pybind_fwd.hpp" -#include namespace py = pybind11; namespace ttnn::operations::binary_ng { namespace detail { -template -void bind_binary_ng_operation(py::module& module, T op, const std::string& docstring); +void bind_binary_ng_operation(py::module& module); } void py_module(py::module& module); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index b55c8a8cccd..0279ff3734e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -90,18 +90,18 @@ void BinaryNgDeviceOperation::validate_on_program_cache_hit( const auto input_shape_b = tensor_args.input_tensor_b.has_value() ? tensor_args.input_tensor_b->get_logical_shape() : ttnn::Shape{1, 1}; - const int rank_a = input_shape_a.rank(); - const int rank_b = input_shape_b.rank(); - const int larger_rank = std::max(rank_a, rank_b); - for (int i = -1; i >= -larger_rank; --i) { - auto a_dim = (i >= -rank_a) ? input_shape_a[i] : 1; - auto b_dim = (i >= -rank_b) ? input_shape_b[i] : 1; - TT_FATAL( - a_dim == b_dim || a_dim == 1 || b_dim == 1, - "Broadcasting rule violation for rank {}, dim a: {}, dim b: {}", - i, - a_dim, - b_dim); + constexpr int max_rank = 4; + if (input_shape_a.rank() > 0 && input_shape_b.rank() > 0) { + for (int i = 1; i <= max_rank; i++) { + auto a_dim = i <= input_shape_a.rank() ? input_shape_a[-i] : 1; + auto b_dim = i <= input_shape_b.rank() ? input_shape_b[-i] : 1; + TT_FATAL( + a_dim == b_dim || a_dim == 1 || b_dim == 1, + "Broadcasting rule violation for rank {}, dim a: {}, dim b: {}", + i, + a_dim, + b_dim); + } } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index d1125e8238f..bce5cfdc824 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "binary_ng_utils.hpp" +#include "binary_ng_device_operation.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/cb_utils.hpp" @@ -17,6 +17,124 @@ std::tuple extract_shape_dims(const Tens return {shape[-4], shape[-3], shape[-2] / tile.get_height(), shape[-1] / tile.get_width()}; } +enum class KernelName { + ReaderNoBcast, + ReaderRowBcast, + ReaderColBcast, + ReaderScalarBcast, + WriterNoBcast, + WriterRowBcast, + WriterColBcast, + WriterScalarBcast, + WriterScalar, + ComputeNoBcast, + ComputeBcast, + ComputeScalar +}; + +struct BinaryNgKernelConfig { + BinaryNgKernelConfig(SubtileBroadcastType subtile_broadcast_type) { + switch (subtile_broadcast_type) { + case SubtileBroadcastType::NONE: + reader_kernel = KernelName::ReaderNoBcast; + compute_kernel = KernelName::ComputeNoBcast; + writer_kernel = KernelName::WriterNoBcast; + bcast_input = std::nullopt; + break; + + case SubtileBroadcastType::SCALAR_A: + reader_kernel = KernelName::ReaderScalarBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterNoBcast; + bcast_input = 0; + break; + + case SubtileBroadcastType::SCALAR_B: + reader_kernel = KernelName::ReaderNoBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterScalarBcast; + bcast_input = 1; + break; + + case SubtileBroadcastType::ROW_A: + reader_kernel = KernelName::ReaderRowBcast; + compute_kernel = KernelName::ComputeNoBcast; + writer_kernel = KernelName::WriterNoBcast; + bcast_input = std::nullopt; + break; + + case SubtileBroadcastType::ROW_B: + reader_kernel = KernelName::ReaderNoBcast; + compute_kernel = KernelName::ComputeNoBcast; + writer_kernel = KernelName::WriterRowBcast; + bcast_input = std::nullopt; + break; + + case SubtileBroadcastType::COL_A: + reader_kernel = KernelName::ReaderColBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterNoBcast; + bcast_input = 0; + break; + + case SubtileBroadcastType::COL_B: + reader_kernel = KernelName::ReaderNoBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterColBcast; + bcast_input = 1; + break; + + case SubtileBroadcastType::ROW_A_COL_B: + reader_kernel = KernelName::ReaderRowBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterColBcast; + bcast_input = 1; + break; + + case SubtileBroadcastType::ROW_B_COL_A: + reader_kernel = KernelName::ReaderColBcast; + compute_kernel = KernelName::ComputeBcast; + writer_kernel = KernelName::WriterRowBcast; + bcast_input = 0; + break; + } + } + + std::string bcast_input_str() const { + if (bcast_input.has_value()) { + return std::to_string(*bcast_input); + } + return ""; + } + + KernelName reader_kernel; + KernelName compute_kernel; + KernelName writer_kernel; + std::optional bcast_input; +}; + +std::string get_kernel_file_path(KernelName kernel_name) { + constexpr std::string_view root = "ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels"; + constexpr std::string_view dataflow = "{}/dataflow/{}"; + constexpr std::string_view compute = "{}/compute/{}"; + + switch (kernel_name) { + case KernelName::ReaderNoBcast: return fmt::format(dataflow, root, "reader_interleaved_no_bcast.cpp"); + case KernelName::ReaderRowBcast: return fmt::format(dataflow, root, "reader_interleaved_row_bcast.cpp"); + case KernelName::ReaderColBcast: return fmt::format(dataflow, root, "reader_interleaved_col_bcast.cpp"); + case KernelName::ReaderScalarBcast: return fmt::format(dataflow, root, "reader_interleaved_scalar_bcast.cpp"); + case KernelName::WriterNoBcast: return fmt::format(dataflow, root, "writer_interleaved_no_bcast.cpp"); + case KernelName::WriterRowBcast: return fmt::format(dataflow, root, "writer_interleaved_row_bcast.cpp"); + case KernelName::WriterColBcast: return fmt::format(dataflow, root, "writer_interleaved_col_bcast.cpp"); + case KernelName::WriterScalarBcast: return fmt::format(dataflow, root, "writer_interleaved_scalar_bcast.cpp"); + case KernelName::WriterScalar: return fmt::format(dataflow, root, "writer_interleaved_scalar.cpp"); + case KernelName::ComputeNoBcast: return fmt::format(compute, root, "eltwise_binary_no_bcast.cpp"); + case KernelName::ComputeBcast: return fmt::format(compute, root, "eltwise_binary.cpp"); + case KernelName::ComputeScalar: return fmt::format(compute, root, "eltwise_binary_scalar.cpp"); + default: __builtin_unreachable(); // GCC 12 doesn't compile even though we exhaustively match + } +} + std::tuple calculate_compute_kernel_args( SubtileBroadcastType broadcast_type, uint32_t start_tile_id, uint32_t HtWt, uint32_t Wt) { uint32_t start_t = start_tile_id % HtWt; @@ -164,8 +282,6 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto b_data_format = b.has_value() ? datatype_to_dataformat_converter(b->get_dtype()) : DataFormat::Float16_b; auto c_data_format = datatype_to_dataformat_converter(c.get_dtype()); - tt::DataFormat b_intermediate_format = b_data_format; - uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); uint32_t c_single_tile_size = tt_metal::detail::TileSize(c_data_format); @@ -183,23 +299,10 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio Buffer* b_buffer = nullptr; Buffer* c_buffer = c.buffer(); - auto op_type = operation_attributes.binary_op_type; - auto compute_kernel_defines = OpConfig(op_type).as_defines(); - bool op_has_exp = - op_type == BinaryOpType::LOGADDEXP || op_type == BinaryOpType::LDEXP || op_type == BinaryOpType::LOGADDEXP2; - // How many tiles to store per input CB (double buffer) constexpr uint32_t num_tiles_per_cb = 2; auto [a_cb, a_cb_handle] = create_cb(tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); - - if (compute_kernel_defines.find("PREPROCESS_A_INIT") != compute_kernel_defines.end()) { - auto a_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : a_data_format; - uint32_t a_intermediate_single_tile_size = tt_metal::detail::TileSize(a_intermediate_format); - auto [a_cb_interim, a_cb_interim_handle] = create_cb( - tt::CBIndex::c_3, program, all_device_cores, a_intermediate_single_tile_size, 1, a_intermediate_format); - } - auto [c_cb, c_cb_handle] = create_cb(tt::CBIndex::c_2, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format); @@ -208,13 +311,6 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto [b_cb, b_cb_handle] = create_cb(tt::CBIndex::c_1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); - if (compute_kernel_defines.find("PREPROCESS_B_INIT") != compute_kernel_defines.end()) { - auto b_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : b_data_format; - uint32_t b_intermediate_single_tile_size = tt_metal::detail::TileSize(b_intermediate_format); - auto [b_cb_interim, b_cb_interim_handle] = create_cb( - tt::CBIndex::c_4, program, all_device_cores, b_intermediate_single_tile_size, 1, b_intermediate_format); - } - auto a_is_dram = static_cast(a_buffer->buffer_type() == tt_metal::BufferType::DRAM); bool b_is_dram = false; auto c_is_dram = static_cast(c_buffer->buffer_type() == tt_metal::BufferType::DRAM); @@ -251,12 +347,12 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio // Compute kernel needs to know which op it's going to perform // This has to be passed as a compile-time argument // For now we're just going to do addition - compute_kernel_defines["BCAST_INPUT"] = kernel_config.bcast_input_str(); auto compute_kernel_id = tt_metal::CreateKernel( program, get_kernel_file_path(compute_kernel), all_device_cores, - tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .defines = compute_kernel_defines}); + tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, .defines = {{"BCAST_INPUT", kernel_config.bcast_input_str()}}}); auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { tt_metal::SetRuntimeArgs(program, kernel_id, core, args); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp deleted file mode 100644 index 3671cd9d10f..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp +++ /dev/null @@ -1,206 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "binary_ng_utils.hpp" - -#include -#include -#include - -template <> -struct fmt::formatter : fmt::formatter { - auto format(ttnn::operations::binary_ng::Lowercase const& value, fmt::format_context& ctx) const { - auto out = ctx.out(); - for (char c : value.view) { - *out++ = std::tolower(static_cast(c)); - } - return out; - } -}; - -namespace ttnn::operations::binary_ng { - -BinaryNgKernelConfig::BinaryNgKernelConfig(SubtileBroadcastType subtile_broadcast_type) { - switch (subtile_broadcast_type) { - case SubtileBroadcastType::NONE: - reader_kernel = KernelName::ReaderNoBcast; - compute_kernel = KernelName::ComputeNoBcast; - writer_kernel = KernelName::WriterNoBcast; - bcast_input = std::nullopt; - break; - - case SubtileBroadcastType::SCALAR_A: - reader_kernel = KernelName::ReaderScalarBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterNoBcast; - bcast_input = 0; - break; - - case SubtileBroadcastType::SCALAR_B: - reader_kernel = KernelName::ReaderNoBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterScalarBcast; - bcast_input = 1; - break; - - case SubtileBroadcastType::ROW_A: - reader_kernel = KernelName::ReaderRowBcast; - compute_kernel = KernelName::ComputeNoBcast; - writer_kernel = KernelName::WriterNoBcast; - bcast_input = std::nullopt; - break; - - case SubtileBroadcastType::ROW_B: - reader_kernel = KernelName::ReaderNoBcast; - compute_kernel = KernelName::ComputeNoBcast; - writer_kernel = KernelName::WriterRowBcast; - bcast_input = std::nullopt; - break; - - case SubtileBroadcastType::COL_A: - reader_kernel = KernelName::ReaderColBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterNoBcast; - bcast_input = 0; - break; - - case SubtileBroadcastType::COL_B: - reader_kernel = KernelName::ReaderNoBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterColBcast; - bcast_input = 1; - break; - - case SubtileBroadcastType::ROW_A_COL_B: - reader_kernel = KernelName::ReaderRowBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterColBcast; - bcast_input = 1; - break; - - case SubtileBroadcastType::ROW_B_COL_A: - reader_kernel = KernelName::ReaderColBcast; - compute_kernel = KernelName::ComputeBcast; - writer_kernel = KernelName::WriterRowBcast; - bcast_input = 0; - break; - } -} - -std::string BinaryNgKernelConfig::bcast_input_str() const { - if (bcast_input.has_value()) { - return std::to_string(*bcast_input); - } - return ""; -} - -std::string get_kernel_file_path(KernelName kernel_name) { - constexpr std::string_view root = "ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels"; - constexpr std::string_view dataflow = "{}/dataflow/{}"; - constexpr std::string_view compute = "{}/compute/{}"; - - switch (kernel_name) { - case KernelName::ReaderNoBcast: return fmt::format(dataflow, root, "reader_interleaved_no_bcast.cpp"); - case KernelName::ReaderRowBcast: return fmt::format(dataflow, root, "reader_interleaved_row_bcast.cpp"); - case KernelName::ReaderColBcast: return fmt::format(dataflow, root, "reader_interleaved_col_bcast.cpp"); - case KernelName::ReaderScalarBcast: return fmt::format(dataflow, root, "reader_interleaved_scalar_bcast.cpp"); - case KernelName::WriterNoBcast: return fmt::format(dataflow, root, "writer_interleaved_no_bcast.cpp"); - case KernelName::WriterRowBcast: return fmt::format(dataflow, root, "writer_interleaved_row_bcast.cpp"); - case KernelName::WriterColBcast: return fmt::format(dataflow, root, "writer_interleaved_col_bcast.cpp"); - case KernelName::WriterScalarBcast: return fmt::format(dataflow, root, "writer_interleaved_scalar_bcast.cpp"); - case KernelName::WriterScalar: return fmt::format(dataflow, root, "writer_interleaved_scalar.cpp"); - case KernelName::ComputeNoBcast: return fmt::format(compute, root, "eltwise_binary_no_bcast.cpp"); - case KernelName::ComputeBcast: return fmt::format(compute, root, "eltwise_binary.cpp"); - case KernelName::ComputeScalar: return fmt::format(compute, root, "eltwise_binary_scalar.cpp"); - default: __builtin_unreachable(); // GCC 12 doesn't compile even though we exhaustively match - } -} - -constexpr OpConfig::SfpuConfig NezConfig("nez_tile_init", "nez_tile(i)"); -constexpr OpConfig::SfpuConfig GtzConfig("gtz_tile_init", "gtz_tile(i)"); - -OpConfig::OpConfig(BinaryOpType binary_op_type) { - fpu_binary_op = FpuBinaryOp::SUB; - switch (binary_op_type) { - case BinaryOpType::ADD: fpu_binary_op = FpuBinaryOp::ADD; break; - case BinaryOpType::SUB: break; - case BinaryOpType::MUL: fpu_binary_op = FpuBinaryOp::MUL; break; - case BinaryOpType::DIV: - preprocess_b = SfpuConfig("recip_tile_init", "recip_tile(i)", "compute_kernel_api/eltwise_unary/recip.h"); - fpu_binary_op = FpuBinaryOp::MUL; - break; - case BinaryOpType::GT: postprocess = GtzConfig; break; - case BinaryOpType::LT: postprocess = SfpuConfig("ltz_tile_init", "ltz_tile(i)"); break; - case BinaryOpType::GTE: postprocess = SfpuConfig("gez_tile_init", "gez_tile(i)"); break; - case BinaryOpType::LTE: postprocess = SfpuConfig("lez_tile_init", "lez_tile(i)"); break; - case BinaryOpType::EQ: postprocess = SfpuConfig("eqz_tile_init", "eqz_tile(i)"); break; - case BinaryOpType::NE: postprocess = NezConfig; break; - case BinaryOpType::SQUARED_DIFFERENCE: postprocess = SfpuConfig("square_tile_init", "square_tile(i)"); break; - case BinaryOpType::BIAS_GELU: - fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("gelu_tile_init", "gelu_tile(i)", "compute_kernel_api/eltwise_unary/gelu.h"); - break; - case BinaryOpType::LOGICAL_AND: - fpu_binary_op = FpuBinaryOp::MUL; - postprocess = NezConfig; - break; - case BinaryOpType::LOGICAL_OR: - fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = GtzConfig; - break; - case BinaryOpType::LOGICAL_XOR: - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = NezConfig; - break; - case BinaryOpType::LDEXP: - fpu_binary_op = FpuBinaryOp::MUL; - preprocess_b = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); - break; - case BinaryOpType::LOGADDEXP: - fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("exp_tile_init", "exp_tile(i)", "compute_kernel_api/eltwise_unary/exp.h"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_tile_init", "log_tile(i)"); - break; - case BinaryOpType::LOGADDEXP2: - fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_with_base_tile_init", "log_with_base_tile(i, 0x3dc5u);"); - break; - default: __builtin_unreachable(); - } -} - -std::map OpConfig::SfpuConfig::as_defines(std::string_view prefix) const { - if (init.empty()) { - return {}; - } - - std::map defines; - defines[fmt::format("{}_INIT", prefix)] = init; - defines[fmt::format("{}_APPLY(i)", prefix)] = apply; - defines[fmt::format("{}_INCLUDE", prefix)] = include; - return defines; -} - -std::map OpConfig::as_defines() const { - std::map defines; - defines.merge(preprocess_a.as_defines("PREPROCESS_A")); - defines.merge(preprocess_b.as_defines("PREPROCESS_B")); - defines.merge(postprocess.as_defines("POSTPROCESS")); - - auto binary_op_str = magic_enum::enum_name(fpu_binary_op); - defines["BINARY_OP"] = fmt::format("{}_tiles", Lowercase{binary_op_str}); - defines["BINARY_OP_TYPE"] = fmt::format("EltwiseBinaryType::ELW{}", binary_op_str); - - return defines; -} - -} // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp deleted file mode 100644 index cc8a242fc0c..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "binary_ng_device_operation.hpp" -#include "ttnn/operations/eltwise/binary_ng/types.hpp" - -#include -#include - -namespace ttnn::operations::binary_ng { - -enum class KernelName { - ReaderNoBcast, - ReaderRowBcast, - ReaderColBcast, - ReaderScalarBcast, - WriterNoBcast, - WriterRowBcast, - WriterColBcast, - WriterScalarBcast, - WriterScalar, - ComputeNoBcast, - ComputeBcast, - ComputeScalar -}; - -struct BinaryNgKernelConfig { - BinaryNgKernelConfig(SubtileBroadcastType subtile_broadcast_type); - - std::string bcast_input_str() const; - - KernelName reader_kernel; - KernelName compute_kernel; - KernelName writer_kernel; - std::optional bcast_input; -}; - -std::string get_kernel_file_path(KernelName kernel_name); - -struct OpConfig { - struct SfpuConfig { - SfpuConfig() = default; - constexpr SfpuConfig( - std::string_view init, std::string_view apply, std::string_view include = "compute_kernel_api.h") : - init{init}, apply{apply}, include{include} {} - std::string_view init{}; - std::string_view apply{}; - std::string_view include{}; - - std::map as_defines(std::string_view prefix) const; - }; - - enum class FpuBinaryOp { ADD, SUB, MUL }; - - OpConfig(BinaryOpType binary_op_type); - - std::map as_defines() const; - - SfpuConfig preprocess_a{}; - SfpuConfig preprocess_b{}; - SfpuConfig postprocess{}; - FpuBinaryOp fpu_binary_op; -}; - -struct Lowercase { - std::string_view view; -}; - -} // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp index ca31d410f39..dff89bfd613 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp @@ -2,72 +2,24 @@ // // SPDX-License-Identifier: Apache-2.0 -#include #include "compute_kernel_api/eltwise_binary.h" -#include "eltwise_defines.hpp" -#include "eltwise_utils.hpp" - -#ifdef PREPROCESS_A_INCLUDE -#include QUOTE(PREPROCESS_A_INCLUDE) -#endif +#include -#ifdef PREPROCESS_B_INCLUDE -#include QUOTE(PREPROCESS_B_INCLUDE) -#endif - -#ifdef POSTPROCESS_INCLUDE -#include QUOTE(POSTPROCESS_INCLUDE) -#endif namespace NAMESPACE { -ALWI void process_tile( - tt::CBIndex cb_pre_lhs, - tt::CBIndex cb_post_lhs, - tt::CBIndex cb_pre_rhs, - tt::CBIndex cb_post_rhs, - tt::CBIndex cb_out, - uint32_t freq, - uint32_t tile_start) { - using namespace ckernel; +ALWI void process_tile(uint32_t cb_bcast, uint32_t cb_other, uint32_t cb_out, uint32_t freq, uint32_t tile_start) { constexpr uint32_t onetile = 1; -#if BCAST_INPUT - auto cb_bcast = cb_post_rhs; - auto cb_other = cb_post_lhs; -#else - auto cb_bcast = cb_post_lhs; - auto cb_other = cb_post_rhs; -#endif - -#if PREPROCESS_A && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - cb_wait_front(cb_bcast, onetile); for (uint32_t j = tile_start; j < freq; ++j) { -#if PREPROCESS_A && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif cb_wait_front(cb_other, onetile); - cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B - binary_op_specific_init(); -#endif tile_regs_acquire(); - BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + add_tiles(cb_bcast, cb_other, 0, 0, 0); tile_regs_commit(); tile_regs_wait(); @@ -89,28 +41,30 @@ void MAIN { return; } - constexpr auto cb_pre_lhs = tt::CBIndex::c_0; - constexpr auto cb_pre_rhs = tt::CBIndex::c_1; - constexpr auto cb_out = tt::CBIndex::c_2; + constexpr auto cb_in0 = tt::CBIndex::c_0; + constexpr auto cb_in1 = tt::CBIndex::c_1; + constexpr auto cb_out0 = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; - - binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); - -#if not(PREPROCESS_A || PREPROCESS_B) - binary_op_specific_init(); +#if BCAST_INPUT + auto cb_bcast = cb_in1; + auto cb_other = cb_in0; +#else + auto cb_bcast = cb_in0; + auto cb_other = cb_in1; #endif + binary_op_init_common(cb_bcast, cb_other, cb_out0); + add_tiles_init(); + uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { - process_tile(cb_pre_lhs, cb_post_lhs, cb_pre_rhs, cb_post_rhs, cb_out, tile_freq, tile_start); + process_tile(cb_bcast, cb_other, cb_out0, tile_freq, tile_start); } if (remaining_iterations > 0) { - process_tile(cb_pre_lhs, cb_post_lhs, cb_pre_rhs, cb_post_rhs, cb_out, remaining_iterations, tile_start); + process_tile(cb_bcast, cb_other, cb_out0, remaining_iterations, tile_start); } } } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp index e1d3fb08997..1e3703fa4b8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp @@ -6,71 +6,35 @@ #include "compute_kernel_api/eltwise_binary.h" #include "dprint.h" -#include "eltwise_defines.hpp" -#include "eltwise_utils.hpp" - -#ifdef PREPROCESS_A_INCLUDE -#include QUOTE(PREPROCESS_A_INCLUDE) -#endif - -#ifdef PREPROCESS_B_INCLUDE -#include QUOTE(PREPROCESS_B_INCLUDE) -#endif - -#ifdef POSTPROCESS_INCLUDE -#include QUOTE(POSTPROCESS_INCLUDE) -#endif - namespace NAMESPACE { void MAIN { uint32_t num_tiles = get_arg_val(0); - constexpr auto cb_pre_lhs = tt::CBIndex::c_0; - constexpr auto cb_pre_rhs = tt::CBIndex::c_1; - constexpr auto cb_out = tt::CBIndex::c_2; + constexpr auto cb_in0 = tt::CBIndex::c_0; + constexpr auto cb_in1 = tt::CBIndex::c_1; + constexpr auto cb_out0 = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; - - binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); - -#if not(PREPROCESS_A || PREPROCESS_B) - binary_op_specific_init(); -#endif + binary_op_init_common(cb_in0, cb_in1, cb_out0); + add_tiles_init(); constexpr uint32_t onetile = 1; - for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif - cb_wait_front(cb_post_lhs, onetile); - -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - cb_wait_front(cb_post_rhs, onetile); - - cb_reserve_back(cb_out, onetile); + for(uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + cb_wait_front(cb_in0, onetile); + cb_wait_front(cb_in1, onetile); + cb_reserve_back(cb_out0, onetile); -#if PREPROCESS_A || PREPROCESS_B - binary_op_specific_init(); -#endif tile_regs_acquire(); - BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + add_tiles(cb_in0, cb_in1, 0, 0, 0); tile_regs_commit(); tile_regs_wait(); - pack_tile(0, cb_out); + pack_tile(0, cb_out0); tile_regs_release(); - cb_push_back(cb_out, onetile); - cb_pop_front(cb_post_lhs, onetile); - cb_pop_front(cb_post_rhs, onetile); + cb_push_back(cb_out0, onetile); + cb_pop_front(cb_in0, onetile); + cb_pop_front(cb_in1, onetile); } } } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp index 92e2b95be2e..2b377fb52ca 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp @@ -4,71 +4,37 @@ #include #include "compute_kernel_api/eltwise_binary.h" - -#include "eltwise_defines.hpp" -#include "eltwise_utils.hpp" - -#ifdef PREPROCESS_A_INCLUDE -#include QUOTE(PREPROCESS_A_INCLUDE) -#endif - -#ifdef PREPROCESS_B_INCLUDE -#include QUOTE(PREPROCESS_B_INCLUDE) -#endif - -#ifdef POSTPROCESS_INCLUDE -#include QUOTE(POSTPROCESS_INCLUDE) -#endif +#include "dprint.h" namespace NAMESPACE { void MAIN { uint32_t num_tiles = get_arg_val(0); - constexpr auto cb_pre_lhs = tt::CBIndex::c_0; - constexpr auto cb_pre_rhs = tt::CBIndex::c_1; - constexpr auto cb_out = tt::CBIndex::c_2; + constexpr auto cb_in0 = tt::CBIndex::c_0; + constexpr auto cb_in1 = tt::CBIndex::c_1; + constexpr auto cb_out0 = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; - - binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); - -#if not(PREPROCESS_A || PREPROCESS_B) - binary_op_specific_init(); -#endif + binary_op_init_common(cb_in0, cb_in1, cb_out0); + add_tiles_init(); constexpr uint32_t onetile = 1; -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - cb_wait_front(cb_post_rhs, onetile); + cb_wait_front(cb_in1, onetile); for(uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif - cb_wait_front(cb_post_lhs, onetile); - - cb_reserve_back(cb_out, onetile); + cb_wait_front(cb_in0, onetile); + cb_reserve_back(cb_out0, onetile); -#if PREPROCESS_A || PREPROCESS_B - binary_op_specific_init(); -#endif tile_regs_acquire(); - BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + add_tiles(cb_in0, cb_in1, 0, 0, 0); tile_regs_commit(); tile_regs_wait(); - pack_tile(0, cb_out); + pack_tile(0, cb_out0); tile_regs_release(); - cb_pop_front(cb_post_lhs, onetile); - cb_push_back(cb_out, onetile); + cb_pop_front(cb_in0, onetile); + cb_push_back(cb_out0, onetile); } } } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp deleted file mode 100644 index 5238d3b8db6..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#define DO_QUOTE(x) #x -#define QUOTE(x) DO_QUOTE(x) - -#if defined(PREPROCESS_A_INIT) -#define PREPROCESS_A 1 -#else -#define PREPROCESS_A 0 -#endif - -#if defined(PREPROCESS_B_INIT) -#define PREPROCESS_B 1 -#else -#define PREPROCESS_B 0 -#endif - -#if defined(POSTPROCESS_INIT) -#define POSTPROCESS 1 -#else -#define POSTPROCESS 0 -#endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp deleted file mode 100644 index 63f350c47c1..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "compute_kernel_api/common.h" -#include "compute_kernel_api/tile_move_copy.h" - -#define PREPROCESS(init, apply, cb_pre, cb_post, cb_out, per_core_block_size) \ - do { \ - using namespace ckernel; \ - \ - copy_tile_to_dst_init_short(); \ - \ - reconfig_data_format_srca(/*old*/ cb_post, /*new*/ cb_pre); \ - pack_reconfig_data_format(/*old*/ cb_out, /*new*/ cb_post); \ - \ - cb_wait_front(cb_pre, per_core_block_size); \ - cb_reserve_back(cb_post, per_core_block_size); \ - \ - tile_regs_acquire(); \ - init(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - copy_tile(cb_pre, i, i); \ - apply(i); \ - } \ - tile_regs_commit(); \ - \ - tile_regs_wait(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - pack_tile(i, cb_post); /* DST[0]->cb */ \ - } \ - tile_regs_release(); \ - \ - cb_pop_front(cb_pre, per_core_block_size); \ - cb_push_back(cb_post, per_core_block_size); \ - \ - reconfig_data_format_srca(/*old*/ cb_pre, /*new*/ cb_post); \ - pack_reconfig_data_format(/*old*/ cb_post, /*new*/ cb_out); \ - } while (0) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp index 06c53bfe7e6..ccb5695b3ac 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp @@ -10,7 +10,6 @@ enum class BinaryOpType { ADD, SUB, MUL, - DIV, GT, LT, LTE, @@ -19,11 +18,13 @@ enum class BinaryOpType { NE, SQUARED_DIFFERENCE, BIAS_GELU, + LOGADDEXP, LOGICAL_AND, LOGICAL_OR, LOGICAL_XOR, LDEXP, - LOGADDEXP, LOGADDEXP2, + DIV_FAST }; + } diff --git a/ttnn/ttnn/operations/binary_ng.py b/ttnn/ttnn/operations/binary_ng.py deleted file mode 100644 index d012271bda5..00000000000 --- a/ttnn/ttnn/operations/binary_ng.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import torch - - -ttnn.attach_golden_function(ttnn.experimental.add, golden_function=lambda a, b: a + b) -ttnn.attach_golden_function(ttnn.experimental.sub, golden_function=lambda a, b: a - b) -ttnn.attach_golden_function(ttnn.experimental.mul, golden_function=lambda a, b: a * b) -ttnn.attach_golden_function(ttnn.experimental.div, golden_function=lambda a, b: torch.divide(a, b)) -ttnn.attach_golden_function(ttnn.experimental.eq, golden_function=lambda a, b: torch.eq(a, b)) -ttnn.attach_golden_function(ttnn.experimental.ne, golden_function=lambda a, b: torch.ne(a, b)) -ttnn.attach_golden_function(ttnn.experimental.gt, golden_function=lambda a, b: torch.gt(a, b)) -ttnn.attach_golden_function(ttnn.experimental.lt, golden_function=lambda a, b: torch.lt(a, b)) -ttnn.attach_golden_function(ttnn.experimental.gte, golden_function=lambda a, b: torch.ge(a, b)) -ttnn.attach_golden_function(ttnn.experimental.lte, golden_function=lambda a, b: torch.le(a, b)) -ttnn.attach_golden_function(ttnn.experimental.ldexp, golden_function=lambda a, b: torch.ldexp(a, b)) -ttnn.attach_golden_function(ttnn.experimental.logaddexp, golden_function=lambda a, b: torch.logaddexp(a, b)) -ttnn.attach_golden_function(ttnn.experimental.logaddexp2, golden_function=lambda a, b: torch.logaddexp2(a, b)) -ttnn.attach_golden_function(ttnn.experimental.logical_and, golden_function=lambda a, b: torch.logical_and(a, b)) -ttnn.attach_golden_function(ttnn.experimental.logical_or, golden_function=lambda a, b: torch.logical_or(a, b)) -ttnn.attach_golden_function(ttnn.experimental.logical_xor, golden_function=lambda a, b: torch.logical_xor(a, b)) -ttnn.attach_golden_function( - ttnn.experimental.squared_difference, golden_function=lambda a, b: torch.square(torch.sub(a, b)) -) -ttnn.attach_golden_function( - ttnn.experimental.bias_gelu, golden_function=lambda a, b: torch.nn.functional.gelu(torch.add(a, b)) -)