Skip to content

Commit

Permalink
#10778: Update polygamma with ttnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Jul 31, 2024
1 parent ba2ab9f commit 0f35c21
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 41 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,6 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.logical_ori

.. autofunction:: tt_lib.tensor.polygamma

.. autofunction:: tt_lib.tensor.frac

.. autofunction:: tt_lib.tensor.floor_div
Expand Down
8 changes: 4 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def eltwise_logit(x, *args, eps, device, dtype, layout, input_mem_config, output
@setup_host_and_device
def eltwise_polygamma(x, *args, k, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.polygamma(t0, k, output_mem_config=output_mem_config)
t1 = ttnn.polygamma(t0, k=k, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down Expand Up @@ -994,7 +994,7 @@ def eltwise_round(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.round(t0, decimals, output_mem_config=output_mem_config)
t1 = ttnn.round(t0, decimals=decimals, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down Expand Up @@ -1416,15 +1416,15 @@ def zeros(x, *args, device, dtype, layout, input_mem_config, output_mem_config,
def triu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
tx = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
diag = kwargs.get("diag", 0)
t1 = ttnn.triu(tx, diag, memory_config=output_mem_config)
t1 = ttnn.triu(tx, diagonal=diag, memory_config=output_mem_config)
return tt2torch_tensor(t1)


@setup_host_and_device
def tril(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
tx = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
diag = kwargs.get("diag", 0)
t1 = ttnn.tril(tx, diag, memory_config=output_mem_config)
t1 = ttnn.tril(tx, diagonal=diag, memory_config=output_mem_config)
return tt2torch_tensor(t1)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,30 +493,6 @@ Tensor logit(const Tensor& input_a, float eps, const MemoryConfig& output_mem_co
return operation::decorate_as_composite(__func__, _logit)(input_a, eps, output_mem_config);
}

// polygamma support for the range of input(1, 10) and n(1, 10)
Tensor _polygamma(const Tensor& input_a, uint32_t k, const MemoryConfig& output_mem_config) {
float k_der = 1.0f + k;
float fact_val = std::tgamma(k_der);
float pos_neg = 1.0f;
if (k == 2 || k == 4 || k == 6 || k == 8 || k == 10) {
pos_neg = -1.0f;
}
Tensor temp(input_a);
{
Tensor z1 = ttnn::reciprocal(ttnn::power(input_a, k_der, output_mem_config), output_mem_config);
temp = z1;
for (int idx = 1; idx < 11; idx++) {
z1 = ttnn::reciprocal(ttnn::power(ttnn::add(input_a, idx, std::nullopt, output_mem_config), k_der, output_mem_config), output_mem_config);
temp = ttnn::add(temp, z1, std::nullopt, output_mem_config);
}
}
fact_val *= pos_neg;
return ttnn::multiply(temp, fact_val, std::nullopt, output_mem_config);
}
Tensor polygamma(const Tensor& input_a, uint32_t value, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _polygamma)(input_a, value, output_mem_config);
}

// logical_xori
Tensor _logical_xori(const Tensor& input_a, float value, const MemoryConfig& output_mem_config) {
if (std::fpclassify(value) == FP_ZERO) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,6 @@ Tensor eps(
Tensor logit(
const Tensor& input_a, float eps, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// polygamma
Tensor polygamma(
const Tensor& input_a, uint32_t k, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor logical_xori(
const Tensor& input_a,
float immediate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,13 +832,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
R"doc("dimension to logit along", "int", "0, 1, 2, or 3")doc"
);

detail::bind_unary_op_with_param(
m_tensor, "polygamma", &polygamma,
py::arg("n"),
R"doc(Returns a tensor that is a polygamma of input tensor where the range supports from 1 to 10 with shape ``[W, Z, Y, X]`` along n ``{1}``.)doc",
R"doc("the order of the polygamma along", "int", "1 to 10")doc"
);

detail::bind_unary_op_with_param(
m_tensor, "logical_xori", &logical_xori,
py::arg("immediate"),
Expand Down

0 comments on commit 0f35c21

Please sign in to comment.