diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_complex.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_complex.py index 8c29c744482..268f6b0b949 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_complex.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_complex.py @@ -437,7 +437,7 @@ def test_level1_polar(bs, memcfg, dtype, device, function_level_defaults): # we set imag = angle theta x = Complex(None, re=torch.ones(input_shape), im=torch.rand(input_shape)) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py index 8a3f19ba6ab..8426a74de57 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py @@ -41,7 +41,7 @@ def test_level2_abs_bw(bs, hw, memcfg, dtype, device, function_level_defaults): in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -81,7 +81,7 @@ def test_level2_abs_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_de in_data = random_complex_tensor(input_shape, (0, 0), (0, 0)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py index 40efba39d99..2afaf21d6a1 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py @@ -40,7 +40,7 @@ def test_level2_angle_bw(bs, hw, memcfg, dtype, device, function_level_defaults) in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py index 3848e06e1f0..574f4f26e94 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py @@ -45,17 +45,17 @@ def test_level2_complex_add_bw(bs, hw, alpha, memcfg, dtype, device, function_le other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100)) other_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - other_tensor = ttl.tensor.complex_tensor( + other_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py index c03f07a4623..959d39f992c 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py @@ -48,17 +48,17 @@ def test_level2_complex_div_bw(bs, hw, memcfg, dtype, device, function_level_def other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100)) other_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - other_tensor = ttl.tensor.complex_tensor( + other_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -99,17 +99,17 @@ def test_level2_complex_div_bw_other_zero(bs, hw, memcfg, dtype, device, functio other_data = random_complex_tensor(input_shape, (0, 0), (0, 0)) other_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - other_tensor = ttl.tensor.complex_tensor( + other_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py index 6249d84a397..4777af0195a 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py @@ -45,17 +45,17 @@ def test_level2_complex_mul_bw(bs, hw, memcfg, dtype, device, function_level_def other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100)) other_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - other_tensor = ttl.tensor.complex_tensor( + other_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py index 592e869f1e4..8be916bfa5c 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py @@ -40,17 +40,17 @@ def test_level2_complex_sub_bw(bs, hw, alpha, memcfg, dtype, device, function_le other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100)) other_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - other_tensor = ttl.tensor.complex_tensor( + other_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py index 28425bf0fcc..5142f1b0e4a 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py @@ -40,13 +40,13 @@ def test_level2_conj_bw(bs, hw, memcfg, dtype, device, function_level_defaults): in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py index f15cc03a269..7db566665b5 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py @@ -40,7 +40,7 @@ def test_level2_imag_bw(bs, hw, memcfg, dtype, device, function_level_defaults): in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py index 09a56545a28..ad0e1ffb7d8 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py @@ -42,13 +42,13 @@ def test_level2_polar_bw(bs, hw, memcfg, dtype, device, function_level_defaults) in_data = random_complex_tensor(input_shape, (-90, 90), (0, 2 * pi)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py index 4bfcebc2880..ab9e7bc4d67 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py @@ -40,7 +40,7 @@ def test_level2_real_bw(bs, hw, memcfg, dtype, device, function_level_defaults): in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py index a421f200ddf..a8067b0f005 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py +++ b/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py @@ -41,13 +41,13 @@ def test_level2_recip_bw(bs, hw, memcfg, dtype, device, function_level_defaults) in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -93,13 +93,13 @@ def test_level2_recip_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_ in_data = random_complex_tensor(input_shape, (0, 0), (0, 0)) in_data.requires_grad = True - input_tensor = ttl.tensor.complex_tensor( + input_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60)) - grad_tensor = ttl.tensor.complex_tensor( + grad_tensor = ttnn.complex_tensor( ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/complex/test_complex_conj.py b/tests/ttnn/unit_tests/operations/complex/test_complex_conj.py new file mode 100644 index 00000000000..47e8d1b84df --- /dev/null +++ b/tests/ttnn/unit_tests/operations/complex/test_complex_conj.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import tt_lib as ttl +import pytest +import ttnn +from loguru import logger +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose + +from models.utility_functions import is_wormhole_b0, skip_for_grayskull +from tests.ttnn.unit_tests.operations.complex.utility_funcs import ( + convert_complex_to_torch_tensor, + random_complex_tensor, +) + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "memcfg", + ( + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), + ), + ids=["out_DRAM", "out_L1"], +) +@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.FLOAT32,))) +@pytest.mark.parametrize("bs", ((1, 1),)) +@pytest.mark.parametrize("hw", ((32, 32),)) +def test_conj(bs, hw, memcfg, dtype, device, function_level_defaults): + input_shape = torch.Size([bs[0], bs[1], hw[0], hw[1]]) + + in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) + + input_tensor = ttnn.complex_tensor( + ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), + ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), + ) + + tt_dev = ttnn.conj(input_tensor, memory_config=memcfg) + + tt_to_torch = convert_complex_to_torch_tensor(tt_dev) + + golden_function = ttnn.get_golden_function(ttnn.conj) + golden_tensor = golden_function(in_data) + + passing, output = comp_pcc(golden_tensor, tt_to_torch) + logger.info(output) + assert passing diff --git a/tests/ttnn/unit_tests/operations/complex/utility_funcs.py b/tests/ttnn/unit_tests/operations/complex/utility_funcs.py new file mode 100644 index 00000000000..68f8b697d39 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/complex/utility_funcs.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import tt_lib as ttl + + +def random_complex_tensor(shape, real_range=(-100, 100), imag_range=(-100, 100)): + torch.manual_seed(213919) + real_part = (real_range[1] - real_range[0]) * torch.rand(shape) + real_range[0] + imag_part = (imag_range[1] - imag_range[0]) * torch.rand(shape) + imag_range[0] + return torch.complex(real_part, imag_part) + + +def convert_complex_to_torch_tensor(tt_dev): + tt_dev_r = tt_dev.real.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + tt_dev_i = tt_dev.imag.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch() + tt_to_torch = torch.complex(tt_dev_r, tt_dev_i) + return tt_to_torch diff --git a/tests/ttnn/unit_tests/operations/test_complex.py b/tests/ttnn/unit_tests/operations/test_complex.py index dd8d257d15c..5c912785fcf 100644 --- a/tests/ttnn/unit_tests/operations/test_complex.py +++ b/tests/ttnn/unit_tests/operations/test_complex.py @@ -97,7 +97,7 @@ def test_level2_real(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check real x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -123,7 +123,7 @@ def test_level2_imag(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check imag x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -149,7 +149,7 @@ def test_level2_abs(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check abs x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -178,7 +178,7 @@ def test_level2_abs(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check abs x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -207,7 +207,7 @@ def test_level2_conj(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check abs x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -239,7 +239,7 @@ def test_level2_recip(bs, memcfg, dtype, device, function_level_defaults): # check abs x = Complex(input_shape) x = x.div(x * 0.5) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -274,11 +274,11 @@ def test_level2_add(bs, memcfg, dtype, device, function_level_defaults): x = Complex(input_shape) y = Complex(input_shape) * -0.5 - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - ytt = ttl.tensor.complex_tensor( + ytt = ttnn.complex_tensor( ttl.tensor.Tensor(y.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(y.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -311,11 +311,11 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults): x = Complex(input_shape) y = Complex(input_shape) * -0.5 - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - ytt = ttl.tensor.complex_tensor( + ytt = ttnn.complex_tensor( ttl.tensor.Tensor(y.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(y.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -349,11 +349,11 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults): x = Complex(input_shape) y = Complex(input_shape) * -0.5 - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - ytt = ttl.tensor.complex_tensor( + ytt = ttnn.complex_tensor( ttl.tensor.Tensor(y.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(y.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -387,11 +387,11 @@ def test_level2_div(bs, memcfg, dtype, device, function_level_defaults): x = Complex(input_shape) * 0.5 y = Complex(input_shape) * 1 - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) - ytt = ttl.tensor.complex_tensor( + ytt = ttnn.complex_tensor( ttl.tensor.Tensor(y.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(y.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -422,7 +422,7 @@ def test_level2_is_real(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check abs x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(0 * x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -452,7 +452,7 @@ def test_level2_is_imag(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check abs x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(0 * x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -481,7 +481,7 @@ def test_level2_angle(bs, memcfg, dtype, device, function_level_defaults): input_shape = torch.Size([bs[0], bs[1], 32, 64]) # check imag x = Complex(input_shape) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) @@ -514,7 +514,7 @@ def test_level2_polar(bs, memcfg, dtype, device, function_level_defaults): # we set imag = angle theta x = Complex(None, re=torch.ones(input_shape), im=torch.rand(input_shape)) - xtt = ttl.tensor.complex_tensor( + xtt = ttnn.complex_tensor( ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), ) diff --git a/tests/ttnn/unit_tests/operations/test_complex_tensor.py b/tests/ttnn/unit_tests/operations/test_complex_tensor.py new file mode 100644 index 00000000000..2eaa8653dbc --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_complex_tensor.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import tt_lib as ttl +import pytest +import ttnn +from loguru import logger +from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose + +from models.utility_functions import is_wormhole_b0 +from tests.ttnn.unit_tests.operations.complex.utility_funcs import ( + convert_complex_to_torch_tensor, + random_complex_tensor, +) + + +@pytest.mark.parametrize( + "memcfg", + ( + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), + ), + ids=["out_DRAM", "out_L1"], +) +@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.FLOAT32,))) +@pytest.mark.parametrize("bs", ((1, 1),)) +@pytest.mark.parametrize("hw", ((32, 32),)) +def test_create_complex_tensor(bs, hw, memcfg, dtype, device, function_level_defaults): + input_shape = torch.Size([bs[0], bs[1], hw[0], hw[1]]) + + in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70)) + + input_tensor = ttnn.complex_tensor( + ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), + ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg), + ) + + tt_dev = input_tensor + tt_to_torch = convert_complex_to_torch_tensor(tt_dev) + + golden_tensor = in_data + + passing, output = comp_pcc(golden_tensor, tt_to_torch) + logger.info(output) + assert passing diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 6fc29565b99..617480f8702 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -30,19 +30,20 @@ set(TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/ternary_backward/device/ternary_backward_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/upsample/upsample_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/upsample/device/upsample_op_multi_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/upsample/device/upsample_op_single_core.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/unary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex/complex.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/ternary/where_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 5d61aea1ab4..6d9174b2c99 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -32,9 +32,9 @@ #include "ttnn/operations/embedding/embedding_pybind.hpp" #include "ttnn/operations/matmul/matmul_pybind.hpp" #include "ttnn/operations/transformer/transformer_pybind.hpp" -#include "ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp" -#include "ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp" #include "ttnn/operations/experimental/experimental_pybind.hpp" +#include "ttnn/operations/eltwise/complex/complex_pybind.hpp" +#include "ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp" namespace py = pybind11; @@ -49,9 +49,6 @@ void py_module(py::module& module) { auto m_unary = module.def_submodule("unary", "unary operations"); unary::py_module(m_unary); - auto m_complex_unary_backward = module.def_submodule("complex_unary_backward", "complex_unary_backward operations"); - complex_unary_backward::py_module(m_complex_unary_backward); - auto m_binary = module.def_submodule("binary", "binary operations"); binary::py_module(m_binary); @@ -69,9 +66,15 @@ void py_module(py::module& module) { ccl::py_bind_line_all_gather(m_ccl); ccl::py_bind_reduce_scatter(m_ccl); + auto m_complex = module.def_submodule("complex", "complex tensor creation"); + complex::py_module(m_complex); + auto m_complex_unary = module.def_submodule("complex_unary", "complex_unary operations"); complex_unary::py_module(m_complex_unary); + auto m_complex_unary_backward = module.def_submodule("complex_unary_backward", "complex_unary_backward operations"); + complex_unary_backward::py_module(m_complex_unary_backward); + auto m_ternary = module.def_submodule("ternary", "ternary operations"); ternary::py_module(m_ternary); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/complex/complex_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/complex/complex_ops.cpp index d64caf4158a..0546ef0b7ea 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/complex/complex_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/complex/complex_ops.cpp @@ -12,7 +12,6 @@ #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" namespace tt { diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index fa211da6218..dd608bbc9f2 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -8,10 +8,10 @@ #include "ttnn/deprecated/tt_dnn/op_library/optimizer/optimizer_ops.hpp" #include "tt_lib_bindings_tensor.hpp" #include "tt_lib_bindings_tensor_impl.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" + namespace tt::tt_metal::detail { -using ComplexTensor = ttnn::operations::complex_binary::ComplexTensor; + void TensorModuleCompositeOPs(py::module& m_tensor) { @@ -936,22 +936,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) { detail::bind_binary_op(m_tensor, "scatter", &tt::tt_metal::scatter, R"doc(Performs scatter operation on elements of the input tensors ``{0}`` and ``{1}``,specifically to copy channel data.)doc"); - // *** type-2 complex operations in new submodule 'type2_complex' *** - auto m_type2_cplx = m_tensor.def_submodule("complex", "Complex type2"); - py::class_ pycplx_cls(m_type2_cplx, "ComplexTensor"); - - pycplx_cls.def_property_readonly("real",&ComplexTensor::real); - pycplx_cls.def_property_readonly("imag",&ComplexTensor::imag); - pycplx_cls.def("deallocate",&ComplexTensor::deallocate); - - m_tensor.def("complex_tensor", - [](Tensor& r, Tensor& i) -> ComplexTensor { - return ComplexTensor({r,i}); - }, - py::arg("real"), - py::arg("imag"), - R"doc(Create a complex tensor object from real and imag parts ``{0}`` and ``{1}``.)doc" - ); // loss functions m_tensor.def( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 6dcb41f9934..010daa1b15f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -10,9 +10,8 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" -#include "ttnn/operations/eltwise/complex_binary/complex_binary.hpp" #include "ttnn/types.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace py = pybind11; @@ -23,7 +22,6 @@ namespace binary { namespace detail { -using ComplexTensor = complex_binary::ComplexTensor; template void bind_binary_operation(py::module& module, const binary_operation_t& operation, const std::string& description) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp new file mode 100644 index 00000000000..523537c3688 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + + +#include "complex.hpp" + + +namespace ttnn { +namespace operations::complex { + + +ComplexTensor::ComplexTensor(std::array val): m_real_imag(val) { + TT_ASSERT( m_real_imag[0].get_legacy_shape() == m_real_imag[1].get_legacy_shape() , "Tensor shapes of real and imag should be identical"); + } + +const Tensor& ComplexTensor::operator[](uint32_t index) const { + return m_real_imag[index]; + } + +const Tensor& ComplexTensor::real() const { + return m_real_imag[0]; + } + +const Tensor& ComplexTensor::imag() const { + return m_real_imag[1]; + } + +void ComplexTensor::deallocate() { + m_real_imag[0].deallocate(); + m_real_imag[1].deallocate(); + } + + +ComplexTensor CreateComplexTensor::operator()( + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg) { + return ComplexTensor({input_tensor_a_arg, input_tensor_b_arg}); + } + +} // namespace operations::complex + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp new file mode 100644 index 00000000000..296a6e3e801 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/decorators.hpp" + +namespace ttnn { +namespace operations::complex { + +class ComplexTensor { + private: + std::array m_real_imag; + + public: + ComplexTensor(std::array val); + const Tensor& operator[](uint32_t index) const; + const Tensor& real() const; + const Tensor& imag() const; + void deallocate(); +}; + +struct CreateComplexTensor { + + static ComplexTensor operator()( + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg); +}; + +} // namespace operations::complex + +using ComplexTensor = operations::complex::ComplexTensor; + +constexpr auto complex_tensor = ttnn::register_operation< + "ttnn::complex_tensor", + operations::complex::CreateComplexTensor>(); + + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp new file mode 100644 index 00000000000..6a41e931eff --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/complex/complex_pybind.hpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "complex.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::complex { + +namespace detail { + +void bind_complex_tensor_type(py::module& m) { + py::class_(m, "ComplexTensor") + .def(py::init>()) + .def_property_readonly("real", &ComplexTensor::real) + .def_property_readonly("imag", &ComplexTensor::imag) + .def("deallocate", &ComplexTensor::deallocate) + .def("__getitem__", &ComplexTensor::operator[]); + +} + +void bind_complex_tensor(py::module& module) { + auto doc = fmt::format( + R"doc({0}real: ttnn.Tensor, imag: ttnn.Tensor -> ComplexTensor + + Create a complex tensor from real and imaginary part tensors. + + Args: + * :attr:`real` + * :attr:`imag` + + Example: + + >>> real = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> imag = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> complex_tensor = ttnn.complex_tensor(real, imag) + )doc", + ttnn::complex_tensor.base_name()); + + bind_registered_operation( + module, + ttnn::complex_tensor, + doc, + ttnn::pybind_arguments_t{ + py::arg("real"), + py::arg("imag")} + ); +} + +} // detail + + +void py_module(py::module& module) { + detail::bind_complex_tensor_type(module); + detail::bind_complex_tensor(module); +} + +} // ttnn::operations::complex diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp index bd8123e9943..fa2133ee86e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/complex_binary_pybind.hpp @@ -8,6 +8,7 @@ #include #include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" #include "ttnn/types.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp index 93059758bb0..7f0ec8ecf85 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.cpp @@ -6,6 +6,7 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" namespace ttnn::operations::complex_binary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp index 7d294d97262..b9532371032 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp @@ -8,6 +8,7 @@ #include #include "ttnn/tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_binary { @@ -19,34 +20,6 @@ enum class ComplexBinaryOpType { DIV, }; -class ComplexTensor { - private: - std::array m_real_imag; - - public: - - ComplexTensor(std::array val): m_real_imag(val) { - TT_ASSERT( m_real_imag[0].get_legacy_shape() == m_real_imag[1].get_legacy_shape() , "Tensor shapes of real and imag should be identical"); - } - - const Tensor& operator[](uint32_t index) const { - return m_real_imag[index]; - } - - const Tensor& real() const { - return m_real_imag[0]; - } - - const Tensor& imag() const { - return m_real_imag[1]; - } - - void deallocate() { - m_real_imag[0].deallocate(); - m_real_imag[1].deallocate(); - } -}; - // OpHandler_complex_binary_type1 = get_function_complex_binary ComplexTensor _add(const ComplexTensor& input_a, const ComplexTensor& input_b, const MemoryConfig& output_mem_config); ComplexTensor _sub(const ComplexTensor& input_a, const ComplexTensor& input_b, const MemoryConfig& output_mem_config); diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp index cc7fc8521f0..05174bbdfdf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp @@ -12,7 +12,6 @@ namespace ttnn { namespace operations::complex_binary_backward { -using ComplexTensor = complex_binary::ComplexTensor; template struct ExecuteComplexBinaryBackwardWFloat { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp index 5d8894d7890..85a9593eb7e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward_pybind.hpp @@ -9,6 +9,7 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" #include "ttnn/types.hpp" namespace py = pybind11; @@ -18,7 +19,6 @@ namespace operations { namespace complex_binary_backward { namespace detail { -using ComplexTensor = complex_binary::ComplexTensor; template void bind_complex_binary_backward_w_float(py::module& module, const complex_binary_backward_operation_t& operation, const std::string& description) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp index c3ffccf99f4..fd88e586187 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.cpp @@ -2,18 +2,18 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "complex_binary_backward_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" -#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" #include "ttnn/operations/eltwise/complex_unary/complex_unary.hpp" #include "ttnn/cpp/ttnn/operations/eltwise/ternary/where_op.hpp" + namespace ttnn::operations::complex_binary_backward { -using ComplexTensor = complex_binary::ComplexTensor; // complex add // self: grad, other: grad * alpha diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp index 3e088af935d..9434b4655b7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp @@ -8,10 +8,9 @@ #include #include "ttnn/tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_binary_backward { -using ComplexTensor = complex_binary::ComplexTensor; constexpr uint8_t DefaultQueueId = 0; enum class ComplexBinaryBackwardOpType { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp index d884ae53e3f..52150a11216 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary.hpp @@ -12,7 +12,6 @@ namespace ttnn { namespace operations::complex_unary { -using ComplexTensor = complex_binary::ComplexTensor; template struct ExecuteComplexUnaryTensor { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary_pybind.hpp index 998f7b1c136..3688602618a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/complex_unary_pybind.hpp @@ -9,6 +9,7 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/complex_unary/complex_unary.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" #include "ttnn/types.hpp" namespace py = pybind11; @@ -18,7 +19,6 @@ namespace operations { namespace complex_unary { namespace detail { -using ComplexTensor = complex_binary::ComplexTensor; template void bind_complex_unary_tensor(py::module& module, const complex_unary_operation_t& operation, const std::string& description) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp index 622943a90ca..6d67e892dd0 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp @@ -2,18 +2,16 @@ // // SPDX-License-Identifier: Apache-2.0 - +#include "complex_unary_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/bcast/bcast_op.hpp" #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" -#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_unary { -using ComplexTensor = complex_binary::ComplexTensor; Tensor _real(const ComplexTensor& input, const MemoryConfig& output_mem_config) { return input[0]; diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp index 9f66bb5c618..d1973c94515 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp @@ -8,10 +8,9 @@ #include #include "ttnn/tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_unary { -using ComplexTensor = complex_binary::ComplexTensor; constexpr uint8_t DefaultQueueId = 0; enum class ComplexUnaryOpType { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp index b775f6270bb..9aecbae57ca 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp @@ -12,7 +12,7 @@ namespace ttnn { namespace operations::complex_unary_backward { -using ComplexTensor = complex_binary::ComplexTensor; + template struct ExecuteComplexUnaryBackward { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp index 99323034ff1..d3143c10c09 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward_pybind.hpp @@ -9,6 +9,7 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" #include "ttnn/types.hpp" namespace py = pybind11; @@ -18,7 +19,6 @@ namespace operations { namespace complex_unary_backward { namespace detail { -using ComplexTensor = complex_binary::ComplexTensor; template void bind_complex_unary_backward(py::module& module, const complex_unary_backward_operation_t& operation, const std::string& description) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp index 31f6d92ddc5..2e6f30b20b9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.cpp @@ -3,13 +3,11 @@ // SPDX-License-Identifier: Apache-2.0 +#include "complex_unary_backward_op.hpp" #include "third_party/magic_enum/magic_enum.hpp" - #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" -#include "ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp" -#include "ttnn/operations/eltwise/complex_binary_backward/device/complex_binary_backward_op.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp" #include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" @@ -18,7 +16,6 @@ namespace ttnn::operations::complex_unary_backward { -using ComplexTensor = complex_binary::ComplexTensor; // polar // grad_abs = torch.real(grad_conj * torch.sgn(result)) diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp index 359af5f71e3..3ceb86eadb5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp @@ -8,10 +8,9 @@ #include #include "ttnn/tensor/tensor.hpp" #include "third_party/magic_enum/magic_enum.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_unary_backward { -using ComplexTensor = complex_binary::ComplexTensor; constexpr uint8_t DefaultQueueId = 0; enum class ComplexUnaryBackwardOpType { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 5dd7269a2ec..270c86e67df 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -12,7 +12,7 @@ #include "ttnn/operations/eltwise/unary/unary_composite.hpp" #include "ttnn/operations/eltwise/complex_unary/complex_unary.hpp" #include "ttnn/types.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" +#include "ttnn/operations/eltwise/complex/complex.hpp" namespace py = pybind11; @@ -23,7 +23,6 @@ namespace unary { namespace detail { using FusedActivations = std::vector; -using ComplexTensor = complex_binary::ComplexTensor; template void bind_unary_operation(py::module& module, const unary_operation_t& operation, const std::string& info_doc = "" ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index e0fa0ad836b..7bbf6d6c7d7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -9,9 +9,6 @@ #include "ttnn/cpp/pybind11/decorators.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward.hpp" -#include "ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp" -#include "ttnn/operations/eltwise/complex_binary_backward/complex_binary_backward.hpp" -#include "ttnn/operations/eltwise/complex_unary_backward/complex_unary_backward.hpp" #include "ttnn/operations/eltwise/unary_backward/unary_backward.hpp" #include "ttnn/types.hpp"