diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 1bd08c3025a..20f23c51c74 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -313,6 +313,8 @@ on: - conv2d.full.conv2d_sliding_window - conv2d.short.conv2d_short_sweep - max_pool2d.short.max_pool2d_short_sweep + - max_pool2d.full.max_pool2d_params + - max_pool2d.full.max_pool2d_large_dims - transformer.concatenate_heads.concatenate_heads - transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads - transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads_kv_input diff --git a/tests/sweep_framework/sweep_utils/max_pool2d_common.py b/tests/sweep_framework/sweep_utils/max_pool2d_common.py new file mode 100644 index 00000000000..87cfcb6c5ac --- /dev/null +++ b/tests/sweep_framework/sweep_utils/max_pool2d_common.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, List +import itertools +import random +import torch +import math + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + + +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + [pad_h, pad_w] = test_vector["padding"] + [_, _, kernel_h, kernel_w] = test_vector["shape"] + if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w: + return True, "double of padding can not be greater than kernel size." + return False, None + + +def mesh_device_fixture(): + num_devices = ttnn.GetNumPCIeDevices() + # As of now take device id as 0. + device_id = 0 + assert device_id < num_devices, "CreateDevice not supported for non-mmio device" + device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768) + ttnn.SetDefaultDevice(device) + + device_name = "Unknown" + if ttnn.device.is_grayskull(device): + device_name = "grayskull" + elif ttnn.device.is_wormhole_b0(device): + device_name = "wormhole_b0" + yield device, device_name + + ttnn.close_device(device) + + +def run_max_pool2d( + in_n, + in_c, + in_h, + in_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + dtype, + device, + sharding=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ceil_mode=False, +): + act_shape = [in_n, in_c, in_h, in_w] + kernel_size = [kernel_h, kernel_w] + stride = [stride_h, stride_h] + padding = [pad_h, pad_w] + dilation = [dilation_h, dilation_w] + + out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1 + out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1 + + torch.manual_seed(0) + torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32) + + act = torch.randn(act_shape, dtype=torch.bfloat16) + act_shape = (1, 1, in_n * in_h * in_w, in_c) + act_permuted = torch.permute(act, (0, 2, 3, 1)) + act_reshaped = act_permuted.reshape(act_shape) + + if dtype == ttnn.bfloat8_b: + ttact = ttnn.from_torch(act_reshaped, dtype, layout=ttnn.TILE_LAYOUT) + else: + ttact = ttnn.from_torch(act_reshaped, dtype) + + ttact_device = ttnn.to_device(ttact, device) + start_time = start_measuring_time() + output = ttnn.max_pool2d( + input_tensor=ttact_device, + batch_size=in_n, + input_h=in_h, + input_w=in_w, + channels=in_c, + kernel_size=[kernel_h, kernel_w], + stride=[stride_h, stride_w], + padding=[pad_h, pad_w], + dilation=[dilation_h, dilation_w], + memory_config=None, + applied_shard_scheme=sharding, + ) + + output_host = output.cpu() + output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host)) + output_pytorch = output_pytorch_padded[:, :, :, :in_c] + e2e_perf = stop_measuring_time(start_time) + + ## reference + golden_pytorch = torch.nn.MaxPool2d( + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=False, + ceil_mode=False, + )(act) + + golden_shape = golden_pytorch.shape + output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1]) + output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W + + atol, rtol = torch.testing._comparison.default_tolerances(torch.bfloat16) + if dtype == ttnn.bfloat8_b: + atol = 0.35 + + ## test for equivalance + allclose = torch.allclose(output_pytorch, golden_pytorch, atol=atol) + isequal = torch.equal(output_pytorch, golden_pytorch) + + assert allclose, " Reference and output tensor are not close" + if dtype == ttnn.bfloat16: + assert isequal, " Reference and output tensor are not equal" + + # check pcc and return + return [check_with_pcc(output_pytorch, golden_pytorch, pcc=0.998), e2e_perf] diff --git a/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_large_dims.py b/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_large_dims.py new file mode 100644 index 00000000000..bf6aec947ea --- /dev/null +++ b/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_large_dims.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from tests.sweep_framework.sweep_utils.max_pool2d_common import run_max_pool2d, mesh_device_fixture, invalidate_vector + +# Total test cases +# max_pool2d_full_sweep_suite_large_dims = 17 * 4 * 4 * 3 * 4 * 2 = 6528 +# There can be invalid test cases in here based on conditions in invalidate_vector. + +parameters = { + "max_pool2d_full_sweep_suite_large_dims": { + "kernel_size": [[j for i in range(2)] for j in range(15, 32)], # square kernels only + "padding": [[7, 7], [8, 8], [15, 15], [16, 16]], + "stride": [[7, 7], [8, 8], [15, 15], [16, 16]], + "sharding": [ + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ], + "shape": [ + [4, 16, 1056, 160], + [1, 32, 599, 503], # prime number in height and width + [7, 31, 512, 512], # prime numbers in batch size and channels + [3, 17, 503, 503], # prime numbers for all + ], + "dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + } +} + + +def run( + kernel_size, + padding, + stride, + sharding, + shape, + dtype, + *, + device, +): + [in_n, in_c, in_h, in_w] = shape + [kernel_h, kernel_w] = kernel_size + [stride_h, stride_w] = stride + [pad_h, pad_w] = padding + [dilation_h, dilation_w] = [1, 1] # dilation is fix + + return run_max_pool2d( + in_n, + in_c, + in_h, + in_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + dtype, + device, + sharding, + ceil_mode=False, + ) diff --git a/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_params.py b/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_params.py new file mode 100644 index 00000000000..28dc20b4b8f --- /dev/null +++ b/tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_params.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from tests.sweep_framework.sweep_utils.max_pool2d_common import run_max_pool2d, mesh_device_fixture, invalidate_vector + +# Shapes are taken from existing unit tests +input_shapes = [ + [[1, 256, 56, 56]], + [[1, 512, 10, 10]], + [[2, 32, 23, 23]], + [[4, 16, 1056, 160]], + [[8, 4096, 10, 16]], + [[16, 16, 528, 80]], +] + +# Total test cases +# max_pool2d_full_sweep_suite_params_{idx} = 13 * 7 * 7 * 3 * 6(input_shapes) * 2 * 2 = 45864 +# There can be invalid test cases in here based on conditions in invalidate_vector. + +parameters = { + f"max_pool2d_full_sweep_suite_params_{idx}": { + "kernel_size": [[j for i in range(2)] for j in range(2, 15)], # square kernels only + "padding": [[j for i in range(2)] for j in range(1, 8)], + "stride": [[j for i in range(2)] for j in range(1, 8)], + "sharding": [ + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ], + "shape": shape_, + "dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "ceil_mode": [True, False], + } + for idx, shape_ in enumerate(input_shapes) +} + + +def run( + kernel_size, + padding, + stride, + sharding, + shape, + dtype, + ceil_mode=False, + *, + device, +): + [in_n, in_c, in_h, in_w] = shape + [kernel_h, kernel_w] = kernel_size + [stride_h, stride_w] = stride + [pad_h, pad_w] = padding + [dilation_h, dilation_w] = [1, 1] # dilation is fix + + return run_max_pool2d( + in_n, + in_c, + in_h, + in_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + dtype, + device, + sharding, + ceil_mode, + ) diff --git a/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py b/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py index de56bebcc30..f70a17b9488 100644 --- a/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py +++ b/tests/sweep_framework/sweeps/max_pool2d/short/max_pool2d_short_sweep.py @@ -12,6 +12,7 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random +from tests.sweep_framework.sweep_utils.max_pool2d_common import run_max_pool2d, mesh_device_fixture parameters = { @@ -63,24 +64,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: return False, None -def mesh_device_fixture(): - num_devices = ttnn.GetNumPCIeDevices() - # As of now take device id as 0. - device_id = 0 - assert device_id < num_devices, "CreateDevice not supported for non-mmio device" - device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768) - ttnn.SetDefaultDevice(device) - - device_name = "Unknown" - if device.arch() == "grayskull": - device_name = "grayskull" - elif device.arch() == "wormhole_b0": - device_name = "wormhole_b0" - yield device, device_name - - ttnn.close_device(device) - - def run( input_specs, dtype, @@ -102,91 +85,22 @@ def run( dilation_w, ceil_mode, ) = input_specs - act_shape = in_n, in_c, in_h, in_w - kernel_size = kernel_h, kernel_w - padding = pad_h, pad_w - stride = stride_h, stride_w - dilation = dilation_h, dilation_w - - out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1 - out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1 - - torch.manual_seed(0) - torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32) - - act = torch.randn(act_shape, dtype=torch.bfloat16) - act_shape = (1, 1, in_n * in_h * in_w, in_c) - act_permuted = torch.permute(act, (0, 2, 3, 1)) - act_reshaped = act_permuted.reshape(act_shape) - - if dtype == ttnn.bfloat8_b: - ttact = ttnn.from_torch(act_reshaped, dtype, layout=ttnn.TILE_LAYOUT) - else: - ttact = ttnn.from_torch(act_reshaped, dtype) - - ttact_device = ttnn.to_device(ttact, device) - parallel_config = ttnn._ttnn.operations.conv.determine_parallel_config( - is_1d_systolic=True, - batch_size=in_n, - input_channels=in_c, - output_height=out_h, - output_width=out_w, - output_channels=in_c, - compute_grid_size=device.compute_with_storage_grid_size(), - is_out_tiled=False, - ) - sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( - tensor_shape=act_shape, - parallel_config=parallel_config, - tile_size=32 if dtype == ttnn.bfloat8_b else 1, - ) - ttact_device = ttnn.to_memory_config(ttact_device, sharded_memory_config) - start_time = start_measuring_time() - output = ttnn.max_pool2d( - input_tensor=ttact_device, - batch_size=in_n, - input_h=in_h, - input_w=in_w, - channels=in_c, - kernel_size=[kernel_h, kernel_w], - stride=[stride_h, stride_w], - padding=[pad_h, pad_w], - dilation=[dilation_h, dilation_w], - device=device, + sharding = ttnn.TensorMemoryLayout.HEIGHT_SHARDED + return run_max_pool2d( + in_n, + in_c, + in_h, + in_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + dtype, + device, + sharding, + ceil_mode, ) - - # interleaved_mem_config = ttnn.L1_MEMORY_CONFIG - # output = ttnn.to_memory_config(output, interleaved_mem_config) - output_host = output.cpu() - output_pytorch_padded = ttnn.to_torch(output_host) - output_pytorch = output_pytorch_padded[:, :, :, :in_c] - e2e_perf = stop_measuring_time(start_time) - - ## reference - golden_pytorch = torch.nn.MaxPool2d( - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - return_indices=False, - ceil_mode=ceil_mode, - )(act) - - ## test for equivalance - golden_shape = golden_pytorch.shape - output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1]) - output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W - - atol, rtol = torch.testing._comparison.default_tolerances(torch.bfloat16) - if dtype == ttnn.bfloat8_b: - atol = 0.35 - - allclose = torch.allclose(output_pytorch, golden_pytorch, atol=atol) - isequal = torch.equal(output_pytorch, golden_pytorch) - - assert allclose, " Reference and output tensor are not close" - if dtype == ttnn.bfloat16: - assert isequal, " Reference and output tensor are not equal" - - # check pcc and return - return [check_with_pcc(output_pytorch, golden_pytorch, pcc=0.998), e2e_perf]