Skip to content

Commit

Permalink
Add full sweeps for maxpool2d. (#14711)
Browse files Browse the repository at this point in the history
Add full sweeps for maxpool2d.

Signed-off-by: Nilaykumar K Patel <[email protected]>
  • Loading branch information
nkpatel-tt authored Nov 11, 2024
1 parent fec3ebc commit 498db4e
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 105 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ on:
- conv2d.full.conv2d_sliding_window
- conv2d.short.conv2d_short_sweep
- max_pool2d.short.max_pool2d_short_sweep
- max_pool2d.full.max_pool2d_params
- max_pool2d.full.max_pool2d_large_dims
- transformer.concatenate_heads.concatenate_heads
- transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads
- transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads_kv_input
Expand Down
131 changes: 131 additions & 0 deletions tests/sweep_framework/sweep_utils/max_pool2d_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple, List
import itertools
import random
import torch
import math

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random


def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
[pad_h, pad_w] = test_vector["padding"]
[_, _, kernel_h, kernel_w] = test_vector["shape"]
if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w:
return True, "double of padding can not be greater than kernel size."
return False, None


def mesh_device_fixture():
num_devices = ttnn.GetNumPCIeDevices()
# As of now take device id as 0.
device_id = 0
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768)
ttnn.SetDefaultDevice(device)

device_name = "Unknown"
if ttnn.device.is_grayskull(device):
device_name = "grayskull"
elif ttnn.device.is_wormhole_b0(device):
device_name = "wormhole_b0"
yield device, device_name

ttnn.close_device(device)


def run_max_pool2d(
in_n,
in_c,
in_h,
in_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dtype,
device,
sharding=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ceil_mode=False,
):
act_shape = [in_n, in_c, in_h, in_w]
kernel_size = [kernel_h, kernel_w]
stride = [stride_h, stride_h]
padding = [pad_h, pad_w]
dilation = [dilation_h, dilation_w]

out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1
out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1

torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32)

act = torch.randn(act_shape, dtype=torch.bfloat16)
act_shape = (1, 1, in_n * in_h * in_w, in_c)
act_permuted = torch.permute(act, (0, 2, 3, 1))
act_reshaped = act_permuted.reshape(act_shape)

if dtype == ttnn.bfloat8_b:
ttact = ttnn.from_torch(act_reshaped, dtype, layout=ttnn.TILE_LAYOUT)
else:
ttact = ttnn.from_torch(act_reshaped, dtype)

ttact_device = ttnn.to_device(ttact, device)
start_time = start_measuring_time()
output = ttnn.max_pool2d(
input_tensor=ttact_device,
batch_size=in_n,
input_h=in_h,
input_w=in_w,
channels=in_c,
kernel_size=[kernel_h, kernel_w],
stride=[stride_h, stride_w],
padding=[pad_h, pad_w],
dilation=[dilation_h, dilation_w],
memory_config=None,
applied_shard_scheme=sharding,
)

output_host = output.cpu()
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
output_pytorch = output_pytorch_padded[:, :, :, :in_c]
e2e_perf = stop_measuring_time(start_time)

## reference
golden_pytorch = torch.nn.MaxPool2d(
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=False,
ceil_mode=False,
)(act)

golden_shape = golden_pytorch.shape
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])
output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W

atol, rtol = torch.testing._comparison.default_tolerances(torch.bfloat16)
if dtype == ttnn.bfloat8_b:
atol = 0.35

## test for equivalance
allclose = torch.allclose(output_pytorch, golden_pytorch, atol=atol)
isequal = torch.equal(output_pytorch, golden_pytorch)

assert allclose, " Reference and output tensor are not close"
if dtype == ttnn.bfloat16:
assert isequal, " Reference and output tensor are not equal"

# check pcc and return
return [check_with_pcc(output_pytorch, golden_pytorch, pcc=0.998), e2e_perf]
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from tests.sweep_framework.sweep_utils.max_pool2d_common import run_max_pool2d, mesh_device_fixture, invalidate_vector

# Total test cases
# max_pool2d_full_sweep_suite_large_dims = 17 * 4 * 4 * 3 * 4 * 2 = 6528
# There can be invalid test cases in here based on conditions in invalidate_vector.

parameters = {
"max_pool2d_full_sweep_suite_large_dims": {
"kernel_size": [[j for i in range(2)] for j in range(15, 32)], # square kernels only
"padding": [[7, 7], [8, 8], [15, 15], [16, 16]],
"stride": [[7, 7], [8, 8], [15, 15], [16, 16]],
"sharding": [
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
],
"shape": [
[4, 16, 1056, 160],
[1, 32, 599, 503], # prime number in height and width
[7, 31, 512, 512], # prime numbers in batch size and channels
[3, 17, 503, 503], # prime numbers for all
],
"dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
}
}


def run(
kernel_size,
padding,
stride,
sharding,
shape,
dtype,
*,
device,
):
[in_n, in_c, in_h, in_w] = shape
[kernel_h, kernel_w] = kernel_size
[stride_h, stride_w] = stride
[pad_h, pad_w] = padding
[dilation_h, dilation_w] = [1, 1] # dilation is fix

return run_max_pool2d(
in_n,
in_c,
in_h,
in_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dtype,
device,
sharding,
ceil_mode=False,
)
76 changes: 76 additions & 0 deletions tests/sweep_framework/sweeps/max_pool2d/full/max_pool2d_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from tests.sweep_framework.sweep_utils.max_pool2d_common import run_max_pool2d, mesh_device_fixture, invalidate_vector

# Shapes are taken from existing unit tests
input_shapes = [
[[1, 256, 56, 56]],
[[1, 512, 10, 10]],
[[2, 32, 23, 23]],
[[4, 16, 1056, 160]],
[[8, 4096, 10, 16]],
[[16, 16, 528, 80]],
]

# Total test cases
# max_pool2d_full_sweep_suite_params_{idx} = 13 * 7 * 7 * 3 * 6(input_shapes) * 2 * 2 = 45864
# There can be invalid test cases in here based on conditions in invalidate_vector.

parameters = {
f"max_pool2d_full_sweep_suite_params_{idx}": {
"kernel_size": [[j for i in range(2)] for j in range(2, 15)], # square kernels only
"padding": [[j for i in range(2)] for j in range(1, 8)],
"stride": [[j for i in range(2)] for j in range(1, 8)],
"sharding": [
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
],
"shape": shape_,
"dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"ceil_mode": [True, False],
}
for idx, shape_ in enumerate(input_shapes)
}


def run(
kernel_size,
padding,
stride,
sharding,
shape,
dtype,
ceil_mode=False,
*,
device,
):
[in_n, in_c, in_h, in_w] = shape
[kernel_h, kernel_w] = kernel_size
[stride_h, stride_w] = stride
[pad_h, pad_w] = padding
[dilation_h, dilation_w] = [1, 1] # dilation is fix

return run_max_pool2d(
in_n,
in_c,
in_h,
in_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dtype,
device,
sharding,
ceil_mode,
)
Loading

0 comments on commit 498db4e

Please sign in to comment.