From b8dde20f649a8d6467b9f65e7c619bdfbddaa517 Mon Sep 17 00:00:00 2001 From: Patrick Roberts Date: Mon, 5 Aug 2024 12:56:20 -0500 Subject: [PATCH] Revert "#11043: Overload complex forward ops" (#11091) This reverts commit 8d4008d36e1ee5c0a7e4cf453c57c74c423893ca. --- .../unit_tests/operations/test_complex.py | 2 + .../ttnn/operations/eltwise/binary/binary.cpp | 25 ----- .../ttnn/operations/eltwise/binary/binary.hpp | 11 +- .../eltwise/binary/binary_pybind.hpp | 103 +----------------- 4 files changed, 9 insertions(+), 132 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_complex.py b/tests/ttnn/unit_tests/operations/test_complex.py index 5ec748aec2c..5c912785fcf 100644 --- a/tests/ttnn/unit_tests/operations/test_complex.py +++ b/tests/ttnn/unit_tests/operations/test_complex.py @@ -332,6 +332,7 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults): assert passing +@pytest.mark.skip(reason="This test is failing because ttnn.mul doesn't support complex tensors") @pytest.mark.parametrize( "memcfg", ( @@ -369,6 +370,7 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults): assert passing +@pytest.mark.skip(reason="This test is failing because ttnn.div doesn't support complex tensors") @pytest.mark.parametrize( "memcfg", ( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 8f59389818a..3e45d16d6fe 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -8,7 +8,6 @@ #include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" -#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" namespace ttnn::operations::binary { @@ -219,30 +218,6 @@ Tensor BinaryOperation::operator()( } -template -ComplexTensor BinaryOperation::operator()( - const ComplexTensor &input_a, - const ComplexTensor &input_b, - const ttnn::MemoryConfig &output_mem_config) { - if constexpr(binary_op_type == BinaryOpType::MUL) { - Tensor re_part = ttnn::subtract( - ttnn::multiply(input_a[0],input_b[0],std::nullopt,output_mem_config), - ttnn::multiply(input_a[1],input_b[1],std::nullopt,output_mem_config), - std::nullopt, output_mem_config); - - Tensor im_part = ttnn::add( - ttnn::multiply(input_a[0],input_b[1],std::nullopt,output_mem_config), - ttnn::multiply(input_a[1],input_b[0],std::nullopt,output_mem_config), - std::nullopt, output_mem_config); - - return ComplexTensor({ re_part, im_part }); - }else if constexpr(binary_op_type == BinaryOpType::DIV_FAST) { - return ttnn::multiply( input_a, ttnn::operations::complex_unary::_reciprocal( input_b , output_mem_config ), output_mem_config ); //TODO: Overload reciprocal - }else { - TT_THROW("Unsupported operation (expected MUL or DIV_FAST)"); - } -} - template Tensor RelationalBinary::operator()( uint8_t queue_id, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index 738fcaa925e..5fe6dba4140 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -8,7 +8,6 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp" -#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn { @@ -59,12 +58,6 @@ struct BinaryOperation { const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt, std::optional input_tensor_a_activation = std::nullopt); - - static ComplexTensor operator()( - const ComplexTensor &input_tensor_a_arg, - const ComplexTensor &input_tensor_b_arg, - const MemoryConfig &memory_config); - }; template @@ -132,7 +125,7 @@ constexpr auto subtract = ttnn::register_operation_with_auto_launch_op< constexpr auto subtract_ = ttnn::register_operation_with_auto_launch_op< "ttnn::subtract_", operations::binary::BinaryOperation>(); -constexpr auto multiply = ttnn::register_operation< +constexpr auto multiply = ttnn::register_operation_with_auto_launch_op< "ttnn::multiply", operations::binary::BinaryOperation>(); constexpr auto multiply_ = ttnn::register_operation_with_auto_launch_op< @@ -176,7 +169,7 @@ constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op< constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op< "ttnn::squared_difference", operations::binary::BinaryOperation>(); -constexpr auto divide = ttnn::register_operation< +constexpr auto divide = ttnn::register_operation_with_auto_launch_op< "ttnn::divide", operations::binary::BinaryOperation>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index c8c593d6dd6..d9bbb16e9e7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -104,99 +104,6 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati py::arg("queue_id") = 0}); } -template -void bind_binary_operation_with_complex(py::module& module, const binary_operation_t& operation, const std::string& description) { - auto doc = fmt::format( - R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor - - {2} - Supports broadcasting. - - Args: - * :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor) - * :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number) - - Keyword args: - * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor - * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor - * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor - * :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor - * :attr:`queue_id` (Optional[uint8]): command queue id - - Example: - - >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) - >>> output = {1}(tensor1, tensor2) - )doc", - operation.base_name(), - operation.python_fully_qualified_name(), - description); - - bind_registered_operation( - module, - operation, - doc, - // tensor and scalar - ttnn::pybind_overload_t{ - [](const binary_operation_t& self, - const ttnn::Tensor& input_tensor_a, - const float scalar, - const std::optional& dtype, - const std::optional& memory_config, - const std::optional& output_tensor, - const std::optional& activations, - const std::optional& input_tensor_a_activation, - const uint8_t& queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor, activations, input_tensor_a_activation); - }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::kw_only(), - py::arg("dtype") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("activations") = std::nullopt, - py::arg("input_tensor_a_activation") = std::nullopt, - py::arg("queue_id") = 0}, - - // tensor and tensor - ttnn::pybind_overload_t{ - [](const binary_operation_t& self, - const ttnn::Tensor& input_tensor_a, - const ttnn::Tensor& input_tensor_b, - const std::optional& dtype, - const std::optional& memory_config, - const std::optional& output_tensor, - const std::optional& activations, - const std::optional& input_tensor_a_activation, - const uint8_t& queue_id) -> ttnn::Tensor { - return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor, activations, input_tensor_a_activation); - }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::kw_only(), - py::arg("dtype") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("activations") = std::nullopt, - py::arg("input_tensor_a_activation") = std::nullopt, - py::arg("queue_id") = 0}, - - // complex tensor - ttnn::pybind_overload_t{ - [](const binary_operation_t& self, - const ComplexTensor& input_tensor_a, - const ComplexTensor& input_tensor_b, - const MemoryConfig& memory_config) { - return self(input_tensor_a, input_tensor_b, memory_config); - }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::kw_only(), - py::arg("memory_config")}); -} - template void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description) { auto doc = fmt::format( @@ -380,11 +287,11 @@ void bind_div_like_ops(py::module& module, const binary_operation_t& operation, template void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description) { auto doc = fmt::format( - R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + R"doc({0}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor Args: - * :attr:`input_tensor_a` (ttnn.Tensor or ComplexTensor) - * :attr:`input_tensor_b` (ttnn.Tensor or Number or ComplexTensor) + * :attr:`input_tensor_a` + * :attr:`input_tensor_b` (ttnn.Tensor or Number) * :attr:`accurate_mode`: ``false`` if input_tensor_b is non-zero, else ``true``. * :attr:`round_mode` @@ -557,7 +464,7 @@ void py_module(py::module& module) { R"doc(Subtracts :attr:`input_tensor_b` from :attr:`input_tensor_a` and returns the tensor with the same layout as :attr:`input_tensor_a` in-place .. math:: \mathrm{{input\_tensor\_a}}_i - \mathrm{{input\_tensor\_b}}_i)doc"); - detail::bind_binary_operation_with_complex( + detail::bind_binary_operation( module, ttnn::multiply, R"doc(Multiplies :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a` @@ -647,7 +554,7 @@ void py_module(py::module& module) { R"doc(Compute bias_gelu of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a` .. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc"); - detail::bind_binary_operation_with_complex( + detail::bind_binary_operation( module, ttnn::divide, R"doc(Divides :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`