diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index f01e1dccf2ac..89835b414831 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -82,6 +82,7 @@ Pointwise Unary ttnn/atan2 ttnn/atanh ttnn/cbrt + ttnn/celu ttnn/clip ttnn/clone ttnn/cos diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 3a81bcdbb554..482d3e2a5290 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -425,6 +425,8 @@ Tensor elementwise operations .. autofunction:: tt_lib.tensor.subalpha +.. autofunction:: tt_lib.tensor.celu + .. autofunction:: tt_lib.tensor.addalpha .. autofunction:: tt_lib.tensor.ldexp diff --git a/docs/source/ttnn/ttnn/ttnn/celu.rst b/docs/source/ttnn/ttnn/ttnn/celu.rst new file mode 100644 index 000000000000..7ba75a8cbd36 --- /dev/null +++ b/docs/source/ttnn/ttnn/ttnn/celu.rst @@ -0,0 +1,6 @@ +.. _ttnn.celu: + +ttnn.celu +######### + +.. autofunction:: ttnn.celu diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index 29c474531d20..018eac6a1202 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -546,6 +546,10 @@ "tt_lib_op": tt_lib_ops.eltwise_addcmul, "pytorch_op": pytorch_ops.addcmul, }, + "eltwise-celu": { + "tt_lib_op": tt_lib_ops.eltwise_celu, + "pytorch_op": pytorch_ops.celu, + }, "eltwise-addcdiv": { "tt_lib_op": tt_lib_ops.eltwise_addcdiv, "pytorch_op": pytorch_ops.addcdiv, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py index b6fe77e640d3..5ed3ae292798 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_composite.py @@ -112,6 +112,7 @@ def custom_compare(*args, **kwargs): "polygamma", "nextafter", "scatter", + "celu", ), shapes, ) @@ -127,6 +128,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def options["rad2deg"] = (0, 2 * pi) options["hypot"] = (1, 100) options["atan2"] = (-100, 100) + options["celu"] = (-100, 100) options["cbrt"] = (-1000, 1000) options["hardsigmoid"] = (-100, 100) options["hardswish"] = (-100, 100) @@ -209,7 +211,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def test_args.update({"value": np.random.randint(1, 100)}) elif fn in ["lerp_binary"]: test_args.update({"weight": np.random.randint(1, 100)}) - elif fn in ["subalpha"]: + elif fn in ["subalpha", "celu"]: test_args.update({"alpha": np.random.randint(1, 100)}) elif fn in ["addalpha"]: test_args.update({"alpha": np.random.randint(1, 100)}) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 4234dfa799cd..51eb5159632f 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -879,6 +879,10 @@ def addalpha(x, y, *args, alpha, **kwargs): return torch.add(x, y, alpha=alpha) +def celu(x, *args, alpha, **kwargs): + return torch.celu(x, alpha=alpha) + + def lamb_optimizer(x, y, z, w, *args, beta1, beta2, step_size, eps, weight_decay, **kwargs): exp_avg_out, exp_avg_sq_out, param = lamb_optimizer_kernel.lamb_kernel( x, y, z, w, beta1=beta1, beta2=beta2, step_size=step_size, eps=eps, weight_decay=weight_decay diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 26b7b26983cd..4f5b0d407123 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -935,6 +935,24 @@ def eltwise_subalpha( return tt2torch_tensor(t2) +@setup_host_and_device +def eltwise_celu( + x, + *args, + alpha, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t2 = ttl.tensor.celu(t0, alpha, output_mem_config=output_mem_config) + + return tt2torch_tensor(t2) + + @setup_host_and_device def eltwise_logit(x, *args, eps, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) diff --git a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py index ce09f3999b29..7c03d27af2bf 100644 --- a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py +++ b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py @@ -476,6 +476,23 @@ def eltwise_log_sigmoid( return ttnn_tensor_to_torch(t1) +def eltwise_celu( + x, + *args, + alpha, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = ttnn.celu(t0, alpha, memory_config=memory_config_to_ttnn(output_mem_config)) + + return ttnn_tensor_to_torch(t1) + + def eltwise_logit( x, *args, diff --git a/tests/ttnn/unit_tests/operations/test_activation.py b/tests/ttnn/unit_tests/operations/test_activation.py index 53a87d6481a1..0e407d481dd9 100644 --- a/tests/ttnn/unit_tests/operations/test_activation.py +++ b/tests/ttnn/unit_tests/operations/test_activation.py @@ -239,6 +239,13 @@ def test_scalarB_elu(device, h, w, scalar): run_activation_test_scalarB(device, h, w, scalar, ttnn.elu, F.elu) +@pytest.mark.parametrize("alpha", [1, 2.5, 5.0]) +@pytest.mark.parametrize("h", [64]) +@pytest.mark.parametrize("w", [128]) +def test_scalarB_celu(device, h, w, alpha): + run_activation_test_scalarB(device, h, w, alpha, ttnn.celu, F.celu) + + @pytest.mark.parametrize("scalar", [0.5, 1.0]) @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index c87a6b1f2e62..9ca75a3df723 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -901,6 +901,21 @@ Tensor xlogy(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o return operation::decorate_as_composite(__func__, _xlogy)(input_a, input_b, output_mem_config); } +// Celu +// torch.where(x > 0, x, alpha * (torch.exp(x / alpha) - 1)) +Tensor _celu(const Tensor& input_a, float alpha, const MemoryConfig& output_mem_config) { + float recip_val = 1.0f / alpha; + Tensor tmp = sub_unary(exp(mul_unary(input_a, recip_val, output_mem_config), output_mem_config), 1.0f, output_mem_config); + Tensor result = mul_unary(tmp, alpha, output_mem_config); + tmp.deallocate(); + result = where(gtz(input_a, output_mem_config), input_a, result, output_mem_config); + return result; +} +Tensor celu(const Tensor& input_a, float alpha, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _celu)(input_a, alpha, output_mem_config); +} + + Tensor _variance_impl( const Tensor& y, const Tensor& mean_y, Tensor& y_minus_mean_y, const MemoryConfig& output_mem_config) { constexpr float correction = 0.0f; diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index ffd734044ccf..8583d1bfbf61 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -99,6 +99,11 @@ Tensor selu( const float alpha = 1.6732632423543772848170429916717, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +Tensor celu( + const Tensor& x, + float alpha, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + // Function Swish = same as SILU // use transformation y = x * sigmoid( x ) by broadcast Tensor swish(const Tensor& a, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 78d3551bd3ab..4012f83585e1 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -418,6 +418,22 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + m_tensor.def("celu", &celu, + py::arg("input").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Applies the celu function to the elements of the input tensor ``input``. + + Input tensor must have BFLOAT16 data type. + + Output tensor will have BFLOAT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Tensor celu is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "alpha", "alpha value (PyTorch default)", "float", "default to 1.0", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); + m_tensor.def("subalpha", &subalpha, py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Subtracts ``input_b``, scaled by ``alpha``, from ``input_a``. diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 0365b3b565b1..17e2d984285f 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -364,6 +364,7 @@ def manage_config_attribute(name, value): geglu, reglu, swiglu, + celu, ) from ttnn.operations.math import ( diff --git a/ttnn/ttnn/operations/activation.py b/ttnn/ttnn/operations/activation.py index 0c1e6fe1c40d..825733a1c138 100644 --- a/ttnn/ttnn/operations/activation.py +++ b/ttnn/ttnn/operations/activation.py @@ -31,6 +31,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): "relu6": F.relu6, "sigmoid": torch.sigmoid, "sign": torch.sign, + "celu": F.celu, "softsign": F.softsign, "swish": F.hardswish, "softplus": F.softplus, @@ -363,6 +364,7 @@ def activation_function( ("leaky_relu", ttl.tensor.leaky_relu, "leaky relu", "slope"), ("prelu", ttl.tensor.prelu, "prelu", "weight"), ("elu", ttl.tensor.elu, "elu", "alpha"), + ("celu", ttl.tensor.celu, "celu", "alpha"), ("softshrink", ttl.tensor.softshrink, "softshrink", "lambda"), ]