diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 346c8549e981..97cff73907d3 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -990,7 +990,6 @@ def test_nei_ttnn(input_shapes, scalar, device): assert comp_pass -@pytest.mark.skip(reason="#16165: Test is broken if you run after individually.") @pytest.mark.parametrize( "input_shapes", ( diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h index 6c23abe0a263..56bc3323bfa9 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h @@ -14,10 +14,14 @@ namespace ckernel { namespace sfpu { template -inline void calculate_sfpu_binary(const uint dst_offset) -{ +inline void calculate_sfpu_binary(const uint dst_offset) { _calculate_sfpu_binary_(dst_offset); } +template +inline void sfpu_binary_init() { + _sfpu_binary_init_(); +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h index 09fcb4a530dd..e870ffc804bb 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h @@ -12,9 +12,10 @@ namespace ckernel { // New LLK SFPU APIs -template +template inline void llk_math_eltwise_binary_sfpu_binop_init() { - llk_math_eltwise_binary_sfpu_init(); + llk_math_eltwise_binary_sfpu_init( + ckernel::sfpu::sfpu_binary_init); } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h index 6c23abe0a263..56bc3323bfa9 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_binary.h @@ -14,10 +14,14 @@ namespace ckernel { namespace sfpu { template -inline void calculate_sfpu_binary(const uint dst_offset) -{ +inline void calculate_sfpu_binary(const uint dst_offset) { _calculate_sfpu_binary_(dst_offset); } +template +inline void sfpu_binary_init() { + _sfpu_binary_init_(); +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h index 09fcb4a530dd..e870ffc804bb 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binop.h @@ -12,9 +12,10 @@ namespace ckernel { // New LLK SFPU APIs -template +template inline void llk_math_eltwise_binary_sfpu_binop_init() { - llk_math_eltwise_binary_sfpu_init(); + llk_math_eltwise_binary_sfpu_init( + ckernel::sfpu::sfpu_binary_init); } template diff --git a/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h b/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h index 239958919401..274c1ad219cd 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h +++ b/tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h @@ -61,6 +61,16 @@ ALWI void power_binary_tile(uint32_t idst0, uint32_t idst1) { /** * Please refer to documentation for any_init. */ -ALWI void eltwise_binop_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } +ALWI void add_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +ALWI void sub_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +ALWI void mul_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +ALWI void div_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +ALWI void rsub_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } + +ALWI void power_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init())); } } // namespace ckernel diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index c5735f6d4a8b..041753e8a963 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit c5735f6d4a8b66b6e46f26c4c655abb694875bd7 +Subproject commit 041753e8a96342c13fcd3a3220a13054e69f0910 diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 33a7f6a02671..d36e0cd178b2 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 33a7f6a026719af509a119d8a4e8e36c7c31854c +Subproject commit d36e0cd178b29624d3db8f6de5c9f7e54170d24c diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 8efea1c20f49..c6abe9dbdde8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -184,14 +184,30 @@ std::map get_defines_fp32( new_defines.insert({"ADD_INT32_INIT", fmt::format("add_int32_tile_init();")}); op_name = "add_int32_tile"; } else { + new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")}); op_name = "add_binary_tile"; } break; - case BinaryOpType::SUB: op_name = "sub_binary_tile"; break; - case BinaryOpType::MUL: op_name = "mul_binary_tile"; break; - case BinaryOpType::RSUB: op_name = "rsub_binary_tile"; break; - case BinaryOpType::POWER: op_name = "power_binary_tile"; break; - case BinaryOpType::DIV_FAST: op_name = "div_binary_tile"; break; + case BinaryOpType::SUB: + new_defines.insert({"BINOP_INIT", fmt::format("sub_binary_tile_init();")}); + op_name = "sub_binary_tile"; + break; + case BinaryOpType::MUL: + new_defines.insert({"BINOP_INIT", fmt::format("mul_binary_tile_init();")}); + op_name = "mul_binary_tile"; + break; + case BinaryOpType::RSUB: + new_defines.insert({"BINOP_INIT", fmt::format("rsub_binary_tile_init();")}); + op_name = "rsub_binary_tile"; + break; + case BinaryOpType::POWER: + new_defines.insert({"BINOP_INIT", fmt::format("power_binary_tile_init();")}); + op_name = "power_binary_tile"; + break; + case BinaryOpType::DIV_FAST: + new_defines.insert({"BINOP_INIT", fmt::format("div_binary_tile_init();")}); + op_name = "div_binary_tile"; + break; case BinaryOpType::BITWISE_AND: new_defines.insert({"BITWISE_INIT", fmt::format("binary_bitwise_tile_init();")}); op_name = "and_binary_tile"; @@ -217,12 +233,14 @@ std::map get_defines_fp32( // PRE_IN1_0 ====> Applies prescaling for second input new_defines.merge(get_defines(UnaryOpType::EXP, std::vector{0}, "PRE_IN0_0")); new_defines.merge(get_defines(UnaryOpType::EXP, std::vector{0}, "PRE_IN1_0")); + new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")}); op_name = "add_binary_tile"; new_defines.merge(get_defines(UnaryOpType::LOG, std::nullopt, "0", idst1)); break; case BinaryOpType::LOGADDEXP2: new_defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN0_0")); new_defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN1_0")); + new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")}); op_name = "add_binary_tile"; new_defines.merge(get_defines(UnaryOpType::LOG2, std::nullopt, "0", idst1)); break; @@ -239,12 +257,14 @@ std::map get_defines_fp32( new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "0", idst1)); break; case BinaryOpType::BIAS_GELU: + new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")}); op_name = "add_binary_tile"; new_defines.merge(get_defines(UnaryOpType::GELU, std::vector{0}, "0", idst1)); break; case BinaryOpType::LOGICAL_OR: new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN0_0")); new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN1_0")); + new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")}); op_name = "add_binary_tile"; new_defines.merge(get_defines(UnaryOpType::GTZ, std::nullopt, "0", idst1)); break; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp index 6970e4aa4980..0d4aa04d128b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp @@ -16,10 +16,6 @@ #define PRE_SCALE defined SFPU_OP_INIT_PRE_IN0_0 || defined SFPU_OP_INIT_PRE_IN1_0 -#if defined(ADD_INT32_INIT) || defined(BITWISE_INIT) || defined(SHIFT_INIT) -#define INT32_INIT -#endif - namespace NAMESPACE { void MAIN { uint32_t per_core_block_cnt = get_arg_val(0); @@ -111,10 +107,9 @@ void MAIN { for (uint32_t i = 0; i < per_core_block_size; ++i) { copy_tile(cb_inp1, i, i * 2 + 1); -#ifndef INT32_INIT - eltwise_binop_tile_init(); +#ifdef BINOP_INIT + BINOP_INIT #endif - #ifdef ADD_INT32_INIT ADD_INT32_INIT #endif