Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sliding window ops in c++ #8389

Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3869f84
#0: wip composite conv cpp refactoring
tt-nshanker May 9, 2024
396b3b8
#0: halo config generation in cpp done, need testing
mywoodstock May 2, 2024
a1705f6
#0: Halo op using new c++ config generation
mywoodstock May 7, 2024
1878fd8
#0: wip micro conv op
tt-nshanker May 9, 2024
8f1c126
#0: wip
tt-nshanker May 10, 2024
7f744d0
#0: halo config generation in cpp done, need testing
mywoodstock May 2, 2024
6615839
#0: New Halo struct which stores the sliding window config, and creat…
mywoodstock May 9, 2024
70c6a1d
#0: missed arg
mywoodstock May 9, 2024
cd0df15
#0: config tensors as part of program
mywoodstock May 10, 2024
5fba5d4
#0: on device config
mywoodstock May 10, 2024
3c3acf5
#0: wip
tt-nshanker May 10, 2024
59edaf4
#0: the friend lives in detail namespace, include tensor.hpp issues
mywoodstock May 10, 2024
365f97b
#0: wip
tt-nshanker May 10, 2024
d3de252
#0: fix more compile errors
mywoodstock May 10, 2024
e6704af
#0: wip
mywoodstock May 10, 2024
32b6e01
#0: wip compiler error fixes
tt-nshanker May 10, 2024
40d6323
#0: wip compiler error fixes
tt-nshanker May 11, 2024
be7b097
#0: composite conv updates + pybind
tt-nshanker May 11, 2024
9dd0d55
#0: WIP : 2 Tests failing
sankarmanoj-tt May 6, 2024
40c14d5
#0: adding composite conv op and new halo op to makefile wip
tt-nshanker May 13, 2024
43e9584
#0: wip
tt-nshanker May 13, 2024
7029ae8
#0: config tensor generation before create_program, move to device in…
mywoodstock May 13, 2024
0506ab4
#0: wip
tt-nshanker May 13, 2024
fbb14e6
#0: conv compilation fixes
tt-nshanker May 13, 2024
e483505
#0: more fixes
tt-nshanker May 13, 2024
9ab9cc5
#0: clean compile
mywoodstock May 13, 2024
14b83d3
#0: clean compile
mywoodstock May 13, 2024
badc53e
#0: one more thing
mywoodstock May 13, 2024
4d35d80
#0: swap with new c++ conv wip
tt-nshanker May 13, 2024
20d01fc
#0: fixes
tt-nshanker May 14, 2024
763b1ba
#0: fix to run conv unit test
tt-nshanker May 14, 2024
c857095
#0: maxpool composite in c++
mywoodstock May 14, 2024
57375a7
#0: maxpool2d c++ macro-op pybind
mywoodstock May 14, 2024
2045299
#0: fixed compilation errors
mywoodstock May 14, 2024
c5b830b
#0: some more fixes
mywoodstock May 14, 2024
dd8f188
#0: new maxpool unit tests
mywoodstock May 14, 2024
3a322fb
#0: resize instead of reserve!
mywoodstock May 14, 2024
eea55c8
#0: bugfixes
mywoodstock May 15, 2024
2de043f
#0: some fixes
mywoodstock May 15, 2024
35dece2
#0: updated to_weight_special_padding_tile_layout to use borrowed or …
arakhmati May 15, 2024
403b655
#0: remote config fix
mywoodstock May 15, 2024
14936e5
#0: config identical to orig version
mywoodstock May 15, 2024
7672b51
#0: wip fixes
tt-nshanker May 16, 2024
87faaa3
#0: more fixes
mywoodstock May 16, 2024
29c868a
#0: fix broken height sharded
mywoodstock May 16, 2024
e29dcca
#0: height sharded tests pass on GS
tt-nshanker May 17, 2024
996b29d
#0: snap to tile fix
mywoodstock May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ repos:
rev: 23.10.1
hooks:
- id: black
language_version: python3.8
language_version: python3
202 changes: 202 additions & 0 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger

import torch
import pytest
import math
from models.utility_functions import is_wormhole_b0
from tests.ttnn.utils_for_testing import assert_with_pcc
import ttnn


## NOTE: this is the new C++ TTNN version


