From 7f4eb3237f0defe1ad7e6f4bd11a867bd4efe0be Mon Sep 17 00:00:00 2001 From: Patrick Roberts Date: Mon, 23 Dec 2024 16:16:29 +0000 Subject: [PATCH] #16153: Implement binary-ng fused input activations --- .../operations/eltwise/test_binary_bcast.py | 239 ++++++++++++------ .../eltwise/binary/common/binary_op_utils.cpp | 6 +- .../eltwise/binary_ng/binary_ng.cpp | 98 +++++-- .../eltwise/binary_ng/binary_ng.hpp | 52 ++-- .../eltwise/binary_ng/binary_ng_pybind.cpp | 36 ++- .../device/binary_ng_device_operation.cpp | 25 +- .../device/binary_ng_device_operation.hpp | 15 +- .../device/binary_ng_program_factory.cpp | 39 ++- .../binary_ng/device/binary_ng_utils.cpp | 91 +++---- .../binary_ng/device/binary_ng_utils.hpp | 23 +- .../device/kernels/compute/eltwise_binary.cpp | 51 ++-- .../compute/eltwise_binary_no_bcast.cpp | 28 +- .../kernels/compute/eltwise_binary_scalar.cpp | 28 +- .../kernels/compute/eltwise_defines.hpp | 38 --- .../device/kernels/compute/eltwise_utils.hpp | 237 ++++++++++++++--- .../eltwise/unary/common/unary_op_utils.cpp | 74 +++--- .../eltwise/unary/common/unary_op_utils.hpp | 2 + ttnn/cpp/ttnn/tensor/shape/small_vector.hpp | 1 + 18 files changed, 726 insertions(+), 357 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index 236c2a8b085..8f472111298 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -10,11 +10,101 @@ compare_pcc, ) from models.utility_functions import skip_for_grayskull +from itertools import product as parameters + + +binary_fns = { + "gte", + "gt", + "lte", + "lt", + "eq", + "ne", + "logical_and", + "logical_or", + "logical_xor", + "ldexp", + "logaddexp", + "logaddexp2", + "squared_difference", + "add", + "sub", + "mul", + "div", + "bias_gelu", +} +activation_fns = { + "EXP": torch.exp, + "GELU": torch.nn.functional.gelu, + "RELU": torch.relu, + "SQRT": torch.sqrt, + "SIGMOID": torch.sigmoid, + "LOG": torch.log, + "TANH": torch.tanh, + "LOG2": torch.log2, + "LOG10": torch.log10, + "SIN": torch.sin, + "COS": torch.cos, + "ABS": torch.abs, + "SIGN": torch.sign, + "SQUARE": torch.square, + "EQZ": lambda x: torch.eq(x, 0), + "NEZ": lambda x: torch.not_equal(x, 0), + "GTZ": lambda x: torch.greater(x, 0), + "LTZ": lambda x: torch.less(x, 0), + "GEZ": lambda x: torch.greater_equal(x, 0), + "LEZ": lambda x: torch.less_equal(x, 0), + "EXP2": torch.exp2, + "EXPM1": torch.expm1, + "SIGNBIT": torch.signbit, + "RSQRT": torch.rsqrt, + "RELU6": torch.nn.functional.relu6, + "ATAN": torch.atan, + "ERF": torch.erf, + "ERFC": torch.erfc, + "ISINF": torch.isinf, + "ISPOSINF": torch.isposinf, + "ISNEGINF": torch.isneginf, + "ISNAN": torch.isnan, + "LOGICAL_NOT_UNARY": torch.logical_not, + "ISFINITE": torch.isfinite, + "ERFINV": torch.erfinv, + "I0": torch.i0, + "TAN": torch.tan, + "SILU": torch.nn.functional.silu, + "NEG": torch.neg, + "FLOOR": torch.floor, + "CEIL": torch.ceil, +} +no_activations = ((), (), ()) +square_lhs = (("SQUARE",), (), ()) +sin_rhs = ((), ("SIN",), ()) +floor_lhs_ceil_rhs_cos_post = (("FLOOR",), ("CEIL",), ("COS",)) +exp_floor_lhs_exp_rhs = (("FLOOR", "EXP"), ("EXP",), ()) +log_lhs_sqrt_abs_post = (("LOG",), (), ("ABS", "SQRT")) +exp_post = ((), (), ("EXP",)) +log_post = ((), (), ("LOG",)) +tanh_post = ((), (), ("TANH",)) +log2_post = ((), (), ("LOG2",)) +log10_post = ((), (), ("LOG10",)) +exp2_post = ((), (), ("EXP2",)) +expm1_post = ((), (), ("EXPM1",)) +erfinv_post = ((), (), ("ERFINV",)) +i0_post = ((), (), ("I0",)) +tan_post = ((), (), ("TAN",)) +floor_post = ((), (), ("FLOOR",)) +ceil_post = ((), (), ("CEIL",)) + + +def rand_bf16_gen(shape, device, *, min=0, max=1): + pt = torch.rand(shape, dtype=torch.bfloat16) * (max - min) + min + tt = ttnn.from_torch(pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + return pt, tt @skip_for_grayskull("Possible accuracy issues with grayskull") @pytest.mark.parametrize( - "input_shapes", + "a_shape, b_shape", ( (torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])), (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])), @@ -22,46 +112,71 @@ ), ) @pytest.mark.parametrize( - "ttnn_fn", - [ - ttnn.experimental.gte, - ttnn.experimental.gt, - ttnn.experimental.lte, - ttnn.experimental.lt, - ttnn.experimental.eq, - ttnn.experimental.ne, - ttnn.experimental.logical_and, - ttnn.experimental.logical_or, - ttnn.experimental.logical_xor, - ttnn.experimental.ldexp, - ttnn.experimental.logaddexp, - ttnn.experimental.logaddexp2, - ttnn.experimental.squared_difference, - ttnn.experimental.add, - ttnn.experimental.sub, - ttnn.experimental.mul, - ttnn.experimental.div, - ttnn.experimental.bias_gelu, - ], + "ttnn_fn, activations", + { + *parameters( + binary_fns, + { + no_activations, + square_lhs, + sin_rhs, + floor_lhs_ceil_rhs_cos_post, + exp_floor_lhs_exp_rhs, + log_lhs_sqrt_abs_post, + }, + ), + *parameters({"add"}, {((), (), (op,)) for op in activation_fns.keys()}), + }.difference( + parameters({"eq", "ne"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), + parameters({"logaddexp", "logaddexp2"}, {floor_lhs_ceil_rhs_cos_post}), + parameters({"gte", "lt", "lte"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}), + parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu"}, {log_lhs_sqrt_abs_post}), + parameters({"div"}, {exp_post, tanh_post, exp2_post, expm1_post, i0_post, tan_post}), + parameters({"sub"}, {log_post, log2_post, log10_post}), + parameters({"ldexp"}, {erfinv_post, tan_post, floor_post, ceil_post}), + parameters({"squared_difference"}, {erfinv_post, i0_post}), + parameters({"add"}, {tan_post, tanh_post}), + {("mul", log_lhs_sqrt_abs_post)}, + ), ) -def test_binary_scalar_ops(input_shapes, ttnn_fn, device): - a_shape, b_shape = input_shapes - a_pt = torch.rand(a_shape).bfloat16() - b_pt = torch.rand(b_shape).bfloat16() +def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device): + torch.manual_seed(0) + ttnn_op = getattr(ttnn.experimental, ttnn_fn) + lhs, rhs, post = ([getattr(ttnn.UnaryOpType, op) for op in ops] for ops in activations) + golden_lhs, golden_rhs, golden_post = ((activation_fns[op] for op in ops) for ops in activations) + # make 0 exclusive for rhs of div + min, max = (1, 0) if ttnn_fn == "div" else (0, 1) + + a_pt, a_tt = rand_bf16_gen(a_shape, device) + b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max) - a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) cq_id = 0 - out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id) - golden_fn = ttnn.get_golden_function(ttnn_fn) - out_pt = golden_fn(a_pt, b_pt) + out_tt = ttnn_op(a_tt, b_tt, queue_id=cq_id, lhs_activations=lhs, rhs_activations=rhs, post_activations=post) + + for golden_activation in golden_lhs: + a_pt = golden_activation(a_pt).bfloat16() + + for golden_activation in golden_rhs: + b_pt = golden_activation(b_pt).bfloat16() + + golden_fn = ttnn.get_golden_function(ttnn_op) + out_pt = golden_fn(a_pt, b_pt).bfloat16() - comp_pass = compare_pcc([out_tt], [out_pt]) - assert comp_pass + for golden_activation in golden_post: + out_pt = golden_activation(out_pt).bfloat16() + + def compare(tt, pt): + imprecise_cases = { + *parameters({"bias_gelu"}, {square_lhs, floor_lhs_ceil_rhs_cos_post}), + *parameters({"gte", "gt", "lte", "lt"}, {sin_rhs}), + } + return compare_pcc(tt, pt, 0.98) if (ttnn_fn, activations) in imprecise_cases else compare_pcc(tt, pt) + + assert compare([out_tt], [out_pt]) @pytest.mark.parametrize( - "input_shapes", + "a_shape, b_shape", ( (torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])), (torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])), @@ -70,43 +185,23 @@ def test_binary_scalar_ops(input_shapes, ttnn_fn, device): ) @pytest.mark.parametrize( "ttnn_fn", - [ - ttnn.experimental.gte, - ttnn.experimental.gt, - ttnn.experimental.lte, - ttnn.experimental.lt, - ttnn.experimental.eq, - ttnn.experimental.ne, - ttnn.experimental.logical_and, - ttnn.experimental.logical_or, - ttnn.experimental.logical_xor, - ttnn.experimental.ldexp, - ttnn.experimental.logaddexp, - ttnn.experimental.logaddexp2, - ttnn.experimental.squared_difference, - ttnn.experimental.add, - ttnn.experimental.sub, - ttnn.experimental.mul, - ttnn.experimental.div, - ttnn.experimental.bias_gelu, - ], + binary_fns, ) -def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device): - a_shape, b_shape = input_shapes - a_pt = torch.rand(a_shape).bfloat16() - b_pt = torch.rand(b_shape).bfloat16() +def test_binary_scalar_ops_invalid_bcast(a_shape, b_shape, ttnn_fn, device): + torch.manual_seed(0) + ttnn_op = getattr(ttnn.experimental, ttnn_fn) - a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) - b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + _, a_tt = rand_bf16_gen(a_shape, device) + _, b_tt = rand_bf16_gen(b_shape, device) with pytest.raises(RuntimeError) as e: cq_id = 0 - _ = ttnn_fn(a_tt, b_tt, queue_id=cq_id) + _ = ttnn_op(a_tt, b_tt, queue_id=cq_id) assert "Broadcasting rule violation" in str(e.value) @pytest.mark.parametrize( - "shapes", + "a_shape, b_shape", [ [[1, 71, 7, 7], [7, 7]], [[920, 1, 256], [256]], @@ -119,17 +214,14 @@ def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device): [[16, 1], [1, 1, 32]], ], ) -def test_unequal_ranks(device, shapes): +def test_unequal_ranks(a_shape, b_shape, device): torch.manual_seed(0) - torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16) - torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16) + + torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device) + torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device) + torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b - input_tensor_a = ttnn.from_torch( - torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG - ) - input_tensor_b = ttnn.from_torch( - torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG - ) + output_tensor = ttnn.experimental.add(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG) output_tensor = ttnn.to_torch(output_tensor) @@ -138,7 +230,7 @@ def test_unequal_ranks(device, shapes): @pytest.mark.parametrize( - "data", + "a, b, c_golden", [ ([], [], []), ([1], [2], [3]), @@ -150,8 +242,7 @@ def test_unequal_ranks(device, shapes): ], ) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) -def test_01_volume_tensors(device, data, memory_config): - (a, b, c_golden) = data +def test_01_volume_tensors(device, a, b, c_golden, memory_config): a = torch.BFloat16Tensor(a) b = torch.BFloat16Tensor(b) assert torch.add(a, b).tolist() == c_golden diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 846d7fc0f68..b91a8e0491e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -151,11 +151,11 @@ std::map get_defines( defines["ELTWISE_OP"] = op_name.c_str(); defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str(); if (fused_activations.has_value()) { - if (op_type == BinaryOpType::ADD and fused_activations.value().size() == 1 and - fused_activations.value().at(0).op_type == UnaryOpType::RELU) { + if (op_type == BinaryOpType::ADD and fused_activations->size() == 1 and + fused_activations->at(0).op_type == UnaryOpType::RELU and not input_tensor_a_activation.has_value()) { defines["PACK_RELU"] = "1"; } else { - defines.merge(ttnn::operations::unary::utils::get_block_defines(fused_activations.value(), "0", idst)); + defines.merge(ttnn::operations::unary::utils::get_block_defines(*fused_activations, "0", idst)); } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 5cd0c53491d..4d844cab987 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -11,50 +11,98 @@ namespace ttnn::operations::binary_ng { template Tensor BinaryNg::invoke( uint8_t queue_id, - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); - Tensor input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + auto input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); return ttnn::prim::binary_ng( - queue_id, input_a, input_b, binary_op_type, output_dtype, memory_config, optional_output_tensor); + queue_id, + input_a, + input_b, + binary_op_type, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor_a, input_tensor_b, output_dtype, memory_config, optional_output_tensor); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + return invoke( + DefaultQueueId, + input_tensor_a, + input_tensor_b, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( uint8_t queue_id, - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - Tensor input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); return ttnn::prim::binary_ng( - queue_id, input_a, scalar, binary_op_type, output_dtype, memory_config, optional_output_tensor); + queue_id, + input_a, + scalar, + binary_op_type, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template Tensor BinaryNg::invoke( - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype, - const std::optional &memory_config, - std::optional optional_output_tensor) { - return invoke(DefaultQueueId, input_tensor_a, scalar, output_dtype, memory_config, optional_output_tensor); + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { + return invoke( + DefaultQueueId, + input_tensor_a, + scalar, + output_dtype, + memory_config, + optional_output_tensor, + lhs_activations, + rhs_activations, + post_activations); } template struct BinaryNg; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index 57311622e38..34cf4165e61 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -8,6 +8,7 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" #include "ttnn/operations/copy.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" namespace ttnn::operations::binary_ng { @@ -15,33 +16,45 @@ template struct BinaryNg { static Tensor invoke( uint8_t queue_id, - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( - const Tensor &input_tensor_a, - const Tensor &input_tensor_b, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( uint8_t queue_id, - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); static Tensor invoke( - const Tensor &input_tensor_a, + const Tensor& input_tensor_a, float scalar, - const std::optional &output_dtype = std::nullopt, - const std::optional &memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt); + const std::optional& output_dtype = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + tt::stl::Span lhs_activations = {}, + tt::stl::Span rhs_activations = {}, + tt::stl::Span post_activations = {}); }; } // namespace ttnn::operations::binary_ng @@ -122,4 +135,5 @@ constexpr auto logaddexp = ttnn::register_operation_with_auto_launch_op< constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::logaddexp2", ttnn::operations::binary_ng::BinaryNg>(); -} + +} // namespace ttnn::experimental diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index da264795941..a392c476a8a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -24,8 +24,20 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, const uint8_t& queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor); + return self( + queue_id, + input_tensor_a, + scalar, + dtype, + memory_config, + output_tensor, + lhs_activations, + rhs_activations, + post_activations); }, py::arg("input_tensor_a"), py::arg("scalar"), @@ -33,6 +45,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}, // tensor and tensor @@ -43,8 +58,20 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst const std::optional& dtype, const std::optional& memory_config, const std::optional& output_tensor, + const ttnn::SmallVector& lhs_activations, + const ttnn::SmallVector& rhs_activations, + const ttnn::SmallVector& post_activations, uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor); + return self( + queue_id, + input_tensor_a, + input_tensor_b, + dtype, + memory_config, + output_tensor, + lhs_activations, + rhs_activations, + post_activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -52,6 +79,9 @@ void bind_binary_ng_operation(py::module& module, T op, const std::string& docst py::arg("dtype") = std::nullopt, py::arg("memory_config") = std::nullopt, py::arg("output_tensor") = std::nullopt, + py::arg("lhs_activations") = ttnn::SmallVector(), + py::arg("rhs_activations") = ttnn::SmallVector(), + py::arg("post_activations") = ttnn::SmallVector(), py::arg("queue_id") = 0}); } } // namespace detail @@ -77,4 +107,4 @@ void py_module(py::module& module) { detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp, "Binary Logaddexp Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::logaddexp2, "Binary Logaddexp2 Operation"); } -} // namespace ttnn::operations::eltwise::binary_ng +} // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index c36a354edde..15a53a734b8 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -42,7 +42,14 @@ SubtileBroadcastType get_subtile_broadcast_type(uint32_t a_h, uint32_t a_w, uint tt::stl::hash::hash_t BinaryNgDeviceOperation::operation_attributes_t::to_hash() const { return tt::stl::hash::hash_objects_with_default_seed( - binary_op_type, memory_config, get_dtype(), compute_kernel_config, subtile_broadcast_type); + binary_op_type, + lhs_activations, + rhs_activations, + post_activations, + memory_config, + get_dtype(), + compute_kernel_config, + subtile_broadcast_type); } DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const { @@ -199,7 +206,10 @@ BinaryNgDeviceOperation::invoke( BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor) { + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { auto subtile_broadcast_type = get_subtile_broadcast_type( input_tensor_a_arg.get_logical_shape()[-2], input_tensor_a_arg.get_logical_shape()[-1], @@ -209,6 +219,9 @@ BinaryNgDeviceOperation::invoke( return { operation_attributes_t{ binary_op_type, + {lhs_activations.begin(), lhs_activations.end()}, + {rhs_activations.begin(), rhs_activations.end()}, + {post_activations.begin(), post_activations.end()}, std::nullopt, memory_config.value_or(input_tensor_a_arg.memory_config()), input_tensor_a_arg.get_dtype(), @@ -225,10 +238,16 @@ BinaryNgDeviceOperation::invoke( BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor) { + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations) { return { operation_attributes_t{ binary_op_type, + {lhs_activations.begin(), lhs_activations.end()}, + {rhs_activations.begin(), rhs_activations.end()}, + {post_activations.begin(), post_activations.end()}, scalar, memory_config.value_or(input_tensor_a_arg.memory_config()), input_tensor_a_arg.get_dtype(), diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp index 2be532e9943..0603ecbee2c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp @@ -8,7 +8,7 @@ #include "ttnn/device_operation.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" - +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" namespace ttnn::operations::binary_ng { enum class SubtileBroadcastType { @@ -31,6 +31,9 @@ struct BinaryNgDeviceOperation { struct operation_attributes_t { BinaryOpType binary_op_type; + ttnn::SmallVector lhs_activations; + ttnn::SmallVector rhs_activations; + ttnn::SmallVector post_activations; std::optional scalar; tt::tt_metal::MemoryConfig memory_config; DataType input_dtype; @@ -86,7 +89,10 @@ struct BinaryNgDeviceOperation { BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor); + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); // tensor-scalar invocation static std::tuple invoke( @@ -95,7 +101,10 @@ struct BinaryNgDeviceOperation { BinaryOpType binary_op_type, const std::optional& output_dtype, const std::optional& memory_config, - std::optional optional_output_tensor); + std::optional optional_output_tensor, + tt::stl::Span lhs_activations, + tt::stl::Span rhs_activations, + tt::stl::Span post_activations); }; } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index 9dc1061f6f6..3121cc2e7ae 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -5,6 +5,7 @@ #include "binary_ng_utils.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/cb_utils.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" using namespace tt::tt_metal; @@ -184,7 +185,39 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio Buffer* c_buffer = c.buffer(); auto op_type = operation_attributes.binary_op_type; - auto compute_kernel_defines = OpConfig(op_type).as_defines(); + OpConfig op_config(op_type); + auto compute_kernel_defines = op_config.as_defines(); + + { + ttnn::SmallVector lhs_activations = operation_attributes.lhs_activations; + ttnn::SmallVector rhs_activations = operation_attributes.rhs_activations; + ttnn::SmallVector post_activations = operation_attributes.post_activations; + + if (op_config.process_lhs.has_value()) { + lhs_activations.push_back(*op_config.process_lhs); + } + + if (op_config.process_rhs.has_value()) { + rhs_activations.push_back(*op_config.process_rhs); + } + + if (op_config.postprocess.has_value()) { + post_activations.insert(post_activations.begin(), *op_config.postprocess); + } + + add_activation_defines(compute_kernel_defines, lhs_activations, "LHS"); + add_activation_defines(compute_kernel_defines, rhs_activations, "RHS"); + + if (lhs_activations.empty() and rhs_activations.empty() and post_activations.size() == 1 and + post_activations[0] == unary::UnaryOpType::RELU) { + compute_kernel_defines["PACK_RELU"] = "1"; + compute_kernel_defines["PROCESS_POST_ACTIVATIONS(i)"] = ""; + unary::utils::update_macro_defines(unary::UnaryOpType::RELU, compute_kernel_defines); + } else { + add_activation_defines(compute_kernel_defines, post_activations, "POST"); + } + } + bool op_has_exp = op_type == BinaryOpType::LOGADDEXP || op_type == BinaryOpType::LDEXP || op_type == BinaryOpType::LOGADDEXP2; @@ -193,7 +226,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto [a_cb, a_cb_handle] = create_cb(tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); - if (compute_kernel_defines.find("PREPROCESS_A_INIT") != compute_kernel_defines.end()) { + if (not compute_kernel_defines["PROCESS_LHS_ACTIVATIONS(i)"].empty()) { auto a_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : a_data_format; uint32_t a_intermediate_single_tile_size = tt_metal::detail::TileSize(a_intermediate_format); auto [a_cb_interim, a_cb_interim_handle] = create_cb( @@ -208,7 +241,7 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto [b_cb, b_cb_handle] = create_cb(tt::CBIndex::c_1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); - if (compute_kernel_defines.find("PREPROCESS_B_INIT") != compute_kernel_defines.end()) { + if (not compute_kernel_defines["PROCESS_RHS_ACTIVATIONS(i)"].empty()) { auto b_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : b_data_format; uint32_t b_intermediate_single_tile_size = tt_metal::detail::TileSize(b_intermediate_format); auto [b_cb_interim, b_cb_interim_handle] = create_cb( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp index 3671cd9d10f..7d4410f7b3f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "binary_ng_utils.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" +#include "tt_metal/common/assert.hpp" #include #include @@ -10,7 +12,7 @@ template <> struct fmt::formatter : fmt::formatter { - auto format(ttnn::operations::binary_ng::Lowercase const& value, fmt::format_context& ctx) const { + auto format(const ttnn::operations::binary_ng::Lowercase& value, fmt::format_context& ctx) const { auto out = ctx.out(); for (char c : value.view) { *out++ = std::tolower(static_cast(c)); @@ -117,9 +119,6 @@ std::string get_kernel_file_path(KernelName kernel_name) { } } -constexpr OpConfig::SfpuConfig NezConfig("nez_tile_init", "nez_tile(i)"); -constexpr OpConfig::SfpuConfig GtzConfig("gtz_tile_init", "gtz_tile(i)"); - OpConfig::OpConfig(BinaryOpType binary_op_type) { fpu_binary_op = FpuBinaryOp::SUB; switch (binary_op_type) { @@ -127,74 +126,57 @@ OpConfig::OpConfig(BinaryOpType binary_op_type) { case BinaryOpType::SUB: break; case BinaryOpType::MUL: fpu_binary_op = FpuBinaryOp::MUL; break; case BinaryOpType::DIV: - preprocess_b = SfpuConfig("recip_tile_init", "recip_tile(i)", "compute_kernel_api/eltwise_unary/recip.h"); + process_rhs = unary::UnaryOpType::RECIP; fpu_binary_op = FpuBinaryOp::MUL; break; - case BinaryOpType::GT: postprocess = GtzConfig; break; - case BinaryOpType::LT: postprocess = SfpuConfig("ltz_tile_init", "ltz_tile(i)"); break; - case BinaryOpType::GTE: postprocess = SfpuConfig("gez_tile_init", "gez_tile(i)"); break; - case BinaryOpType::LTE: postprocess = SfpuConfig("lez_tile_init", "lez_tile(i)"); break; - case BinaryOpType::EQ: postprocess = SfpuConfig("eqz_tile_init", "eqz_tile(i)"); break; - case BinaryOpType::NE: postprocess = NezConfig; break; - case BinaryOpType::SQUARED_DIFFERENCE: postprocess = SfpuConfig("square_tile_init", "square_tile(i)"); break; + case BinaryOpType::GT: postprocess = unary::UnaryOpType::GTZ; break; + case BinaryOpType::LT: postprocess = unary::UnaryOpType::LTZ; break; + case BinaryOpType::GTE: postprocess = unary::UnaryOpType::GEZ; break; + case BinaryOpType::LTE: postprocess = unary::UnaryOpType::LEZ; break; + case BinaryOpType::EQ: postprocess = unary::UnaryOpType::EQZ; break; + case BinaryOpType::NE: postprocess = unary::UnaryOpType::NEZ; break; + case BinaryOpType::SQUARED_DIFFERENCE: postprocess = unary::UnaryOpType::SQUARE; break; case BinaryOpType::BIAS_GELU: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("gelu_tile_init", "gelu_tile(i)", "compute_kernel_api/eltwise_unary/gelu.h"); + process_lhs = unary::UnaryOpType::GELU; break; case BinaryOpType::LOGICAL_AND: fpu_binary_op = FpuBinaryOp::MUL; - postprocess = NezConfig; + postprocess = unary::UnaryOpType::NEZ; break; case BinaryOpType::LOGICAL_OR: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = GtzConfig; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::GTZ; break; case BinaryOpType::LOGICAL_XOR: - preprocess_a = NezConfig; - preprocess_b = NezConfig; - postprocess = NezConfig; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::NEZ; break; case BinaryOpType::LDEXP: fpu_binary_op = FpuBinaryOp::MUL; - preprocess_b = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); + process_rhs = unary::UnaryOpType::EXP2; break; case BinaryOpType::LOGADDEXP: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = - SfpuConfig("exp_tile_init", "exp_tile(i)", "compute_kernel_api/eltwise_unary/exp.h"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_tile_init", "log_tile(i)"); + process_lhs = unary::UnaryOpType::EXP; + process_rhs = unary::UnaryOpType::EXP; + postprocess = unary::UnaryOpType::LOG; break; case BinaryOpType::LOGADDEXP2: fpu_binary_op = FpuBinaryOp::ADD; - preprocess_a = SfpuConfig("exp2_tile_init", "exp2_tile(i)"); - preprocess_b = preprocess_a; - postprocess = SfpuConfig("log_with_base_tile_init", "log_with_base_tile(i, 0x3dc5u);"); + process_lhs = unary::UnaryOpType::EXP2; + process_rhs = unary::UnaryOpType::EXP2; + postprocess = unary::UnaryOpType::LOG2; break; - default: __builtin_unreachable(); - } -} - -std::map OpConfig::SfpuConfig::as_defines(std::string_view prefix) const { - if (init.empty()) { - return {}; + default: TT_THROW("Unsupported binary op {}", binary_op_type); } - - std::map defines; - defines[fmt::format("{}_INIT", prefix)] = init; - defines[fmt::format("{}_APPLY(i)", prefix)] = apply; - defines[fmt::format("{}_INCLUDE", prefix)] = include; - return defines; } std::map OpConfig::as_defines() const { std::map defines; - defines.merge(preprocess_a.as_defines("PREPROCESS_A")); - defines.merge(preprocess_b.as_defines("PREPROCESS_B")); - defines.merge(postprocess.as_defines("POSTPROCESS")); auto binary_op_str = magic_enum::enum_name(fpu_binary_op); defines["BINARY_OP"] = fmt::format("{}_tiles", Lowercase{binary_op_str}); @@ -203,4 +185,23 @@ std::map OpConfig::as_defines() const { return defines; } +void add_activation_defines( + std::map& defines, + tt::stl::Span activations, + std::string_view operand) { + auto prepend_separator = false; + std::string process = ""; + + for (auto& a : activations) { + if (prepend_separator) { + process += ';'; + } + prepend_separator = true; + process += fmt::format("PROCESS_ACTIVATION({}, i)", magic_enum::enum_name(a)); + unary::utils::update_macro_defines(a, defines); + } + + defines[fmt::format("PROCESS_{}_ACTIVATIONS(i)", operand)] = process; +} + } // namespace ttnn::operations::binary_ng diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp index cc8a242fc0c..b1ee7f41700 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp @@ -41,30 +41,23 @@ struct BinaryNgKernelConfig { std::string get_kernel_file_path(KernelName kernel_name); struct OpConfig { - struct SfpuConfig { - SfpuConfig() = default; - constexpr SfpuConfig( - std::string_view init, std::string_view apply, std::string_view include = "compute_kernel_api.h") : - init{init}, apply{apply}, include{include} {} - std::string_view init{}; - std::string_view apply{}; - std::string_view include{}; - - std::map as_defines(std::string_view prefix) const; - }; - enum class FpuBinaryOp { ADD, SUB, MUL }; OpConfig(BinaryOpType binary_op_type); std::map as_defines() const; - SfpuConfig preprocess_a{}; - SfpuConfig preprocess_b{}; - SfpuConfig postprocess{}; + std::optional process_lhs{}; + std::optional process_rhs{}; + std::optional postprocess{}; FpuBinaryOp fpu_binary_op; }; +void add_activation_defines( + std::map& defines, + tt::stl::Span activations, + std::string_view operand); + struct Lowercase { std::string_view view; }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp index c3f5f414d38..8f6f7990bf1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp @@ -3,9 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -22,40 +22,32 @@ ALWI void process_tile( constexpr uint32_t onetile = 1; #if BCAST_INPUT - auto cb_bcast = cb_post_rhs; - auto cb_other = cb_post_lhs; +#define CB_PRE_BCAST cb_pre_rhs +#define CB_POST_BCAST cb_post_rhs +#define CB_PRE_OTHER cb_pre_lhs +#define CB_POST_OTHER cb_post_lhs #else - auto cb_bcast = cb_post_lhs; - auto cb_other = cb_post_rhs; +#define CB_PRE_BCAST cb_pre_lhs +#define CB_POST_BCAST cb_post_lhs +#define CB_PRE_OTHER cb_pre_rhs +#define CB_POST_OTHER cb_post_rhs #endif -#if PREPROCESS_A && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - - cb_wait_front(cb_bcast, onetile); + PREPROCESS(BCAST_OP, CB_PRE_BCAST, CB_POST_BCAST, cb_out, onetile); + cb_wait_front(CB_POST_BCAST, onetile); for (uint32_t j = tile_start; j < freq; ++j) { -#if PREPROCESS_A && (BCAST_INPUT == 1) - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#elif PREPROCESS_B && (BCAST_INPUT == 0) - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif - cb_wait_front(cb_other, onetile); + PREPROCESS(OTHER_OP, CB_PRE_OTHER, CB_POST_OTHER, cb_out, onetile); + cb_wait_front(CB_POST_OTHER, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); @@ -63,9 +55,9 @@ ALWI void process_tile( tile_regs_release(); cb_push_back(cb_out, onetile); - cb_pop_front(cb_other, onetile); + cb_pop_front(CB_POST_OTHER, onetile); } - cb_pop_front(cb_bcast, onetile); + cb_pop_front(CB_POST_BCAST, onetile); } void MAIN { @@ -81,12 +73,15 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp index 2e365c860c3..f2a263bfb1a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp @@ -3,10 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 #include + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "dprint.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -17,39 +17,35 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif constexpr uint32_t onetile = 1; for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); cb_wait_front(cb_post_lhs, onetile); -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); cb_wait_front(cb_post_rhs, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp index 0f468636223..b9f2e29903e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp @@ -3,9 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" -#include "eltwise_defines.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { @@ -16,39 +16,35 @@ void MAIN { constexpr auto cb_pre_rhs = tt::CBIndex::c_1; constexpr auto cb_out = tt::CBIndex::c_2; - constexpr auto cb_post_lhs = PREPROCESS_A ? tt::CBIndex::c_3 : cb_pre_lhs; - constexpr auto cb_post_rhs = PREPROCESS_B ? tt::CBIndex::c_4 : cb_pre_rhs; + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; binary_op_init_common(cb_post_lhs, cb_post_rhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif -#if not(PREPROCESS_A || PREPROCESS_B) +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) binary_op_specific_init(); #endif constexpr uint32_t onetile = 1; -#if PREPROCESS_B - PREPROCESS(PREPROCESS_B_INIT, PREPROCESS_B_APPLY, cb_pre_rhs, cb_post_rhs, cb_out, onetile); -#endif + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); cb_wait_front(cb_post_rhs, onetile); - for(uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { -#if PREPROCESS_A - PREPROCESS(PREPROCESS_A_INIT, PREPROCESS_A_APPLY, cb_pre_lhs, cb_post_lhs, cb_out, onetile); -#endif + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); cb_wait_front(cb_post_lhs, onetile); cb_reserve_back(cb_out, onetile); -#if PREPROCESS_A || PREPROCESS_B +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) binary_op_specific_init(); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); -#if POSTPROCESS - POSTPROCESS_INIT(); - POSTPROCESS_APPLY(0); -#endif + PROCESS_POST_ACTIVATIONS(0); tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp deleted file mode 100644 index 5eacbd9b9dd..00000000000 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_defines.hpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#define DO_QUOTE(x) #x -#define QUOTE(x) DO_QUOTE(x) - -#if defined(PREPROCESS_A_INIT) -#define PREPROCESS_A 1 -#else -#define PREPROCESS_A 0 -#endif - -#if defined(PREPROCESS_B_INIT) -#define PREPROCESS_B 1 -#else -#define PREPROCESS_B 0 -#endif - -#if defined(POSTPROCESS_INIT) -#define POSTPROCESS 1 -#else -#define POSTPROCESS 0 -#endif - -#ifdef PREPROCESS_A_INCLUDE -#include QUOTE(PREPROCESS_A_INCLUDE) -#endif - -#ifdef PREPROCESS_B_INCLUDE -#include QUOTE(PREPROCESS_B_INCLUDE) -#endif - -#ifdef POSTPROCESS_INCLUDE -#include QUOTE(POSTPROCESS_INCLUDE) -#endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp index 63f350c47c1..4451eb0e8f3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp @@ -7,35 +7,210 @@ #include "compute_kernel_api/common.h" #include "compute_kernel_api/tile_move_copy.h" -#define PREPROCESS(init, apply, cb_pre, cb_post, cb_out, per_core_block_size) \ - do { \ - using namespace ckernel; \ - \ - copy_tile_to_dst_init_short(); \ - \ - reconfig_data_format_srca(/*old*/ cb_post, /*new*/ cb_pre); \ - pack_reconfig_data_format(/*old*/ cb_out, /*new*/ cb_post); \ - \ - cb_wait_front(cb_pre, per_core_block_size); \ - cb_reserve_back(cb_post, per_core_block_size); \ - \ - tile_regs_acquire(); \ - init(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - copy_tile(cb_pre, i, i); \ - apply(i); \ - } \ - tile_regs_commit(); \ - \ - tile_regs_wait(); \ - for (uint32_t i = 0; i < per_core_block_size; ++i) { \ - pack_tile(i, cb_post); /* DST[0]->cb */ \ - } \ - tile_regs_release(); \ - \ - cb_pop_front(cb_pre, per_core_block_size); \ - cb_push_back(cb_post, per_core_block_size); \ - \ - reconfig_data_format_srca(/*old*/ cb_pre, /*new*/ cb_post); \ - pack_reconfig_data_format(/*old*/ cb_post, /*new*/ cb_out); \ +#define ACTIVATION_INIT_RELU relu_tile_init +#define ACTIVATION_APPLY_RELU relu_tile + +#define ACTIVATION_INIT_SQUARE square_tile_init +#define ACTIVATION_APPLY_SQUARE square_tile + +#define ACTIVATION_INIT_GTZ gtz_tile_init +#define ACTIVATION_APPLY_GTZ gtz_tile + +#define ACTIVATION_INIT_LTZ ltz_tile_init +#define ACTIVATION_APPLY_LTZ ltz_tile + +#define ACTIVATION_INIT_GEZ gez_tile_init +#define ACTIVATION_APPLY_GEZ gez_tile + +#define ACTIVATION_INIT_LEZ lez_tile_init +#define ACTIVATION_APPLY_LEZ lez_tile + +#define ACTIVATION_INIT_EQZ eqz_tile_init +#define ACTIVATION_APPLY_EQZ eqz_tile + +#define ACTIVATION_INIT_NEZ nez_tile_init +#define ACTIVATION_APPLY_NEZ nez_tile + +#define ACTIVATION_INIT_LOG log_tile_init +#define ACTIVATION_APPLY_LOG log_tile + +#define ACTIVATION_INIT_TANH tanh_tile_init +#define ACTIVATION_APPLY_TANH tanh_tile + +#define ACTIVATION_INIT_LOG2 log_with_base_tile_init +#define ACTIVATION_APPLY_LOG2(i) log_with_base_tile(i, 0x3dc5u) + +#define ACTIVATION_INIT_LOG10 log_with_base_tile_init +#define ACTIVATION_APPLY_LOG10(i) log_with_base_tile(i, 0x36f3u) + +#define ACTIVATION_INIT_EXP exp_tile_init +#define ACTIVATION_APPLY_EXP exp_tile + +#define ACTIVATION_INIT_EXP2 exp2_tile_init +#define ACTIVATION_APPLY_EXP2 exp2_tile + +#define ACTIVATION_INIT_EXPM1 expm1_tile_init +#define ACTIVATION_APPLY_EXPM1 expm1_tile + +#define ACTIVATION_INIT_RECIP recip_tile_init +#define ACTIVATION_APPLY_RECIP recip_tile + +#define ACTIVATION_INIT_GELU gelu_tile_init +#define ACTIVATION_APPLY_GELU gelu_tile + +#define ACTIVATION_INIT_SQRT sqrt_tile_init +#define ACTIVATION_APPLY_SQRT sqrt_tile + +#define ACTIVATION_INIT_SIGMOID sigmoid_tile_init +#define ACTIVATION_APPLY_SIGMOID sigmoid_tile + +#define ACTIVATION_INIT_SIN sin_tile_init +#define ACTIVATION_APPLY_SIN sin_tile + +#define ACTIVATION_INIT_COS cos_tile_init +#define ACTIVATION_APPLY_COS cos_tile + +#define ACTIVATION_INIT_TAN tan_tile_init +#define ACTIVATION_APPLY_TAN tan_tile + +#define ACTIVATION_INIT_ASIN asin_tile_init +#define ACTIVATION_APPLY_ASIN asin_tile + +#define ACTIVATION_INIT_ACOS acos_tile_init +#define ACTIVATION_APPLY_ACOS acos_tile + +#define ACTIVATION_INIT_ATAN atan_tile_init +#define ACTIVATION_APPLY_ATAN atan_tile + +#define ACTIVATION_INIT_ABS abs_tile_init +#define ACTIVATION_APPLY_ABS abs_tile + +#define ACTIVATION_INIT_SIGN sign_tile_init +#define ACTIVATION_APPLY_SIGN sign_tile + +#define ACTIVATION_INIT_SIGNBIT signbit_tile_init +#define ACTIVATION_APPLY_SIGNBIT signbit_tile + +#define ACTIVATION_INIT_RSQRT rsqrt_tile_init +#define ACTIVATION_APPLY_RSQRT rsqrt_tile + +#define ACTIVATION_INIT_RELU6 relu_max_tile_init +#define ACTIVATION_APPLY_RELU6(i) relu_max_tile(i, 0x40c00000u) + +#define ACTIVATION_INIT_ERF erf_tile_init +#define ACTIVATION_APPLY_ERF erf_tile + +#define ACTIVATION_INIT_ERFC erfc_tile_init +#define ACTIVATION_APPLY_ERFC erfc_tile + +#define ACTIVATION_INIT_ISINF isinf_tile_init +#define ACTIVATION_APPLY_ISINF isinf_tile + +#define ACTIVATION_INIT_ISPOSINF isposinf_tile_init +#define ACTIVATION_APPLY_ISPOSINF isposinf_tile + +#define ACTIVATION_INIT_ISNEGINF isneginf_tile_init +#define ACTIVATION_APPLY_ISNEGINF isneginf_tile + +#define ACTIVATION_INIT_ISNAN isnan_tile_init +#define ACTIVATION_APPLY_ISNAN isnan_tile + +#define ACTIVATION_INIT_ISFINITE isfinite_tile_init +#define ACTIVATION_APPLY_ISFINITE isfinite_tile + +#define ACTIVATION_INIT_LOGICAL_NOT_UNARY logical_not_unary_tile_init +#define ACTIVATION_APPLY_LOGICAL_NOT_UNARY logical_not_unary_tile + +#define ACTIVATION_INIT_ERFINV erfinv_tile_init +#define ACTIVATION_APPLY_ERFINV erfinv_tile + +#define ACTIVATION_INIT_I0 i0_tile_init +#define ACTIVATION_APPLY_I0 i0_tile + +#define ACTIVATION_INIT_I1 i1_tile_init +#define ACTIVATION_APPLY_I1 i1_tile + +#define ACTIVATION_INIT_SILU silu_tile_init +#define ACTIVATION_APPLY_SILU silu_tile + +#define ACTIVATION_INIT_NEG negative_tile_init +#define ACTIVATION_APPLY_NEG negative_tile + +#define ACTIVATION_INIT_BITWISE_NOT bitwise_not_tile_init +#define ACTIVATION_APPLY_BITWISE_NOT bitwise_not_tile + +#define ACTIVATION_INIT_FLOOR floor_tile_init +#define ACTIVATION_APPLY_FLOOR floor_tile + +#define ACTIVATION_INIT_CEIL ceil_tile_init +#define ACTIVATION_APPLY_CEIL ceil_tile + +#define IS_EMPTY(...) P_CAT(IS_EMPTY_, IS_BEGIN_PARENS(__VA_ARGS__))(__VA_ARGS__) +#define IS_EMPTY_0(...) IS_BEGIN_PARENS(IS_EMPTY_NON_FUNCTION_C __VA_ARGS__()) +#define IS_EMPTY_1(...) 0 +#define IS_EMPTY_NON_FUNCTION_C(...) () + +#define IS_BEGIN_PARENS(...) P_FIRST(P_CAT(P_IS_VARIADIC_R_, P_IS_VARIADIC_C __VA_ARGS__)) + +#define P_IS_VARIADIC_R_1 1, +#define P_IS_VARIADIC_R_P_IS_VARIADIC_C 0, +#define P_IS_VARIADIC_C(...) 1 + +#define P_FIRST(...) P_FIRST_(__VA_ARGS__, ) +#define P_FIRST_(a, ...) a + +#define P_CAT(a, ...) P_CAT_(a, __VA_ARGS__) +#define P_CAT_(a, ...) a##__VA_ARGS__ + +#define P_COMPL(b) P_CAT(P_COMPL_, b) +#define P_COMPL_0 1 +#define P_COMPL_1 0 + +#define ACTIVATION_INIT(elem) ACTIVATION_INIT_##elem() +#define ACTIVATION_APPLY(elem, i) ACTIVATION_APPLY_##elem(i) + +#define PROCESS_ACTIVATION(elem, i) \ + ACTIVATION_INIT(elem); \ + ACTIVATION_APPLY(elem, i) + +#define PROCESS_ACTIVATIONS(op, i) PROCESS_ACTIVATIONS_(op)(i) +#define PROCESS_ACTIVATIONS_(op) PROCESS_##op##_ACTIVATIONS +#define HAS_ACTIVATIONS(op) P_COMPL(IS_EMPTY(PROCESS_ACTIVATIONS(op, 0))) + +#define BCAST_OP P_CAT(BCAST_OP_, BCAST_INPUT) +#define OTHER_OP P_CAT(BCAST_OP_, P_COMPL(BCAST_INPUT)) +#define BCAST_OP_0 LHS +#define BCAST_OP_1 RHS + +#define PREPROCESS(op, ...) P_CAT(PREPROCESS_, HAS_ACTIVATIONS(op))(op, __VA_ARGS__) +#define PREPROCESS_0(...) +#define PREPROCESS_1(op, cb_pre, cb_post, cb_out, per_core_block_size) \ + do { \ + using namespace ckernel; \ + \ + reconfig_data_format_srca(/*old*/ cb_post, /*new*/ cb_pre); \ + pack_reconfig_data_format(/*old*/ cb_out, /*new*/ cb_post); \ + \ + cb_wait_front(cb_pre, per_core_block_size); \ + cb_reserve_back(cb_post, per_core_block_size); \ + \ + tile_regs_acquire(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + copy_tile_to_dst_init_short(); \ + copy_tile(cb_pre, i, i); \ + PROCESS_ACTIVATIONS(op, i); \ + } \ + tile_regs_commit(); \ + \ + tile_regs_wait(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + pack_tile(i, cb_post); \ + } \ + tile_regs_release(); \ + \ + cb_pop_front(cb_pre, per_core_block_size); \ + cb_push_back(cb_post, per_core_block_size); \ + \ + reconfig_data_format_srca(/*old*/ cb_pre, /*new*/ cb_post); \ + pack_reconfig_data_format(/*old*/ cb_post, /*new*/ cb_out); \ } while (0) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index 7c8513209ae..ccc7604724f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -27,59 +27,58 @@ union Converter { } }; -// update split eltwise ops include macros -void update_macro_defines(UnaryOpType op_type, std::map& defines) { +std::string get_macro_definition(UnaryOpType op_type) { switch (op_type) { - case UnaryOpType::EXP: defines["SFPU_OP_EXP_INCLUDE"] = "1"; break; - case UnaryOpType::GELU: defines["SFPU_OP_GELU_INCLUDE"] = "1"; break; - case UnaryOpType::RECIP: defines["SFPU_OP_RECIP_INCLUDE"] = "1"; break; - case UnaryOpType::SQRT: defines["SFPU_OP_SQRT_INCLUDE"] = "1"; break; - case UnaryOpType::ERFINV: defines["SFPU_OP_ERFINV_INCLUDE"] = "1"; break; + case UnaryOpType::EXP: return "SFPU_OP_EXP_INCLUDE"; + case UnaryOpType::GELU: return "SFPU_OP_GELU_INCLUDE"; + case UnaryOpType::RECIP: return "SFPU_OP_RECIP_INCLUDE"; + case UnaryOpType::SQRT: return "SFPU_OP_SQRT_INCLUDE"; + case UnaryOpType::ERFINV: return "SFPU_OP_ERFINV_INCLUDE"; case UnaryOpType::ERFC: - case UnaryOpType::ERF: defines["SFPU_OP_ERF_ERFC_INCLUDE"] = "1"; break; - case UnaryOpType::ELU: defines["SFPU_OP_ELU_INCLUDE"] = "1"; break; + case UnaryOpType::ERF: return "SFPU_OP_ERF_ERFC_INCLUDE"; + case UnaryOpType::ELU: return "SFPU_OP_ELU_INCLUDE"; case UnaryOpType::RELU: case UnaryOpType::RELU6: case UnaryOpType::RELU_MAX: case UnaryOpType::RELU_MIN: - case UnaryOpType::LEAKY_RELU: defines["SFPU_OP_RELU_FAMILY_INCLUDE"] = "1"; break; + case UnaryOpType::LEAKY_RELU: return "SFPU_OP_RELU_FAMILY_INCLUDE"; case UnaryOpType::ADD_UNARY_SFPU: case UnaryOpType::SUB_UNARY_SFPU: case UnaryOpType::MUL_UNARY_SFPU: - case UnaryOpType::DIV_UNARY_SFPU: defines["SFPU_OP_BINOP_WITH_SCALAR_INCLUDE"] = "1"; break; + case UnaryOpType::DIV_UNARY_SFPU: return "SFPU_OP_BINOP_WITH_SCALAR_INCLUDE"; case UnaryOpType::IDENTITY: - case UnaryOpType::IDENTITY_UINT32: defines["SFPU_OP_IDENTITY_INCLUDE"] = "1"; break; + case UnaryOpType::IDENTITY_UINT32: return "SFPU_OP_IDENTITY_INCLUDE"; case UnaryOpType::FLOOR: - case UnaryOpType::FLOOR_FLOAT32: defines["SFPU_OP_FLOOR_INCLUDE"] = "1"; break; + case UnaryOpType::FLOOR_FLOAT32: return "SFPU_OP_FLOOR_INCLUDE"; case UnaryOpType::CEIL: - case UnaryOpType::CEIL_FLOAT32: defines["SFPU_OP_CEIL_INCLUDE"] = "1"; break; - case UnaryOpType::RDIV: break; - case UnaryOpType::RSUB: defines["SFPU_OP_REVERSE_FAMILY_INCLUDE"] = "1"; + case UnaryOpType::CEIL_FLOAT32: return "SFPU_OP_CEIL_INCLUDE"; + case UnaryOpType::RDIV: + case UnaryOpType::RSUB: return "SFPU_OP_REVERSE_FAMILY_INCLUDE"; case UnaryOpType::ISINF: case UnaryOpType::ISNAN: case UnaryOpType::ISNEGINF: case UnaryOpType::ISPOSINF: - case UnaryOpType::ISFINITE: defines["SFPU_OP_ISINF_ISNAN_INCLUDE"] = "1"; break; - case UnaryOpType::LOGICAL_NOT_UNARY: defines["SFPU_OP_LOGICAL_NOT_NOTI_INCLUDE"] = "1"; break; - case UnaryOpType::I0: defines["SFPU_OP_I0_INCLUDE"] = "1"; break; - case UnaryOpType::I1: defines["SFPU_OP_I1_INCLUDE"] = "1"; break; + case UnaryOpType::ISFINITE: return "SFPU_OP_ISINF_ISNAN_INCLUDE"; + case UnaryOpType::LOGICAL_NOT_UNARY: return "SFPU_OP_LOGICAL_NOT_NOTI_INCLUDE"; + case UnaryOpType::I0: return "SFPU_OP_I0_INCLUDE"; + case UnaryOpType::I1: return "SFPU_OP_I1_INCLUDE"; case UnaryOpType::COS: case UnaryOpType::SIN: - case UnaryOpType::TAN: defines["SFPU_OP_TRIG_FAMILY_INCLUDE"] = "1"; break; - case UnaryOpType::NEG: defines["SFPU_OP_NEG_INCLUDE"] = "1"; break; - case UnaryOpType::SOFTPLUS: defines["SFPU_OP_SOFTPLUS_INCLUDE"] = "1"; break; - case UnaryOpType::PRELU_SFPU: defines["SFPU_OP_PRELU_INCLUDE"] = "1"; break; - case UnaryOpType::TYPECAST: defines["SFPU_OP_TYPECAST_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_XOR: defines["SFPU_OP_BITWISE_XOR_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_NOT: defines["SFPU_OP_BITWISE_NOT_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_AND: defines["SFPU_OP_BITWISE_AND_INCLUDE"] = "1"; break; - case UnaryOpType::BITWISE_OR: defines["SFPU_OP_BITWISE_OR_INCLUDE"] = "1"; break; - case UnaryOpType::RIGHT_SHIFT: defines["SFPU_OP_RIGHT_SHIFT_INCLUDE"] = "1"; break; - case UnaryOpType::LEFT_SHIFT: defines["SFPU_OP_LEFT_SHIFT_INCLUDE"] = "1"; break; - case UnaryOpType::REMAINDER: defines["SFPU_OP_REMAINDER_INCLUDE"] = "1"; break; - case UnaryOpType::FMOD: defines["SFPU_OP_FMOD_INCLUDE"] = "1"; break; - case UnaryOpType::FILL: defines["SFPU_OP_FILL_INCLUDE"] = "1"; break; - default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break; + case UnaryOpType::TAN: return "SFPU_OP_TRIG_FAMILY_INCLUDE"; + case UnaryOpType::NEG: return "SFPU_OP_NEG_INCLUDE"; + case UnaryOpType::SOFTPLUS: return "SFPU_OP_SOFTPLUS_INCLUDE"; + case UnaryOpType::PRELU_SFPU: return "SFPU_OP_PRELU_INCLUDE"; + case UnaryOpType::TYPECAST: return "SFPU_OP_TYPECAST_INCLUDE"; + case UnaryOpType::BITWISE_XOR: return "SFPU_OP_BITWISE_XOR_INCLUDE"; + case UnaryOpType::BITWISE_NOT: return "SFPU_OP_BITWISE_NOT_INCLUDE"; + case UnaryOpType::BITWISE_AND: return "SFPU_OP_BITWISE_AND_INCLUDE"; + case UnaryOpType::BITWISE_OR: return "SFPU_OP_BITWISE_OR_INCLUDE"; + case UnaryOpType::RIGHT_SHIFT: return "SFPU_OP_RIGHT_SHIFT_INCLUDE"; + case UnaryOpType::LEFT_SHIFT: return "SFPU_OP_LEFT_SHIFT_INCLUDE"; + case UnaryOpType::REMAINDER: return "SFPU_OP_REMAINDER_INCLUDE"; + case UnaryOpType::FMOD: return "SFPU_OP_FMOD_INCLUDE"; + case UnaryOpType::FILL: return "SFPU_OP_FILL_INCLUDE"; + default: return "SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"; }; } @@ -436,4 +435,9 @@ std::map get_block_defines( return block_defines; } +// update split eltwise ops include macros +void update_macro_defines(UnaryOpType op_type, std::map& defines) { + defines[get_macro_definition(op_type)] = "1"; +} + } // namespace ttnn::operations::unary::utils diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp index bd2906cb4e8..2d5594c12e0 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.hpp @@ -70,4 +70,6 @@ bool is_parametrized_type(T val) { return false; } +void update_macro_defines(UnaryOpType op_type, std::map& defines); + } // namespace ttnn::operations::unary::utils diff --git a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp b/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp index 5adce7b92b4..90a984ae656 100644 --- a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp @@ -28,6 +28,7 @@ std::ostream& operator<<(std::ostream& os, const SmallVector 0) { os << ", "; } + using tt::stl::reflection::operator<<; os << vec[i]; } os << "])";