diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_complex.py b/tests/ttnn/unit_tests/operations/eltwise/test_complex.py index b94b552317c..4db8c640952 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_complex.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_complex.py @@ -490,7 +490,7 @@ def test_level2_angle(bs, memcfg, dtype, device, function_level_defaults): x_imag = torch.tensor(x.imag, dtype=torch.bfloat16) x_torch = torch.complex(x_real.float(), x_imag.float()) tt_cpu = torch.angle(x_torch).to(torch.bfloat16) - passing, output = comp_pcc(tt_cpu, tt_dev) + passing, output = comp_pcc(tt_cpu, tt_dev, 0.98) logger.info(output) assert passing diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index e46b86fa072..950a6db4768 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -134,31 +134,38 @@ Tensor ExecuteMaximum::invoke(const Tensor& input_a, float value, const std::opt Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { Tensor result(input_a); { - Tensor atan_input = ttnn::multiply( - ttnn::abs(input_b, output_mem_config), - ttnn::reciprocal(ttnn::abs(input_a, output_mem_config), output_mem_config), + Tensor atan_input = ttnn::multiply(input_b, + ttnn::reciprocal(input_a, output_mem_config), std::nullopt, output_mem_config); result = ttnn::atan(atan_input, output_mem_config); } Tensor res(result); { - Tensor ib_gtz = ttnn::gtz(input_b, output_mem_config); - Tensor ib_gt = ttnn::gtz(input_b, output_mem_config); - Tensor ib_lt = ttnn::ltz(input_b, output_mem_config); - float pi_2 = M_PI_2; - Tensor neg_result = ttnn::neg(result, output_mem_config); + Tensor ia_gtz = ttnn::gtz(input_a, output_mem_config); + Tensor ia_ltz = ttnn::ltz(input_a, output_mem_config); + Tensor ib_ltz = ttnn::ltz(input_b, output_mem_config); - res = ttnn::where( - ttnn::gtz(input_a, output_mem_config), - ttnn::where(ib_gtz, result, neg_result), - ttnn::where( - ttnn::ltz(input_a, output_mem_config), - ttnn::where( - ib_gt, - ttnn::add(neg_result, M_PI, std::nullopt, output_mem_config), - ttnn::where(ib_lt, ttnn::subtract(result, M_PI, std::nullopt, output_mem_config), M_PI)), - ttnn::where(ib_gt, pi_2, ttnn::where(ib_lt, -pi_2, 0.0f)))); + Tensor altz_bgte = ttnn::logical_and(ia_ltz, ttnn::ge(input_b, 0.0), std::nullopt, output_mem_config); + Tensor altz_bltz = ttnn::logical_and(ia_ltz, ib_ltz, std::nullopt, output_mem_config); + + Tensor a_eqz = ttnn::eqz(input_a, output_mem_config); + Tensor b_gtz = ttnn::gtz(input_b, output_mem_config); + Tensor b_eqz = ttnn::eqz(input_b, output_mem_config); + + + Tensor az_bltz = ttnn::logical_and(a_eqz, ib_ltz, std::nullopt, output_mem_config); + Tensor az_bgtz = ttnn::logical_and(a_eqz, b_gtz, std::nullopt, output_mem_config); + Tensor az_bz = ttnn::logical_and(a_eqz, b_eqz, std::nullopt, output_mem_config); + float pi_2 = M_PI_2; + res = ttnn::where(ia_gtz, result, + ttnn::where(altz_bgte, ttnn::add(result, M_PI, std::nullopt, output_mem_config), + ttnn::where(altz_bltz, ttnn::subtract(result, M_PI, std::nullopt, output_mem_config), + ttnn::where(az_bltz , M_PI_2, ttnn::where(az_bgtz, -M_PI_2, 0.0, output_mem_config), + output_mem_config), + output_mem_config), + output_mem_config), + output_mem_config); } return res; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp index 278a08bd844..2ab26b45f13 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp @@ -22,7 +22,7 @@ Tensor _imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) } Tensor _angle(const ComplexTensor& input, const MemoryConfig& output_mem_config) { - return ttnn::neg( atan2(input[1],input[0],output_mem_config), output_mem_config ); + return atan2(input[0],input[1],output_mem_config); } Tensor _is_imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) {