diff --git a/tests/ttnn/unit_tests/operations/test_complex.py b/tests/ttnn/unit_tests/operations/test_complex.py index 5c912785fcf..5ec748aec2c 100644 --- a/tests/ttnn/unit_tests/operations/test_complex.py +++ b/tests/ttnn/unit_tests/operations/test_complex.py @@ -332,7 +332,6 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults): assert passing -@pytest.mark.skip(reason="This test is failing because ttnn.mul doesn't support complex tensors") @pytest.mark.parametrize( "memcfg", ( @@ -370,7 +369,6 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults): assert passing -@pytest.mark.skip(reason="This test is failing because ttnn.div doesn't support complex tensors") @pytest.mark.parametrize( "memcfg", ( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 3e45d16d6fe..8f59389818a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -8,6 +8,7 @@ #include "ttnn/device_operation.hpp" #include "ttnn/operations/data_movement.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" +#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" namespace ttnn::operations::binary { @@ -218,6 +219,30 @@ Tensor BinaryOperation::operator()( } +template +ComplexTensor BinaryOperation::operator()( + const ComplexTensor &input_a, + const ComplexTensor &input_b, + const ttnn::MemoryConfig &output_mem_config) { + if constexpr(binary_op_type == BinaryOpType::MUL) { + Tensor re_part = ttnn::subtract( + ttnn::multiply(input_a[0],input_b[0],std::nullopt,output_mem_config), + ttnn::multiply(input_a[1],input_b[1],std::nullopt,output_mem_config), + std::nullopt, output_mem_config); + + Tensor im_part = ttnn::add( + ttnn::multiply(input_a[0],input_b[1],std::nullopt,output_mem_config), + ttnn::multiply(input_a[1],input_b[0],std::nullopt,output_mem_config), + std::nullopt, output_mem_config); + + return ComplexTensor({ re_part, im_part }); + }else if constexpr(binary_op_type == BinaryOpType::DIV_FAST) { + return ttnn::multiply( input_a, ttnn::operations::complex_unary::_reciprocal( input_b , output_mem_config ), output_mem_config ); //TODO: Overload reciprocal + }else { + TT_THROW("Unsupported operation (expected MUL or DIV_FAST)"); + } +} + template Tensor RelationalBinary::operator()( uint8_t queue_id, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index 5fe6dba4140..738fcaa925e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -8,6 +8,7 @@ #include "ttnn/decorators.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn { @@ -58,6 +59,12 @@ struct BinaryOperation { const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt, std::optional input_tensor_a_activation = std::nullopt); + + static ComplexTensor operator()( + const ComplexTensor &input_tensor_a_arg, + const ComplexTensor &input_tensor_b_arg, + const MemoryConfig &memory_config); + }; template @@ -125,7 +132,7 @@ constexpr auto subtract = ttnn::register_operation_with_auto_launch_op< constexpr auto subtract_ = ttnn::register_operation_with_auto_launch_op< "ttnn::subtract_", operations::binary::BinaryOperation>(); -constexpr auto multiply = ttnn::register_operation_with_auto_launch_op< +constexpr auto multiply = ttnn::register_operation< "ttnn::multiply", operations::binary::BinaryOperation>(); constexpr auto multiply_ = ttnn::register_operation_with_auto_launch_op< @@ -169,7 +176,7 @@ constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op< constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op< "ttnn::squared_difference", operations::binary::BinaryOperation>(); -constexpr auto divide = ttnn::register_operation_with_auto_launch_op< +constexpr auto divide = ttnn::register_operation< "ttnn::divide", operations::binary::BinaryOperation>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index d9bbb16e9e7..c8c593d6dd6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -104,6 +104,99 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati py::arg("queue_id") = 0}); } +template +void bind_binary_operation_with_complex(py::module& module, const binary_operation_t& operation, const std::string& description) { + auto doc = fmt::format( + R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor + + {2} + Supports broadcasting. + + Args: + * :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor) + * :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number) + + Keyword args: + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Example: + + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) + >>> output = {1}(tensor1, tensor2) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + description); + + bind_registered_operation( + module, + operation, + doc, + // tensor and scalar + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const ttnn::Tensor& input_tensor_a, + const float scalar, + const std::optional& dtype, + const std::optional& memory_config, + const std::optional& output_tensor, + const std::optional& activations, + const std::optional& input_tensor_a_activation, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor, activations, input_tensor_a_activation); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("dtype") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("input_tensor_a_activation") = std::nullopt, + py::arg("queue_id") = 0}, + + // tensor and tensor + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& dtype, + const std::optional& memory_config, + const std::optional& output_tensor, + const std::optional& activations, + const std::optional& input_tensor_a_activation, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor, activations, input_tensor_a_activation); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("dtype") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("input_tensor_a_activation") = std::nullopt, + py::arg("queue_id") = 0}, + + // complex tensor + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const ComplexTensor& input_tensor_a, + const ComplexTensor& input_tensor_b, + const MemoryConfig& memory_config) { + return self(input_tensor_a, input_tensor_b, memory_config); + }, + py::arg("input_tensor_a"), + py::arg("input_tensor_b"), + py::kw_only(), + py::arg("memory_config")}); +} + template void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description) { auto doc = fmt::format( @@ -287,11 +380,11 @@ void bind_div_like_ops(py::module& module, const binary_operation_t& operation, template void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description) { auto doc = fmt::format( - R"doc({0}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor Args: - * :attr:`input_tensor_a` - * :attr:`input_tensor_b` (ttnn.Tensor or Number) + * :attr:`input_tensor_a` (ttnn.Tensor or ComplexTensor) + * :attr:`input_tensor_b` (ttnn.Tensor or Number or ComplexTensor) * :attr:`accurate_mode`: ``false`` if input_tensor_b is non-zero, else ``true``. * :attr:`round_mode` @@ -464,7 +557,7 @@ void py_module(py::module& module) { R"doc(Subtracts :attr:`input_tensor_b` from :attr:`input_tensor_a` and returns the tensor with the same layout as :attr:`input_tensor_a` in-place .. math:: \mathrm{{input\_tensor\_a}}_i - \mathrm{{input\_tensor\_b}}_i)doc"); - detail::bind_binary_operation( + detail::bind_binary_operation_with_complex( module, ttnn::multiply, R"doc(Multiplies :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a` @@ -554,7 +647,7 @@ void py_module(py::module& module) { R"doc(Compute bias_gelu of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a` .. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc"); - detail::bind_binary_operation( + detail::bind_binary_operation_with_complex( module, ttnn::divide, R"doc(Divides :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`