Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "#11043: Overload complex forward ops" #11091

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/operations/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.mul doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down Expand Up @@ -369,6 +370,7 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.div doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down
25 changes: 0 additions & 25 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"

namespace ttnn::operations::binary {

Expand Down Expand Up @@ -219,30 +218,6 @@ Tensor BinaryOperation<binary_op_type, in_place>::operator()(
}


template <BinaryOpType binary_op_type, bool in_place>
ComplexTensor BinaryOperation<binary_op_type, in_place>::operator()(
const ComplexTensor &input_a,
const ComplexTensor &input_b,
const ttnn::MemoryConfig &output_mem_config) {
if constexpr(binary_op_type == BinaryOpType::MUL) {
Tensor re_part = ttnn::subtract(
ttnn::multiply(input_a[0],input_b[0],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[1],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

Tensor im_part = ttnn::add(
ttnn::multiply(input_a[0],input_b[1],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[0],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

return ComplexTensor({ re_part, im_part });
}else if constexpr(binary_op_type == BinaryOpType::DIV_FAST) {
return ttnn::multiply( input_a, ttnn::operations::complex_unary::_reciprocal( input_b , output_mem_config ), output_mem_config ); //TODO: Overload reciprocal
}else {
TT_THROW("Unsupported operation (expected MUL or DIV_FAST)");
}
}

template <BinaryOpType binary_op_type, bool in_place>
Tensor RelationalBinary<binary_op_type, in_place>::operator()(
uint8_t queue_id,
Expand Down
11 changes: 2 additions & 9 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "ttnn/decorators.hpp"
#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp"
#include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "ttnn/operations/eltwise/complex/complex.hpp"

namespace ttnn {

Expand Down Expand Up @@ -59,12 +58,6 @@ struct BinaryOperation {
const std::optional<Tensor> &optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

static ComplexTensor operator()(
const ComplexTensor &input_tensor_a_arg,
const ComplexTensor &input_tensor_b_arg,
const MemoryConfig &memory_config);

};

template <BinaryOpType binary_op_type, bool in_place>
Expand Down Expand Up @@ -132,7 +125,7 @@ constexpr auto subtract = ttnn::register_operation_with_auto_launch_op<
constexpr auto subtract_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::subtract_",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SUB, true>>();
constexpr auto multiply = ttnn::register_operation<
constexpr auto multiply = ttnn::register_operation_with_auto_launch_op<
"ttnn::multiply",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::MUL, false>>();
constexpr auto multiply_ = ttnn::register_operation_with_auto_launch_op<
Expand Down Expand Up @@ -176,7 +169,7 @@ constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op<
constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op<
"ttnn::squared_difference",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SQUARED_DIFFERENCE, false>>();
constexpr auto divide = ttnn::register_operation<
constexpr auto divide = ttnn::register_operation_with_auto_launch_op<
"ttnn::divide",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::DIV_FAST, false>>();

Expand Down
103 changes: 5 additions & 98 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,99 +104,6 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati
py::arg("queue_id") = 0});
}

template <typename binary_operation_t>
void bind_binary_operation_with_complex(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor

{2}
Supports broadcasting.

Args:
* :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor)
* :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number)

Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id

Example:

>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(tensor1, tensor2)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,
// tensor and scalar
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ttnn::Tensor& input_tensor_a,
const float scalar,
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const std::optional<unary::FusedActivations>& activations,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor, activations, input_tensor_a_activation);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt,
py::arg("queue_id") = 0},

// tensor and tensor
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const std::optional<unary::FusedActivations>& activations,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor, activations, input_tensor_a_activation);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt,
py::arg("queue_id") = 0},

// complex tensor
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ComplexTensor& input_tensor_a,
const ComplexTensor& input_tensor_b,
const MemoryConfig& memory_config) {
return self(input_tensor_a, input_tensor_b, memory_config);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config")});
}

template <typename binary_operation_t>
void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
Expand Down Expand Up @@ -380,11 +287,11 @@ void bind_div_like_ops(py::module& module, const binary_operation_t& operation,
template <typename binary_operation_t>
void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Args:
* :attr:`input_tensor_a` (ttnn.Tensor or ComplexTensor)
* :attr:`input_tensor_b` (ttnn.Tensor or Number or ComplexTensor)
* :attr:`input_tensor_a`
* :attr:`input_tensor_b` (ttnn.Tensor or Number)
* :attr:`accurate_mode`: ``false`` if input_tensor_b is non-zero, else ``true``.
* :attr:`round_mode`

Expand Down Expand Up @@ -557,7 +464,7 @@ void py_module(py::module& module) {
R"doc(Subtracts :attr:`input_tensor_b` from :attr:`input_tensor_a` and returns the tensor with the same layout as :attr:`input_tensor_a` in-place
.. math:: \mathrm{{input\_tensor\_a}}_i - \mathrm{{input\_tensor\_b}}_i)doc");

detail::bind_binary_operation_with_complex(
detail::bind_binary_operation(
module,
ttnn::multiply,
R"doc(Multiplies :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
Expand Down Expand Up @@ -647,7 +554,7 @@ void py_module(py::module& module) {
R"doc(Compute bias_gelu of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
.. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc");

detail::bind_binary_operation_with_complex(
detail::bind_binary_operation(
module,
ttnn::divide,
R"doc(Divides :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
Expand Down
Loading