diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/complex/complex_ops.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/complex/complex_ops.cpp index 66e3afb41d0f..353ef8352bfc 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/complex/complex_ops.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/complex/complex_ops.cpp @@ -260,8 +260,8 @@ std::vector complex_recip_bw(const ComplexTensor& grad, const Com std::vector grad_tensor; Tensor condition_nan = ttnn::logical_and(ttnn::eqz(input.real(),output_mem_config), ttnn::eqz(input.imag(),output_mem_config), std::nullopt, output_mem_config); ComplexTensor neg_grad = ComplexTensor({ttnn::neg(grad.real(),output_mem_config), ttnn::neg(grad.imag(),output_mem_config)}); - ComplexTensor inp_recip = ttnn::operations::complex_unary::_complex_recip(input, output_mem_config); - ComplexTensor grad_inp = complex_mul(neg_grad, ttnn::operations::complex_unary::_conj(complex_mul(inp_recip, inp_recip, output_mem_config), output_mem_config), output_mem_config) ; + ComplexTensor inp_recip = ttnn::operations::complex_unary::_reciprocal(input, output_mem_config); + ComplexTensor grad_inp = ttnn::operations::complex_binary::_mul(neg_grad, ttnn::operations::complex_unary::_conj(ttnn::operations::complex_binary::_mul(inp_recip, inp_recip, output_mem_config), output_mem_config), output_mem_config) ; neg_grad.deallocate(); inp_recip.deallocate(); Tensor grad_inp_r = where(condition_nan, full_like(input.real(), std::nanf(""), output_mem_config), grad_inp.real(), output_mem_config); @@ -300,7 +300,7 @@ std::vector angle_bw(const Tensor& grad, const ComplexTensor& inp // polar fwd op uses sin and cos hence input_b range is (0, 2*pi) std::vector polar_bw(const ComplexTensor& grad, const ComplexTensor& input, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - ComplexTensor result = polar(input, output_mem_config); + ComplexTensor result = ttnn::operations::complex_unary::_polar(input, output_mem_config); Tensor abs_result = ttnn::operations::complex_unary::_abs(result, output_mem_config); Tensor sgn_result_r = where(ttnn::eqz(abs_result, output_mem_config), zeros_like(result.real(), output_mem_config), ttnn::multiply(result.real(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config ); Tensor sgn_result_i = where(ttnn::eqz(abs_result, output_mem_config), zeros_like(result.imag(), output_mem_config), ttnn::multiply(result.imag(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config ); @@ -308,10 +308,10 @@ std::vector polar_bw(const ComplexTensor& grad, const ComplexTens ComplexTensor sgn_result = ComplexTensor({ sgn_result_r, sgn_result_i }); sgn_result_r.deallocate(); sgn_result_i.deallocate(); - Tensor grad_abs = ttnn::operations::complex_unary::_real(complex_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config); + Tensor grad_abs = ttnn::operations::complex_unary::_real(ttnn::operations::complex_binary::_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config); sgn_result.deallocate(); ComplexTensor flip_tensor = ComplexTensor({zeros_like(input.real(), output_mem_config), full_like(input.imag(), 1.0, output_mem_config) }); - Tensor grad_angle = ttnn::operations::complex_unary::_real(complex_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), complex_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config); + Tensor grad_angle = ttnn::operations::complex_unary::_real(ttnn::operations::complex_binary::_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), ttnn::operations::complex_binary::_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config); result.deallocate(); flip_tensor.deallocate(); ComplexTensor grad_result = ComplexTensor({grad_abs, grad_angle}); @@ -321,10 +321,6 @@ std::vector polar_bw(const ComplexTensor& grad, const ComplexTens return grad_tensor; } ->>>>>>> #10382: Migrate conj to TTNN with generalized structure ->>>>>>> 6023a929db... #10382: Migrate complex_Recip to TTNN:tt_eager/tt_dnn/op_library/complex/complex_ops.cpp -======= ->>>>>>> 51aa63370e... #10382: Update ops implementation:tt_eager/tt_dnn/op_library/complex/complex_ops.cpp }//namespace tt_metal }//namespace tt diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp index 297c51516caf..58b8ecec4331 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp @@ -8,7 +8,7 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" -#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp" +#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp" #include "ttnn/types.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp index 54df5a9408ae..84f7c2ba1f34 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp @@ -2,15 +2,11 @@ // // SPDX-License-Identifier: Apache-2.0 - -#include "third_party/magic_enum/magic_enum.hpp" - -#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp" -#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp" +#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" namespace ttnn::operations::complex_binary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp index 4fa326e09c11..62f5f47563c3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp @@ -8,7 +8,7 @@ #include #include "tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp" +#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp" namespace ttnn::operations::complex_binary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp index 59de3057d240..372e38c5519c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp @@ -2,17 +2,14 @@ // // SPDX-License-Identifier: Apache-2.0 - -#include "third_party/magic_enum/magic_enum.hpp" - -#include "ttnn/experimental/tt_dnn/op_library/bcast/bcast_op.hpp" -#include "ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp" +#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" + namespace ttnn::operations::complex_binary_backward { @@ -52,9 +49,9 @@ std::vector _complex_sub_bw(const ComplexTensor& grad, const Comp // grad_other = grad * input.conj() std::vector _complex_mul_bw(const ComplexTensor& grad, const ComplexTensor& input, const ComplexTensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; - ComplexTensor grad_a = complex_mul(grad, conj(other,output_mem_config), output_mem_config); + ComplexTensor grad_a = ttnn::operations::complex_binary::_mul(grad, ttnn::operations::complex_unary::_conj(other,output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_a); - ComplexTensor grad_b = complex_mul(grad, conj(input,output_mem_config), output_mem_config); + ComplexTensor grad_b = ttnn::operations::complex_binary::_mul(grad, ttnn::operations::complex_unary::_conj(input,output_mem_config), output_mem_config); grad_tensor.emplace_back(grad_b); return grad_tensor; } @@ -65,7 +62,7 @@ std::vector _complex_mul_bw(const ComplexTensor& grad, const Comp std::vector _complex_div_bw(const ComplexTensor& grad, const ComplexTensor& input, const ComplexTensor& other, const MemoryConfig& output_mem_config) { std::vector grad_tensor; Tensor condition_nan = ttnn::logical_and(ttnn::eqz(other.real(),output_mem_config), ttnn::eqz(other.imag(),output_mem_config), std::nullopt, output_mem_config); - ComplexTensor grad_a = complex_div(grad, conj(other,output_mem_config), output_mem_config); + ComplexTensor grad_a = ttnn::operations::complex_binary::_div(grad, ttnn::operations::complex_unary::_conj(other,output_mem_config), output_mem_config); Tensor grad_a_r = where(condition_nan, ttnn::operations::creation::full_like(grad.real(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_real(grad_a,output_mem_config), output_mem_config); Tensor grad_a_i = where(condition_nan, ttnn::operations::creation::full_like(grad.imag(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_imag(grad_a,output_mem_config), output_mem_config); grad_a = ComplexTensor({grad_a_r, grad_a_i}); @@ -73,7 +70,7 @@ std::vector _complex_div_bw(const ComplexTensor& grad, const Comp grad_a_i.deallocate(); grad_tensor.emplace_back(grad_a); ComplexTensor neg_grad = ComplexTensor({ttnn::neg(grad.real(),output_mem_config), ttnn::neg(grad.imag(),output_mem_config)}); - ComplexTensor grad_b = complex_mul(neg_grad, conj(complex_div(complex_div(input, other, output_mem_config), other, output_mem_config ),output_mem_config), output_mem_config); + ComplexTensor grad_b = ttnn::operations::complex_binary::_mul(neg_grad, ttnn::operations::complex_unary::_conj(ttnn::operations::complex_binary::_div(ttnn::operations::complex_binary::_div(input, other, output_mem_config), other, output_mem_config ),output_mem_config), output_mem_config); neg_grad.deallocate(); Tensor grad_b_r = where(condition_nan, ttnn::operations::creation::full_like(grad.real(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_real(grad_b,output_mem_config), output_mem_config); Tensor grad_b_i = where(condition_nan, ttnn::operations::creation::full_like(grad.imag(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_imag(grad_b,output_mem_config), output_mem_config); diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp index d284fe26fda7..32234ae69e5c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp @@ -3,8 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 -#include "third_party/magic_enum/magic_enum.hpp" - #include "ttnn/experimental/tt_dnn/op_library/bcast/bcast_op.hpp" #include "ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp" #include "tt_metal/common/constants.hpp" @@ -12,6 +10,7 @@ #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp" +#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" namespace ttnn::operations::complex_unary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp index 98c302f6242d..c4e0ea7e56eb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp @@ -10,6 +10,7 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp" +#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" #include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp"