Skip to content

Commit

Permalink
#7285: Add celu forward support
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthreshx committed Apr 15, 2024
1 parent 6a41653 commit def0269
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Pointwise Unary
ttnn/atan2
ttnn/atanh
ttnn/cbrt
ttnn/celu
ttnn/clip
ttnn/clone
ttnn/cos
Expand Down
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.subalpha

.. autofunction:: tt_lib.tensor.celu

.. autofunction:: tt_lib.tensor.addalpha

.. autofunction:: tt_lib.tensor.ldexp
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/celu.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.celu:

ttnn.celu
#########

.. autofunction:: ttnn.celu
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,10 @@
"tt_lib_op": tt_lib_ops.eltwise_addcmul,
"pytorch_op": pytorch_ops.addcmul,
},
"eltwise-celu": {
"tt_lib_op": tt_lib_ops.eltwise_celu,
"pytorch_op": pytorch_ops.celu,
},
"eltwise-addcdiv": {
"tt_lib_op": tt_lib_ops.eltwise_addcdiv,
"pytorch_op": pytorch_ops.addcdiv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def custom_compare(*args, **kwargs):
"polygamma",
"nextafter",
"scatter",
"celu",
),
shapes,
)
Expand All @@ -127,6 +128,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
options["rad2deg"] = (0, 2 * pi)
options["hypot"] = (1, 100)
options["atan2"] = (-100, 100)
options["celu"] = (-100, 100)
options["cbrt"] = (-1000, 1000)
options["hardsigmoid"] = (-100, 100)
options["hardswish"] = (-100, 100)
Expand Down Expand Up @@ -209,7 +211,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
test_args.update({"value": np.random.randint(1, 100)})
elif fn in ["lerp_binary"]:
test_args.update({"weight": np.random.randint(1, 100)})
elif fn in ["subalpha"]:
elif fn in ["subalpha", "celu"]:
test_args.update({"alpha": np.random.randint(1, 100)})
elif fn in ["addalpha"]:
test_args.update({"alpha": np.random.randint(1, 100)})
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,10 @@ def addalpha(x, y, *args, alpha, **kwargs):
return torch.add(x, y, alpha=alpha)


def celu(x, *args, alpha, **kwargs):
return torch.celu(x, alpha=alpha)


def lamb_optimizer(x, y, z, w, *args, beta1, beta2, step_size, eps, weight_decay, **kwargs):
exp_avg_out, exp_avg_sq_out, param = lamb_optimizer_kernel.lamb_kernel(
x, y, z, w, beta1=beta1, beta2=beta2, step_size=step_size, eps=eps, weight_decay=weight_decay
Expand Down
18 changes: 18 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,24 @@ def eltwise_subalpha(
return tt2torch_tensor(t2)


@setup_host_and_device
def eltwise_celu(
x,
*args,
alpha,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t2 = ttl.tensor.celu(t0, alpha, output_mem_config=output_mem_config)

return tt2torch_tensor(t2)


@setup_host_and_device
def eltwise_logit(x, *args, eps, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
Expand Down
17 changes: 17 additions & 0 deletions tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,23 @@ def eltwise_log_sigmoid(
return ttnn_tensor_to_torch(t1)


def eltwise_celu(
x,
*args,
alpha,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.celu(t0, alpha, memory_config=memory_config_to_ttnn(output_mem_config))

return ttnn_tensor_to_torch(t1)


def eltwise_logit(
x,
*args,
Expand Down
7 changes: 7 additions & 0 deletions tests/ttnn/unit_tests/operations/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ def test_scalarB_elu(device, h, w, scalar):
run_activation_test_scalarB(device, h, w, scalar, ttnn.elu, F.elu)


@pytest.mark.parametrize("alpha", [1, 2.5, 5.0])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_scalarB_celu(device, h, w, alpha):
run_activation_test_scalarB(device, h, w, alpha, ttnn.celu, F.celu)


@pytest.mark.parametrize("scalar", [0.5, 1.0])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
Expand Down
15 changes: 15 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,21 @@ Tensor xlogy(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o
return operation::decorate_as_composite(__func__, _xlogy)(input_a, input_b, output_mem_config);
}

// Celu
// torch.where(x > 0, x, alpha * (torch.exp(x / alpha) - 1))
Tensor _celu(const Tensor& input_a, float alpha, const MemoryConfig& output_mem_config) {
float recip_val = 1.0f / alpha;
Tensor tmp = sub_unary(exp(mul_unary(input_a, recip_val, output_mem_config), output_mem_config), 1.0f, output_mem_config);
Tensor result = mul_unary(tmp, alpha, output_mem_config);
tmp.deallocate();
result = where(gtz(input_a, output_mem_config), input_a, result, output_mem_config);
return result;
}
Tensor celu(const Tensor& input_a, float alpha, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _celu)(input_a, alpha, output_mem_config);
}


Tensor _variance_impl(
const Tensor& y, const Tensor& mean_y, Tensor& y_minus_mean_y, const MemoryConfig& output_mem_config) {
constexpr float correction = 0.0f;
Expand Down
5 changes: 5 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ Tensor selu(
const float alpha = 1.6732632423543772848170429916717,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor celu(
const Tensor& x,
float alpha,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// Function Swish = same as SILU
// use transformation y = x * sigmoid( x ) by broadcast
Tensor swish(const Tensor& a, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
Expand Down
16 changes: 16 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("celu", &celu,
py::arg("input").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Applies the celu function to the elements of the input tensor ``input``.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Tensor celu is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"alpha", "alpha value (PyTorch default)", "float", "default to 1.0", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("subalpha", &subalpha,
py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("alpha") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Subtracts ``input_b``, scaled by ``alpha``, from ``input_a``.
Expand Down
1 change: 1 addition & 0 deletions ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def manage_config_attribute(name, value):
geglu,
reglu,
swiglu,
celu,
)

from ttnn.operations.math import (
Expand Down
2 changes: 2 additions & 0 deletions ttnn/ttnn/operations/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
"relu6": F.relu6,
"sigmoid": torch.sigmoid,
"sign": torch.sign,
"celu": F.celu,
"softsign": F.softsign,
"swish": F.hardswish,
"softplus": F.softplus,
Expand Down Expand Up @@ -363,6 +364,7 @@ def activation_function(
("leaky_relu", ttl.tensor.leaky_relu, "leaky relu", "slope"),
("prelu", ttl.tensor.prelu, "prelu", "weight"),
("elu", ttl.tensor.elu, "elu", "alpha"),
("celu", ttl.tensor.celu, "celu", "alpha"),
("softshrink", ttl.tensor.softshrink, "softshrink", "lambda"),
]

Expand Down

0 comments on commit def0269

Please sign in to comment.