Skip to content

Commit

Permalink
#0: Cleanup binary op apis
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Dec 11, 2024
1 parent 8e49222 commit abd7c42
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 364 deletions.
16 changes: 0 additions & 16 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,6 @@ constexpr auto ne_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::ne_",
operations::binary::InplaceRelationalBinary<operations::binary::BinaryOpType::NE>>();

constexpr auto rsub_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::rsub_binary",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::RSUB>>();
constexpr auto power_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::power_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::POWER>>();
constexpr auto bitwise_and_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_and_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_AND>>();
constexpr auto bitwise_or_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_or_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_OR>>();
constexpr auto bitwise_xor_binary = ttnn::register_operation_with_auto_launch_op<
"ttnn::bitwise_xor_binary",
operations::binary::BinaryOperationSfpu<operations::binary::BinaryOpType::BITWISE_XOR>>();

template <typename InputBType>
ttnn::Tensor operator+(const ttnn::Tensor& input_tensor_a, InputBType scalar) {
return add(input_tensor_a, scalar);
Expand Down
64 changes: 64 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,69 @@ namespace operations {

namespace binary {

/**
* @brief Performs element-wise power operation on the input with the exponent.
* When exponent is Tensor, the supported dtypes are float32 and bfloat16.
* The tested range for the input is (-30,30) and for the exponent is (-20, 20).
*
* @param input The input tensor, i.e the base.
* @param exponent The exponent
* @return The result tensor
*/
struct ExecutePower {
static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
float input_a,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
float input_a,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
uint8_t queue_id,
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
};

template <BinaryCompositeOpType binary_comp_op_type>
struct ExecuteBinaryCompositeOps {
static Tensor invoke(
Expand Down Expand Up @@ -436,5 +499,6 @@ constexpr auto rsub = ttnn::register_operation_with_auto_launch_op<"ttnn::rsub",
constexpr auto bitwise_and = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_and", operations::binary::ExecuteBitwiseAnd>();
constexpr auto bitwise_or = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_or", operations::binary::ExecuteBitwiseOr>();
constexpr auto bitwise_xor = ttnn::register_operation_with_auto_launch_op<"ttnn::bitwise_xor", operations::binary::ExecuteBitwiseXor>();
constexpr auto pow = ttnn::register_operation_with_auto_launch_op<"ttnn::pow", operations::binary::ExecutePower>();

} // namespace ttnn
121 changes: 121 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,124 @@ void bind_binary_inplace_operation(
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt});
}

template <typename binary_operation_t>
void bind_power(py::module& module, const binary_operation_t& operation, const std::string& note = "") {
auto doc = fmt::format(
R"doc(
Perform element-wise {0} operation on :attr:`input_tensor` with :attr:`exponent`.
.. math::
\mathrm{{output\_tensor}}_i = (\mathrm{{input\_tensor}}_i ** \mathrm{{exponent}}_i)
Args:
input_tensor (ttnn.Tensor): the input tensor.
exponent (float, int): the exponent value.
Keyword Args:
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (int, optional): command queue id. Defaults to `0`.
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16, BFLOAT8_B
- TILE
- 2, 3, 4
{2}
Example:
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> exponent = 2
>>> output = {1}(tensor, exponent)
)doc",
ttnn::pow.base_name(),
ttnn::pow.python_fully_qualified_name(),
note);

bind_registered_operation(
module,
ttnn::pow,
doc,
// integer exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
uint32_t exponent,
const std::optional<MemoryConfig>& memory_config,
const std::optional<Tensor>& output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0},

// float exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
float exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId},

// tensor exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, exponent, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId},

// scalar input - tensor exponent
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
float input,
const Tensor& exponent,
const std::optional<MemoryConfig>& memory_config,
std::optional<Tensor> output_tensor,
const uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, input, exponent, memory_config, output_tensor);
},
py::arg("input"),
py::arg("exponent"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId});
}
} // namespace detail

void py_module(py::module& module) {
Expand Down Expand Up @@ -1689,6 +1807,9 @@ void py_module(py::module& module) {
R"doc(Performs Not equal to in-place operation on :attr:`input_a` and :attr:`input_b` and returns the tensor with the same layout as :attr:`input_tensor`)doc",
R"doc(\mathrm{{input\_tensor\_a}}\: != \mathrm{{input\_tensor\_b}})doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_power(
module, ttnn::pow, R"doc(When :attr:`exponent` is a Tensor, supported dtypes are: BFLOAT16, FLOAT32)doc");
}

} // namespace binary
Expand Down
Loading

0 comments on commit abd7c42

Please sign in to comment.