Skip to content

Commit

Permalink
#14995: Angle issue - Fix (#15213)
Browse files Browse the repository at this point in the history
### Ticket
#14995 

### Problem description
Low pcc issue in angle op

### What's changed
Updated the logic to fix the pcc issue. Arguments were passed wrongly

### Checklist
- [ ] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/11969170125)
  • Loading branch information
umadevimcw authored Nov 22, 2024
1 parent af434c6 commit d60d064
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def test_level2_angle(bs, memcfg, dtype, device, function_level_defaults):
x_imag = torch.tensor(x.imag, dtype=torch.bfloat16)
x_torch = torch.complex(x_real.float(), x_imag.float())
tt_cpu = torch.angle(x_torch).to(torch.bfloat16)
passing, output = comp_pcc(tt_cpu, tt_dev)
passing, output = comp_pcc(tt_cpu, tt_dev, 0.98)
logger.info(output)
assert passing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,38 @@ Tensor ExecuteMaximum::invoke(const Tensor& input_a, float value, const std::opt
Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor result(input_a);
{
Tensor atan_input = ttnn::multiply(
ttnn::abs(input_b, output_mem_config),
ttnn::reciprocal(ttnn::abs(input_a, output_mem_config), output_mem_config),
Tensor atan_input = ttnn::multiply(input_b,
ttnn::reciprocal(input_a, output_mem_config),
std::nullopt,
output_mem_config);
result = ttnn::atan(atan_input, output_mem_config);
}
Tensor res(result);
{
Tensor ib_gtz = ttnn::gtz(input_b, output_mem_config);
Tensor ib_gt = ttnn::gtz(input_b, output_mem_config);
Tensor ib_lt = ttnn::ltz(input_b, output_mem_config);
float pi_2 = M_PI_2;
Tensor neg_result = ttnn::neg(result, output_mem_config);
Tensor ia_gtz = ttnn::gtz(input_a, output_mem_config);
Tensor ia_ltz = ttnn::ltz(input_a, output_mem_config);
Tensor ib_ltz = ttnn::ltz(input_b, output_mem_config);

res = ttnn::where(
ttnn::gtz(input_a, output_mem_config),
ttnn::where(ib_gtz, result, neg_result),
ttnn::where(
ttnn::ltz(input_a, output_mem_config),
ttnn::where(
ib_gt,
ttnn::add(neg_result, M_PI, std::nullopt, output_mem_config),
ttnn::where(ib_lt, ttnn::subtract(result, M_PI, std::nullopt, output_mem_config), M_PI)),
ttnn::where(ib_gt, pi_2, ttnn::where(ib_lt, -pi_2, 0.0f))));
Tensor altz_bgte = ttnn::logical_and(ia_ltz, ttnn::ge(input_b, 0.0), std::nullopt, output_mem_config);
Tensor altz_bltz = ttnn::logical_and(ia_ltz, ib_ltz, std::nullopt, output_mem_config);

Tensor a_eqz = ttnn::eqz(input_a, output_mem_config);
Tensor b_gtz = ttnn::gtz(input_b, output_mem_config);
Tensor b_eqz = ttnn::eqz(input_b, output_mem_config);


Tensor az_bltz = ttnn::logical_and(a_eqz, ib_ltz, std::nullopt, output_mem_config);
Tensor az_bgtz = ttnn::logical_and(a_eqz, b_gtz, std::nullopt, output_mem_config);
Tensor az_bz = ttnn::logical_and(a_eqz, b_eqz, std::nullopt, output_mem_config);
float pi_2 = M_PI_2;
res = ttnn::where(ia_gtz, result,
ttnn::where(altz_bgte, ttnn::add(result, M_PI, std::nullopt, output_mem_config),
ttnn::where(altz_bltz, ttnn::subtract(result, M_PI, std::nullopt, output_mem_config),
ttnn::where(az_bltz , M_PI_2, ttnn::where(az_bgtz, -M_PI_2, 0.0, output_mem_config),
output_mem_config),
output_mem_config),
output_mem_config),
output_mem_config);
}
return res;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Tensor _imag(const ComplexTensor& input, const MemoryConfig& output_mem_config)
}

Tensor _angle(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
return ttnn::neg( atan2(input[1],input[0],output_mem_config), output_mem_config );
return atan2(input[0],input[1],output_mem_config);
}

Tensor _is_imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
Expand Down

0 comments on commit d60d064

Please sign in to comment.