@pytest.mark.parametrize("device_l1_small_size", [24576], indirect=True)
@pytest.mark.parametrize(
"act_shape", ## NCHW
(
( ## resnet shapes
[1, 64, 112, 112],
[4, 64, 112, 112],
[8, 64, 112, 112],
[16, 64, 112, 112],
# [20, 64, 112, 112],
## hpr shapes
[8, 32, 132, 20], ## pass
[16, 32, 132, 20], ## pass
[32, 32, 132, 20], ## pass
[64, 32, 132, 20], ## pass
[128, 32, 132, 20], ## pass
# [256, 32, 132, 20], ## oom
[8, 32, 264, 40], ## pass
[16, 32, 264, 40], ## pass
[32, 32, 264, 40], ## pass
# [64, 32, 264, 40], ## oom
# [128, 32, 264, 40], ## oom
# [256, 32, 264, 40], ## oom
[4, 16, 1056, 160], ## pass
# [8, 16, 1056, 160], ## oom
# [16, 16, 1056, 160], ## oom
# [32, 16, 1056, 160], ## oom
# [64, 16, 1056, 160], ## oom
# [128, 16, 1056, 160], ## oom
# [256, 16, 1056, 160], ## oom
[8, 16, 528, 80], ## pass
[16, 16, 528, 80], ## pass
# [32, 16, 528, 80], ## oom
# [64, 16, 528, 80], ## oom
# [128, 16, 528, 80], ## oom
# [256, 16, 528, 80], ## oom
)
),
)
@pytest.mark.parametrize(
"kernel_size",
(
(2, 2),
(3, 3),
),
)
@pytest.mark.parametrize(
"padding",
(
(0, 0),
(1, 1),
),
)
@pytest.mark.parametrize(
"stride",
((2, 2),),
)
@pytest.mark.parametrize("dilation", ((1, 1),)) ## default
@pytest.mark.parametrize(
"nblocks",
(1,),
)
@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b])
def test_run_max_pool(
act_shape,
kernel_size,
padding,
stride,
dilation,
nblocks,
device,
dtype,
):
in_n, in_c, in_h, in_w = act_shape
kernel_h, kernel_w = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dilation_h, dilation_w = dilation

if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w:
pytest.skip("Invalid case")

if (kernel_h == 3 and pad_h != 1) or (kernel_h == 2 and pad_h != 0):
pytest.skip("kernel size and padding combination not supported")

out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1
out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1
if out_w % nblocks != 0:
pytest.skip(f"Unsupported case when out_w ({out_w}) % nblocks ({nblocks}) != 0")

if in_c % 16 != 0:
pytest.skip("Current maxpool writer needs nchannels to be multiple of 16!")

if in_c == 16 and dtype == ttnn.bfloat8_b and in_n * in_h * in_w > 600000:
pytest.skip("This case runs out of memory on Grayskull")

if in_n >= 16 and in_c >= 64 and dtype == ttnn.bfloat8_b and is_wormhole_b0():
pytest.skip("This case runs out of memory on Wormhole b0")

if (
is_wormhole_b0()
and act_shape == [16, 64, 112, 112]
and kernel_size == (3, 3)
and padding == (1, 1)
and stride == (2, 2)
and dilation == (1, 1)
and dtype == ttnn.bfloat16
):
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32)

## construct the tensor in NCHW shape
act = torch.randn(act_shape, dtype=torch.bfloat16)
# act = torch.zeros(act_shape, dtype=torch.bfloat16)
# act = torch.ones(act_shape, dtype=torch.bfloat16)
# act = torch.arange(0, volume(act_shape), dtype=torch.bfloat16).reshape(act_shape)
# for n in range(act_shape[0]):
# for c in range(act_shape[1]):
# for h in range(act_shape[2]):
# for w in range(act_shape[3]):
# act[n, c, h, w] = 1 + n + h + w + c # + torch.rand(1) * 0.15
# torch.save(act, "act.pt")
# act = torch.load("act.pt")

## this op expects input tensor as { N, 1, H * W, C }, so rearrange and reshape tensor
## but before that, make sure in_c is multiple of tile width
act_shape = (in_n, 1, in_h * in_w, in_c)
act_permuted = torch.permute(act, (0, 2, 3, 1))
act_reshaped = act_permuted.reshape(act_shape)

if dtype == ttnn.bfloat8_b:
if (in_h * in_w) % 32 != 0:
pytest.skip("For BFP8_B datatype, input height * width should be multiple of 32")
ttact = ttnn.from_torch(act_reshaped, dtype, layout=ttnn.TILE_LAYOUT)
else:
ttact = ttnn.from_torch(act_reshaped, dtype)

ttact_device = ttnn.to_device(ttact, device)
output = ttnn.maxpool2d(
input_tensor=ttact_device,
batch_size=in_n,
input_height=in_h,
input_width=in_w,
channels=in_c,
kernel_size=(kernel_h, kernel_w),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation_h, dilation_w),
device=device,
)
output_host = ttnn.from_device(output)
output_pytorch_padded = ttnn.to_torch(output_host)
output_pytorch = output_pytorch_padded[:, :, :, :in_c]

## reference
golden_pytorch = torch.nn.MaxPool2d(
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=False,
ceil_mode=False,
)(act)

## test for equivalance
golden_shape = golden_pytorch.shape
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])
output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W
assert_with_pcc(output_pytorch, golden_pytorch)

## do more rigorous comparision for each element
atol, rtol = torch.testing._comparison.default_tolerances(torch.bfloat16)
if dtype == ttnn.bfloat8_b:
atol = 0.35

allclose = torch.allclose(output_pytorch, golden_pytorch, atol=atol)
isclose = torch.all(torch.isclose(output_pytorch, golden_pytorch, atol=atol))
isequal = torch.equal(output_pytorch, golden_pytorch)

assert allclose
assert isclose
if dtype == ttnn.bfloat16:
assert isequal
Loading