Skip to content

Commit

Permalink
#16165: Add binary sfpu div init
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Dec 21, 2024
1 parent 31dca41 commit 2c7e9ad
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,6 @@ def test_nei_ttnn(input_shapes, scalar, device):
assert comp_pass


@pytest.mark.skip(reason="#16165: Test is broken if you run after individually.")
@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int BINOP_MODE, int ITERATIONS = 8>
inline void calculate_sfpu_binary(const uint dst_offset)
{
inline void calculate_sfpu_binary(const uint dst_offset) {
_calculate_sfpu_binary_<APPROXIMATION_MODE, BINOP_MODE, ITERATIONS>(dst_offset);
}

template <bool APPROXIMATION_MODE /*unused*/, int BINOP_MODE>
inline void sfpu_binary_init() {
_sfpu_binary_init_<APPROXIMATION_MODE, BINOP_MODE>();
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
template <bool APPROXIMATE, int binop_mode>
inline void llk_math_eltwise_binary_sfpu_binop_init() {
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>();
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>(
ckernel::sfpu::sfpu_binary_init<APPROXIMATE, binop_mode>);
}

template <bool APPROXIMATE, int binop_mode>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int BINOP_MODE, int ITERATIONS = 8>
inline void calculate_sfpu_binary(const uint dst_offset)
{
inline void calculate_sfpu_binary(const uint dst_offset) {
_calculate_sfpu_binary_<APPROXIMATION_MODE, BINOP_MODE, ITERATIONS>(dst_offset);
}

template <bool APPROXIMATION_MODE /*unused*/, int BINOP_MODE>
inline void sfpu_binary_init() {
_sfpu_binary_init_<APPROXIMATION_MODE, BINOP_MODE>();
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
template <bool APPROXIMATE, int binop_mode>
inline void llk_math_eltwise_binary_sfpu_binop_init() {
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>();
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>(
ckernel::sfpu::sfpu_binary_init<APPROXIMATE, binop_mode>);
}

template <bool APPROXIMATE, int binop_mode>
Expand Down
12 changes: 11 additions & 1 deletion tt_metal/include/compute_kernel_api/eltwise_binary_sfpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ ALWI void power_binary_tile(uint32_t idst0, uint32_t idst1) {
/**
* Please refer to documentation for any_init.
*/
ALWI void eltwise_binop_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX>())); }
ALWI void add_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, ADD_BINARY>())); }

ALWI void sub_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, SUB_BINARY>())); }

ALWI void mul_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, MUL_BINARY>())); }

ALWI void div_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, DIV_BINARY>())); }

ALWI void rsub_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, RSUB_BINARY>())); }

ALWI void power_binary_tile_init() { MATH((llk_math_eltwise_binary_sfpu_binop_init<APPROX, POW_BINARY>())); }

} // namespace ckernel
2 changes: 1 addition & 1 deletion tt_metal/third_party/tt_llk_blackhole
2 changes: 1 addition & 1 deletion tt_metal/third_party/tt_llk_wormhole_b0
30 changes: 25 additions & 5 deletions ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,30 @@ std::map<std::string, std::string> get_defines_fp32(
new_defines.insert({"ADD_INT32_INIT", fmt::format("add_int32_tile_init();")});
op_name = "add_int32_tile";
} else {
new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")});
op_name = "add_binary_tile";
}
break;
case BinaryOpType::SUB: op_name = "sub_binary_tile"; break;
case BinaryOpType::MUL: op_name = "mul_binary_tile"; break;
case BinaryOpType::RSUB: op_name = "rsub_binary_tile"; break;
case BinaryOpType::POWER: op_name = "power_binary_tile"; break;
case BinaryOpType::DIV_FAST: op_name = "div_binary_tile"; break;
case BinaryOpType::SUB:
new_defines.insert({"BINOP_INIT", fmt::format("sub_binary_tile_init();")});
op_name = "sub_binary_tile";
break;
case BinaryOpType::MUL:
new_defines.insert({"BINOP_INIT", fmt::format("mul_binary_tile_init();")});
op_name = "mul_binary_tile";
break;
case BinaryOpType::RSUB:
new_defines.insert({"BINOP_INIT", fmt::format("rsub_binary_tile_init();")});
op_name = "rsub_binary_tile";
break;
case BinaryOpType::POWER:
new_defines.insert({"BINOP_INIT", fmt::format("power_binary_tile_init();")});
op_name = "power_binary_tile";
break;
case BinaryOpType::DIV_FAST:
new_defines.insert({"BINOP_INIT", fmt::format("div_binary_tile_init();")});
op_name = "div_binary_tile";
break;
case BinaryOpType::BITWISE_AND:
new_defines.insert({"BITWISE_INIT", fmt::format("binary_bitwise_tile_init();")});
op_name = "and_binary_tile";
Expand All @@ -217,12 +233,14 @@ std::map<std::string, std::string> get_defines_fp32(
// PRE_IN1_0 ====> Applies prescaling for second input
new_defines.merge(get_defines(UnaryOpType::EXP, std::vector<float>{0}, "PRE_IN0_0"));
new_defines.merge(get_defines(UnaryOpType::EXP, std::vector<float>{0}, "PRE_IN1_0"));
new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")});
op_name = "add_binary_tile";
new_defines.merge(get_defines(UnaryOpType::LOG, std::nullopt, "0", idst1));
break;
case BinaryOpType::LOGADDEXP2:
new_defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN0_0"));
new_defines.merge(get_defines(UnaryOpType::EXP2, std::nullopt, "PRE_IN1_0"));
new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")});
op_name = "add_binary_tile";
new_defines.merge(get_defines(UnaryOpType::LOG2, std::nullopt, "0", idst1));
break;
Expand All @@ -239,12 +257,14 @@ std::map<std::string, std::string> get_defines_fp32(
new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "0", idst1));
break;
case BinaryOpType::BIAS_GELU:
new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")});
op_name = "add_binary_tile";
new_defines.merge(get_defines(UnaryOpType::GELU, std::vector<float>{0}, "0", idst1));
break;
case BinaryOpType::LOGICAL_OR:
new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN0_0"));
new_defines.merge(get_defines(UnaryOpType::NEZ, std::nullopt, "PRE_IN1_0"));
new_defines.insert({"BINOP_INIT", fmt::format("add_binary_tile_init();")});
op_name = "add_binary_tile";
new_defines.merge(get_defines(UnaryOpType::GTZ, std::nullopt, "0", idst1));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

#define PRE_SCALE defined SFPU_OP_INIT_PRE_IN0_0 || defined SFPU_OP_INIT_PRE_IN1_0

#if defined(ADD_INT32_INIT) || defined(BITWISE_INIT) || defined(SHIFT_INIT)
#define INT32_INIT
#endif

namespace NAMESPACE {
void MAIN {
uint32_t per_core_block_cnt = get_arg_val<uint32_t>(0);
Expand Down Expand Up @@ -111,10 +107,9 @@ void MAIN {
for (uint32_t i = 0; i < per_core_block_size; ++i) {
copy_tile(cb_inp1, i, i * 2 + 1);

#ifndef INT32_INIT
eltwise_binop_tile_init();
#ifdef BINOP_INIT
BINOP_INIT
#endif

#ifdef ADD_INT32_INIT
ADD_INT32_INIT
#endif
Expand Down

0 comments on commit 2c7e9ad

Please sign in to comment.