From fbc8a9dc8984e0b38b99a2d74ab6861fc2c83435 Mon Sep 17 00:00:00 2001 From: Nenad Petrovic <109360062+npetrovic-tenstorrent@users.noreply.github.com> Date: Sun, 10 Nov 2024 23:55:49 +0100 Subject: [PATCH] New complex sweeps (#14496) Two new sweeps: - is_imag - is_real - conj_bw --- .github/workflows/ttnn-run-sweeps.yaml | 3 + .../sweeps/eltwise/unary_complex/conj_bw.py | 137 ++++++++++++++++++ .../sweeps/eltwise/unary_complex/is_imag.py | 100 +++++++++++++ .../sweeps/eltwise/unary_complex/is_real.py | 91 ++++++++++++ 4 files changed, 331 insertions(+) create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_complex/conj_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_complex/is_imag.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_complex/is_real.py diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index bb894990686..4e0ddec26a8 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -164,6 +164,8 @@ on: - eltwise.unary_backward.hardswish_bw.hardswish_bw - eltwise.unary_backward.rpow_bw.rpow_bw - eltwise.unary_complex.conj + - eltwise.unary_complex.is_real + - eltwise.unary_complex.is_imag - eltwise.unary_complex.reciprocal - eltwise.unary_complex.reciprocal_bw - eltwise.binary_complex.div_bw.div_bw @@ -193,6 +195,7 @@ on: - eltwise.unary_complex.angle.angle - eltwise.unary_complex.polar_bw.polar_bw - eltwise.unary_complex.angle_bw.angle_bw + - eltwise.unary_complex.conj_bw - eltwise.binary.subtract.subtract - eltwise.binary.subtract.subtract_tensor_pytorch2 - eltwise.binary.multiply.multiply diff --git a/tests/sweep_framework/sweeps/eltwise/unary_complex/conj_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_complex/conj_bw.py new file mode 100644 index 00000000000..471a9664a0e --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_complex/conj_bw.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def str_to_float(x): + try: + return float(x) + except: + return 0.0 + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor_r = gen_func_with_cast_tt( + partial(torch_random, low=0.01, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_grad_tensor_r.requires_grad = True + torch_grad_tensor_r.retain_grad() + + torch_grad_tensor_c = gen_func_with_cast_tt( + partial(torch_random, low=0.01, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_grad_tensor_c.requires_grad = True + torch_grad_tensor_c.retain_grad() + + torch_input_tensor_ar = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_ar.requires_grad = True + torch_input_tensor_ar.retain_grad() + + torch_input_tensor_ac = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_ac.requires_grad = True + torch_input_tensor_ac.retain_grad() + + torch_grad_tensor = torch.complex(torch_grad_tensor_r.to(torch.float32), torch_grad_tensor_c.to(torch.float32)) + torch_input_tensor_a = torch.complex( + torch_input_tensor_ar.to(torch.float32), torch_input_tensor_ac.to(torch.float32) + ) + + golden_function = ttnn.get_golden_function(ttnn.conj_bw) + torch_output_tensor = golden_function(torch_grad_tensor, torch_input_tensor_a)[0] + + grad_tensor_r = ttnn.from_torch( + torch_grad_tensor_r, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + grad_tensor_c = ttnn.from_torch( + torch_grad_tensor_c, dtype=grad_dtype, layout=grad_layout, device=device, memory_config=grad_memory_config + ) + + input_tensor_ar = ttnn.from_torch( + torch_input_tensor_ar, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_ac = ttnn.from_torch( + torch_input_tensor_ac, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + grad_tensor = ttnn.complex_tensor(grad_tensor_r, grad_tensor_c) + input_tensor_a = ttnn.complex_tensor(input_tensor_ar, input_tensor_ac) + + start_time = start_measuring_time() + output_tensor = ttnn.conj_bw(grad_tensor, input_tensor_a, memory_config=output_memory_config)[0] + e2e_perf = stop_measuring_time(start_time) + + output_tensor = torch.cat((ttnn.to_torch(output_tensor.real), ttnn.to_torch(output_tensor.imag)), dim=-1) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_complex/is_imag.py b/tests/sweep_framework/sweeps/eltwise/unary_complex/is_imag.py new file mode 100644 index 00000000000..30f538c2b33 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_complex/is_imag.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, + "xfail": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_ar = gen_func_with_cast_tt( + partial(torch_random, low=0.01, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_input_tensor_ac = gen_func_with_cast_tt( + partial(torch_random, low=0.01, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_output_tensor = torch.isreal( + torch.complex(torch_input_tensor_ar.to(torch.float32), torch_input_tensor_ac.to(torch.float32)) + ) + + input_tensor_ar = ttnn.from_torch( + torch_input_tensor_ar, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_ac = ttnn.from_torch( + torch_input_tensor_ac, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + input_tensor_a = ttnn.complex_tensor(input_tensor_ar, input_tensor_ac) + + start_time = start_measuring_time() + result = ttnn.is_imag(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_complex/is_real.py b/tests/sweep_framework/sweeps/eltwise/unary_complex/is_real.py new file mode 100644 index 00000000000..1ddd60592e7 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_complex/is_real.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.sweep_utils.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_ar = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_input_tensor_ac = gen_func_with_cast_tt( + partial(torch_random, low=100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_output_tensor = torch.isreal( + torch.complex(torch_input_tensor_ar.to(torch.float32), torch_input_tensor_ac.to(torch.float32)) + ) + + input_tensor_ar = ttnn.from_torch( + torch_input_tensor_ar, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_ac = ttnn.from_torch( + torch_input_tensor_ac, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + input_tensor_a = ttnn.complex_tensor(input_tensor_ar, input_tensor_ac) + + start_time = start_measuring_time() + result = ttnn.is_real(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99) + return [pcc, e2e_perf]