Skip to content

Commit

Permalink
#0: Cleanup binary op apis (#15906)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue

### Problem description
move `ttnn.pow` op from eltwise/unary to eltwise/binary
cleanup recently added _binary and _unary apis in #15805 

### What's changed
moved `ttnn.pow` op from eltwise/unary to eltwise/binary
removed the recently added _binary and _unary apis in #15805

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12278748922
https://github.com/tenstorrent/tt-metal/actions/runs/12292260933
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Dec 12, 2024
1 parent 1509a6e commit d43dc8d
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 372 deletions.
10 changes: 5 additions & 5 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ Pointwise Unary
ttnn.asinh
ttnn.atan
ttnn.atanh
ttnn.bitwise_and
ttnn.bitwise_or
ttnn.bitwise_xor
ttnn.bitwise_not
ttnn.bitwise_left_shift
ttnn.bitwise_right_shift
Expand Down Expand Up @@ -162,7 +159,6 @@ Pointwise Unary
ttnn.normalize_global
ttnn.normalize_hw
ttnn.polygamma
ttnn.pow
ttnn.prelu
ttnn.rad2deg
ttnn.rdiv
Expand All @@ -175,7 +171,6 @@ Pointwise Unary
ttnn.remainder
ttnn.round
ttnn.rsqrt
ttnn.rsub
ttnn.selu
ttnn.sigmoid
ttnn.sigmoid_accurate
Expand Down Expand Up @@ -309,10 +304,14 @@ Pointwise Binary
ttnn.logical_or_
ttnn.logical_xor_
ttnn.rpow
ttnn.rsub
ttnn.ldexp
ttnn.logical_and
ttnn.logical_or
ttnn.logical_xor
ttnn.bitwise_and
ttnn.bitwise_or
ttnn.bitwise_xor
ttnn.logaddexp
ttnn.logaddexp2
ttnn.hypot
Expand All @@ -335,6 +334,7 @@ Pointwise Binary
ttnn.maximum
ttnn.minimum
ttnn.outer
ttnn.pow
ttnn.polyval
ttnn.scatter
ttnn.atan2
Expand Down
16 changes: 0 additions & 16 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,6 @@ constexpr auto ne_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::ne_",
operations::binary::InplaceRelationalBinary<operations::binary::BinaryOpType::NE>>();

constexpr auto rsub_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::rsub_binary",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::RSUB>>();
constexpr auto power_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::power_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::POWER>>();
constexpr auto bitwise_and_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_and_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_AND>>();
constexpr auto bitwise_or_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_or_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_OR>>();
constexpr auto bitwise_xor_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_xor_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_XOR>>();

template <typename InputBType>
ttnn::Tensor operator+(const ttnn::Tensor& input_tensor_a, InputBType scalar) {
return add(input_tensor_a, scalar);
Expand Down
64 changes: 64 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,69 @@ namespace operations {

namespace binary {

/**
* @brief Performs element-wise power operation on the input with the exponent.
* When exponent is Tensor, the supported dtypes are float32 and bfloat16.
* The tested range for the input is (-30,30) and for the exponent is (-20, 20).
*
* @param input The input tensor, i.e the base.
* @param exponent The exponent
* @return The result tensor
*/
struct ExecutePower {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
float input_a,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
float input_a,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

template <BinaryCompositeOpType binary_comp_op_type>
struct ExecuteBinaryCompositeOps {
static Tensor invoke(
Expand Down Expand Up @@ -436,5 +499,6 @@ constexpr auto rsub = ttnn::register_operation_with_auto_launch_op<"ttnn::rsub",
constexpr auto bitwise_and = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_and", operations::binary::ExecuteBitwiseAnd>();
constexpr auto bitwise_or = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_or", operations::binary::ExecuteBitwiseOr>();
constexpr auto bitwise_xor = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_xor", operations::binary::ExecuteBitwiseXor>();
constexpr auto pow = ttnn::register_operation_with_auto_launch_op<"ttnn::pow", operations::binary::ExecutePower>();

} // namespace ttnn
127 changes: 124 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,124 @@ void bind_binary_inplace_operation(
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt});
}

template <typename binary_operation_t>
void bind_power(py::module& module, const binary_operation_t& operation, const std::string& note = "") {
auto doc = fmt::format(
R"doc(
Perform element-wise {0} operation on :attr:`input_tensor` with :attr:`exponent`.
.. math::
\mathrm{{output\_tensor}}_i = (\mathrm{{input\_tensor}}_i ** \mathrm{{exponent}}_i)
Args:
input_tensor (ttnn.Tensor, float): the input tensor.
exponent (float, int, ttnn.Tensor): the exponent value.
Keyword Args:
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (int, optional): command queue id. Defaults to `0`.
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16, BFLOAT8_B
- TILE
- 2, 3, 4
{2}
Example:
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> exponent = 2
>>> output = {1}(tensor, exponent)
)doc",
ttnn::pow.base_name(),
ttnn::pow.python_fully_qualified_name(),
note);

bind_registered_operation(
module,
ttnn::pow,
doc,
// integer exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0},

// float exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId},

// tensor exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId},

// scalar input - tensor exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
float input,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input, exponent, memory_config, output_tensor);
},
py::arg("input"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId});
}
} // namespace detail

void py_module(py::module& module) {
Expand Down Expand Up @@ -1434,23 +1552,23 @@ void py_module(py::module& module) {
module,
ttnn::bitwise_and,
R"doc(Perform bitwise_and operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_and| \mathrm{{input\_tensor\_a}}_i)doc",
R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_and|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc",
". ",
R"doc(INT32)doc");

detail::bind_bitwise_binary_ops_operation(
module,
ttnn::bitwise_or,
R"doc(Perform bitwise_or operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_or| \mathrm{{input\_tensor\_a}}_i)doc",
R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_or|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc",
". ",
R"doc(INT32)doc");

detail::bind_bitwise_binary_ops_operation(
module,
ttnn::bitwise_xor,
R"doc(Perform bitwise_xor operation on :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{{output\_tensor}}_i = \mathrm{{input\_tensor\_b}}_i \verb|bitwise_xor| \mathrm{{input\_tensor\_a}}_i)doc",
R"doc(\mathrm{{output\_tensor}}_i = \verb|bitwise_xor|(\mathrm{{input\_tensor\_a, input\_tensor\_b}}))doc",
". ",
R"doc(INT32)doc");

Expand Down Expand Up @@ -1689,6 +1807,9 @@ void py_module(py::module& module) {
R"doc(Performs Not equal to in-place operation on :attr:`input_a` and :attr:`input_b` and returns the tensor with the same layout as :attr:`input_tensor`)doc",
R"doc(\mathrm{{input\_tensor\_a}}\: != \mathrm{{input\_tensor\_b}})doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_power(
module, ttnn::pow, R"doc(When :attr:`exponent` is a Tensor, supported dtypes are: BFLOAT16, FLOAT32)doc");
}

} // namespace binary
Expand Down
Loading

0 comments on commit d43dc8d

Please sign in to comment.