Skip to content

Commit

Permalink
#10919: Create pybinding for complex_tensor in ttnn (#10921)
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored Jul 31, 2024
1 parent 2b908fb commit fbf58c9
Show file tree
Hide file tree
Showing 41 changed files with 351 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def test_level1_polar(bs, memcfg, dtype, device, function_level_defaults):
# we set imag = angle theta
x = Complex(None, re=torch.ones(input_shape), im=torch.rand(input_shape))

xtt = ttl.tensor.complex_tensor(
xtt = ttnn.complex_tensor(
ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_level2_abs_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_level2_abs_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_de
in_data = random_complex_tensor(input_shape, (0, 0), (0, 0))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_level2_angle_bw(bs, hw, memcfg, dtype, device, function_level_defaults)
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def test_level2_complex_add_bw(bs, hw, alpha, memcfg, dtype, device, function_le
other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100))
other_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
other_tensor = ttl.tensor.complex_tensor(
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ def test_level2_complex_div_bw(bs, hw, memcfg, dtype, device, function_level_def
other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100))
other_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
other_tensor = ttl.tensor.complex_tensor(
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down Expand Up @@ -99,17 +99,17 @@ def test_level2_complex_div_bw_other_zero(bs, hw, memcfg, dtype, device, functio
other_data = random_complex_tensor(input_shape, (0, 0), (0, 0))
other_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
other_tensor = ttl.tensor.complex_tensor(
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def test_level2_complex_mul_bw(bs, hw, memcfg, dtype, device, function_level_def
other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100))
other_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
other_tensor = ttl.tensor.complex_tensor(
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def test_level2_complex_sub_bw(bs, hw, alpha, memcfg, dtype, device, function_le
other_data = random_complex_tensor(input_shape, (-20, 90), (-30, 100))
other_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
other_tensor = ttl.tensor.complex_tensor(
other_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(other_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(other_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def test_level2_conj_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_level2_imag_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def test_level2_polar_bw(bs, hw, memcfg, dtype, device, function_level_defaults)
in_data = random_complex_tensor(input_shape, (-90, 90), (0, 2 * pi))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_level2_real_bw(bs, hw, memcfg, dtype, device, function_level_defaults):
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def test_level2_recip_bw(bs, hw, memcfg, dtype, device, function_level_defaults)
in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down Expand Up @@ -93,13 +93,13 @@ def test_level2_recip_bw_inp_zero(bs, hw, memcfg, dtype, device, function_level_
in_data = random_complex_tensor(input_shape, (0, 0), (0, 0))
in_data.requires_grad = True

input_tensor = ttl.tensor.complex_tensor(
input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

grad_data = random_complex_tensor(input_shape, (-50, 50), (-60, 60))
grad_tensor = ttl.tensor.complex_tensor(
grad_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(grad_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(grad_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
Expand Down
52 changes: 52 additions & 0 deletions tests/ttnn/unit_tests/operations/complex/test_complex_conj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import tt_lib as ttl
import pytest
import ttnn
from loguru import logger
from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose

from models.utility_functions import is_wormhole_b0, skip_for_grayskull
from tests.ttnn.unit_tests.operations.complex.utility_funcs import (
convert_complex_to_torch_tensor,
random_complex_tensor,
)


@skip_for_grayskull()
@pytest.mark.parametrize(
"memcfg",
(
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
),
ids=["out_DRAM", "out_L1"],
)
@pytest.mark.parametrize("dtype", ((ttl.tensor.DataType.FLOAT32,)))
@pytest.mark.parametrize("bs", ((1, 1),))
@pytest.mark.parametrize("hw", ((32, 32),))
def test_conj(bs, hw, memcfg, dtype, device, function_level_defaults):
input_shape = torch.Size([bs[0], bs[1], hw[0], hw[1]])

in_data = random_complex_tensor(input_shape, (-90, 90), (-70, 70))

input_tensor = ttnn.complex_tensor(
ttl.tensor.Tensor(in_data.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(in_data.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)

tt_dev = ttnn.conj(input_tensor, memory_config=memcfg)

tt_to_torch = convert_complex_to_torch_tensor(tt_dev)

golden_function = ttnn.get_golden_function(ttnn.conj)
golden_tensor = golden_function(in_data)

passing, output = comp_pcc(golden_tensor, tt_to_torch)
logger.info(output)
assert passing
21 changes: 21 additions & 0 deletions tests/ttnn/unit_tests/operations/complex/utility_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import tt_lib as ttl


def random_complex_tensor(shape, real_range=(-100, 100), imag_range=(-100, 100)):
torch.manual_seed(213919)
real_part = (real_range[1] - real_range[0]) * torch.rand(shape) + real_range[0]
imag_part = (imag_range[1] - imag_range[0]) * torch.rand(shape) + imag_range[0]
return torch.complex(real_part, imag_part)


def convert_complex_to_torch_tensor(tt_dev):
tt_dev_r = tt_dev.real.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev_i = tt_dev.imag.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_to_torch = torch.complex(tt_dev_r, tt_dev_i)
return tt_to_torch
Loading

0 comments on commit fbf58c9

Please sign in to comment.