Skip to content

Commit

Permalink
#0: Rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Jul 20, 2024
1 parent ac6b8c4 commit ea1b724
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ std::vector<ComplexTensor> complex_recip_bw(const ComplexTensor& grad, const Com
std::vector<ComplexTensor> grad_tensor;
Tensor condition_nan = ttnn::logical_and(ttnn::eqz(input.real(),output_mem_config), ttnn::eqz(input.imag(),output_mem_config), std::nullopt, output_mem_config);
ComplexTensor neg_grad = ComplexTensor({ttnn::neg(grad.real(),output_mem_config), ttnn::neg(grad.imag(),output_mem_config)});
ComplexTensor inp_recip = ttnn::operations::complex_unary::_complex_recip(input, output_mem_config);
ComplexTensor grad_inp = complex_mul(neg_grad, ttnn::operations::complex_unary::_conj(complex_mul(inp_recip, inp_recip, output_mem_config), output_mem_config), output_mem_config) ;
ComplexTensor inp_recip = ttnn::operations::complex_unary::_reciprocal(input, output_mem_config);
ComplexTensor grad_inp = ttnn::operations::complex_binary::_mul(neg_grad, ttnn::operations::complex_unary::_conj(ttnn::operations::complex_binary::_mul(inp_recip, inp_recip, output_mem_config), output_mem_config), output_mem_config) ;
neg_grad.deallocate();
inp_recip.deallocate();
Tensor grad_inp_r = where(condition_nan, full_like(input.real(), std::nanf(""), output_mem_config), grad_inp.real(), output_mem_config);
Expand Down Expand Up @@ -300,18 +300,18 @@ std::vector<ComplexTensor> angle_bw(const Tensor& grad, const ComplexTensor& inp
// polar fwd op uses sin and cos hence input_b range is (0, 2*pi)
std::vector<ComplexTensor> polar_bw(const ComplexTensor& grad, const ComplexTensor& input, const MemoryConfig& output_mem_config) {
std::vector<ComplexTensor> grad_tensor;
ComplexTensor result = polar(input, output_mem_config);
ComplexTensor result = ttnn::operations::complex_unary::_polar(input, output_mem_config);
Tensor abs_result = ttnn::operations::complex_unary::_abs(result, output_mem_config);
Tensor sgn_result_r = where(ttnn::eqz(abs_result, output_mem_config), zeros_like(result.real(), output_mem_config), ttnn::multiply(result.real(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Tensor sgn_result_i = where(ttnn::eqz(abs_result, output_mem_config), zeros_like(result.imag(), output_mem_config), ttnn::multiply(result.imag(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
abs_result.deallocate();
ComplexTensor sgn_result = ComplexTensor({ sgn_result_r, sgn_result_i });
sgn_result_r.deallocate();
sgn_result_i.deallocate();
Tensor grad_abs = ttnn::operations::complex_unary::_real(complex_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config);
Tensor grad_abs = ttnn::operations::complex_unary::_real(ttnn::operations::complex_binary::_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config);
sgn_result.deallocate();
ComplexTensor flip_tensor = ComplexTensor({zeros_like(input.real(), output_mem_config), full_like(input.imag(), 1.0, output_mem_config) });
Tensor grad_angle = ttnn::operations::complex_unary::_real(complex_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), complex_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config);
Tensor grad_angle = ttnn::operations::complex_unary::_real(ttnn::operations::complex_binary::_mul(ttnn::operations::complex_unary::_conj(grad, output_mem_config), ttnn::operations::complex_binary::_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config);
result.deallocate();
flip_tensor.deallocate();
ComplexTensor grad_result = ComplexTensor({grad_abs, grad_angle});
Expand All @@ -321,10 +321,6 @@ std::vector<ComplexTensor> polar_bw(const ComplexTensor& grad, const ComplexTens
return grad_tensor;
}

>>>>>>> #10382: Migrate conj to TTNN with generalized structure
>>>>>>> 6023a929db... #10382: Migrate complex_Recip to TTNN:tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
=======
>>>>>>> 51aa63370e... #10382: Update ops implementation:tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
}//namespace tt_metal

}//namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"
#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/types.hpp"

namespace py = pybind11;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
//
// SPDX-License-Identifier: Apache-2.0


#include "third_party/magic_enum/magic_enum.hpp"

#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp"
#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"

namespace ttnn::operations::complex_binary {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <optional>
#include "tensor/tensor.hpp"
#include "third_party/magic_enum/magic_enum.hpp"
#include "tt_eager/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp"

namespace ttnn::operations::complex_binary {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
//
// SPDX-License-Identifier: Apache-2.0


#include "third_party/magic_enum/magic_enum.hpp"

#include "ttnn/experimental/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp"


namespace ttnn::operations::complex_binary_backward {

Expand Down Expand Up @@ -52,9 +49,9 @@ std::vector<ComplexTensor> _complex_sub_bw(const ComplexTensor& grad, const Comp
// grad_other = grad * input.conj()
std::vector<ComplexTensor> _complex_mul_bw(const ComplexTensor& grad, const ComplexTensor& input, const ComplexTensor& other, const MemoryConfig& output_mem_config) {
std::vector<ComplexTensor> grad_tensor;
ComplexTensor grad_a = complex_mul(grad, conj(other,output_mem_config), output_mem_config);
ComplexTensor grad_a = ttnn::operations::complex_binary::_mul(grad, ttnn::operations::complex_unary::_conj(other,output_mem_config), output_mem_config);
grad_tensor.emplace_back(grad_a);
ComplexTensor grad_b = complex_mul(grad, conj(input,output_mem_config), output_mem_config);
ComplexTensor grad_b = ttnn::operations::complex_binary::_mul(grad, ttnn::operations::complex_unary::_conj(input,output_mem_config), output_mem_config);
grad_tensor.emplace_back(grad_b);
return grad_tensor;
}
Expand All @@ -65,15 +62,15 @@ std::vector<ComplexTensor> _complex_mul_bw(const ComplexTensor& grad, const Comp
std::vector<ComplexTensor> _complex_div_bw(const ComplexTensor& grad, const ComplexTensor& input, const ComplexTensor& other, const MemoryConfig& output_mem_config) {
std::vector<ComplexTensor> grad_tensor;
Tensor condition_nan = ttnn::logical_and(ttnn::eqz(other.real(),output_mem_config), ttnn::eqz(other.imag(),output_mem_config), std::nullopt, output_mem_config);
ComplexTensor grad_a = complex_div(grad, conj(other,output_mem_config), output_mem_config);
ComplexTensor grad_a = ttnn::operations::complex_binary::_div(grad, ttnn::operations::complex_unary::_conj(other,output_mem_config), output_mem_config);
Tensor grad_a_r = where(condition_nan, ttnn::operations::creation::full_like(grad.real(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_real(grad_a,output_mem_config), output_mem_config);
Tensor grad_a_i = where(condition_nan, ttnn::operations::creation::full_like(grad.imag(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_imag(grad_a,output_mem_config), output_mem_config);
grad_a = ComplexTensor({grad_a_r, grad_a_i});
grad_a_r.deallocate();
grad_a_i.deallocate();
grad_tensor.emplace_back(grad_a);
ComplexTensor neg_grad = ComplexTensor({ttnn::neg(grad.real(),output_mem_config), ttnn::neg(grad.imag(),output_mem_config)});
ComplexTensor grad_b = complex_mul(neg_grad, conj(complex_div(complex_div(input, other, output_mem_config), other, output_mem_config ),output_mem_config), output_mem_config);
ComplexTensor grad_b = ttnn::operations::complex_binary::_mul(neg_grad, ttnn::operations::complex_unary::_conj(ttnn::operations::complex_binary::_div(ttnn::operations::complex_binary::_div(input, other, output_mem_config), other, output_mem_config ),output_mem_config), output_mem_config);
neg_grad.deallocate();
Tensor grad_b_r = where(condition_nan, ttnn::operations::creation::full_like(grad.real(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_real(grad_b,output_mem_config), output_mem_config);
Tensor grad_b_i = where(condition_nan, ttnn::operations::creation::full_like(grad.imag(), std::nanf(""), std::nullopt, std::nullopt, std::nullopt, output_mem_config), ttnn::operations::complex_unary::_imag(grad_b,output_mem_config), output_mem_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
// SPDX-License-Identifier: Apache-2.0


#include "third_party/magic_enum/magic_enum.hpp"

#include "ttnn/experimental/tt_dnn/op_library/bcast/bcast_op.hpp"
#include "ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/experimental/tt_dnn/op_library/complex/complex_ops.hpp"
#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp"

namespace ttnn::operations::complex_unary {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "tt_metal/host_api.hpp"
#include "tt_metal/tools/profiler/op_profiler.hpp"
#include "ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp"
#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"
#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp"
Expand Down

0 comments on commit ea1b724

Please sign in to comment.