Skip to content

Commit

Permalink
#15183: Update conv_transpose2d with auto_shard, matmul for 1x1 conv …
Browse files Browse the repository at this point in the history
…& output mem_config support. Implemented sweep tests. (#15256)

### Ticket
#15183 

### Problem description
- Added support for using matmul for conv 1x1. 
- Enable auto_shard support.
- Added support for passing output mem_config. 
- Create short sweep tests for conv_transpose2d.



### Checklist
- [ ] Post commit CI passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
sankarmanoj-tt authored Dec 4, 2024
1 parent 56bce0a commit 232b9ac
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ on:
- data_movement.repeat_interleave.repeat_interleave
- data_movement.nonzero.nonzero
- data_movement.backward.concat_bw.concat_bw
- conv_transpose2d.short.conv_transpose2d_short_sweep
- conv2d.full.conv2d_misc
- conv2d.full.conv2d_sharding
- conv2d.full.conv2d_sliding_window
Expand Down
260 changes: 260 additions & 0 deletions tests/sweep_framework/sweep_utils/conv_transpose2d_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple, List
import itertools
import random
import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30


def get_input_specs(
batch_list: List[int],
acts_list: List[int],
kernel_list: List[int],
stride_list: List[int],
padding_list: List[int],
dilation_list: List[int],
) -> Tuple[int, int, int, int, int, int, int, int, int, int]:
for batch_size, activation, kernel, stride, padding, dilation in itertools.product(
batch_list, acts_list, kernel_list, stride_list, padding_list, dilation_list
):
yield (batch_size, activation, activation, kernel, kernel, stride, stride, padding, padding, dilation)


def mesh_device_fixture():
num_devices = ttnn.GetNumPCIeDevices()
# As of now take device id as 0.
device_id = 0
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768)
ttnn.SetDefaultDevice(device)

device_name = "Unknown"
if ttnn.device.is_grayskull(device):
device_name = "grayskull"
elif ttnn.device.is_wormhole_b0(device):
device_name = "wormhole_b0"
yield device, device_name

ttnn.close_device(device)


def run_full(
input_specs,
input_channels,
output_channels,
transpose_mcast,
output_layout,
has_bias,
enable_act_double_buffer,
enable_split_reader,
enable_subblock_padding,
activations_dtype,
weights_dtype,
math_fidelity,
fp32_accum,
packer_l1_acc,
groups,
override_sharding_config,
core_grid,
use_shallow_conv_variant,
deallocate_activation,
enable_auto_formatting,
device,
padded_input_channels=None,
) -> list:
[
batch_size,
input_height,
input_width,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
out_pad_h,
out_pad_w,
] = input_specs
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, kernel_height, kernel_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()

torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()

torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation, dilation),
groups=groups,
)

tt_weight_tensor = ttnn.from_torch(
torch_weight_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(
torch_bias_tensor, weights_dtype if weights_dtype != ttnn.bfloat8_b else ttnn.float32
)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)

conv_config = ttnn.Conv2dConfig(
dtype=activations_dtype,
weights_dtype=weights_dtype,
math_fidelity=math_fidelity,
shard_layout=None,
deallocate_activation=deallocate_activation,
fp32_dest_acc_enabled=fp32_accum,
packer_l1_accum_enabled=packer_l1_acc,
override_sharding_config=override_sharding_config,
output_layout=output_layout,
enable_act_double_buffer=enable_act_double_buffer,
enable_split_reader=enable_split_reader,
enable_subblock_padding=enable_subblock_padding,
)

if override_sharding_config:
if len(core_grid) == 2:
conv_config.core_grid = ttnn.CoreRangeSet({ttnn.CoreRange(core_grid[0], core_grid[1])})
elif len(core_grid) == 4:
conv_config.core_grid = ttnn.CoreRangeSet(
{ttnn.CoreRange(core_grid[0], core_grid[1]), ttnn.CoreRange(core_grid[2], core_grid[3])}
)
start_time = start_measuring_time()
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv_transpose2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
out_channels=output_channels,
device=device,
bias_tensor=tt_bias_tensor,
kernel_size=(kernel_height, kernel_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
output_padding=(out_pad_h, out_pad_w),
dilation=(dilation_h, dilation_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
conv_config=conv_config,
groups=groups,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
e2e_perf = stop_measuring_time(start_time)

# torch_output_tensor is in row major layout and NHWC shape
# NHWC to NCHW
torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]


def run_short(
input_specs,
device,
) -> list:
[
batch_size,
input_channels,
input_height,
input_width,
output_channels,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
out_pad_h,
out_pad_w,
] = input_specs
print(input_specs)
groups = 1
has_bias = True

conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [input_channels, output_channels // groups, kernel_height, kernel_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()

torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()

torch_bias_tensor = None
if has_bias:
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv_transpose2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation_h, dilation_w),
output_padding=(out_pad_h, out_pad_w),
groups=groups,
)

tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)

start_time = start_measuring_time()
[tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv_transpose2d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
out_channels=output_channels,
device=device,
bias_tensor=tt_bias_tensor,
kernel_size=(kernel_height, kernel_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
output_padding=(out_pad_h, out_pad_w),
dilation=(dilation_h, dilation_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
groups=groups,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
e2e_perf = stop_measuring_time(start_time)

# torch_output_tensor is in row major layout and NHWC shape
# NHWC to NCHW
torch_output_tensor = torch_output_tensor.reshape(batch_size, out_height, out_width, torch_output_tensor.shape[-1])
torch_output_tensor = torch_output_tensor[:, :, :, :output_channels]

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple, List
import os
import itertools
import random
import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random
from tests.sweep_framework.sweep_utils.conv_transpose2d_common import run_short, mesh_device_fixture

parameters = {
"short_sweep_suite": {
"input_specs": [
# Contains following params
# [batch_size, input_channels, input_height, input_width, output_channels, kernel_height, kernel_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, out_pad_h, out_pad_w]
# [20, 16, 50, 100, 33, 3, 3, 2, 2, 0, 0, 1, 1, 0, 0], Batch size too big
[1, 16, 50, 100, 33, 3, 3, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 1024, 14, 14, 512, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 128, 112, 112, 64, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 128, 64, 64, 64, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 16, 14, 14, 1, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 256, 32, 32, 128, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 256, 56, 56, 128, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 4, 7, 7, 16, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 512, 16, 16, 256, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
# [1, 512, 28, 28, 256, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
[1, 64, 128, 128, 32, 2, 2, 2, 2, 0, 0, 1, 1, 0, 0],
]
},
}


def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
return False, None


def run(
input_specs,
*,
device,
) -> list:
return run_short(
input_specs,
device,
)


import pytest


@pytest.mark.parametrize("input_spec", parameters["short_sweep_suite"]["input_specs"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_conv_transpose2d_localrun(device, input_spec):
run_short(
input_spec,
device,
)
23 changes: 6 additions & 17 deletions tests/ttnn/unit_tests/operations/test_conv_transpose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def run_conv_transpose2d(
assert passing


@skip_for_blackhole()
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 64 * 1024}], indirect=True)
@pytest.mark.parametrize(
"batch_size, input_height, input_width, input_channels, output_channels, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, out_pad_h, out_pad_w, config, shard_layout",
Expand All @@ -173,26 +175,11 @@ def run_conv_transpose2d(
(1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED),
(1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED),
(1, 256, 256, 32, 32, 3, 3, 1, 1, 1, 1, 0, 0, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED),
(1, 256, 256, 32, 32, 1, 1, 1, 1, 0, 0, 0, 0, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED),
# Stride = 2
(1, 8, 8, 32, 64, 3, 3, 2, 2, 1, 1, 1, 1, None, ttnn.TensorMemoryLayout.WIDTH_SHARDED),
(1, 128, 128, 32, 64, 3, 3, 2, 2, 1, 1, 1, 1, {"act_block_h": 64}, ttnn.TensorMemoryLayout.HEIGHT_SHARDED),
(
1,
16,
16,
256,
256,
3,
3,
2,
2,
1,
1,
1,
1,
None,
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
), # Fails with error : act_block_w_datums == round_up(conv_act_size_c * filter_w, TILE_WIDTH)
(1, 16, 16, 256, 256, 3, 3, 2, 2, 1, 1, 1, 1, None, ttnn.TensorMemoryLayout.BLOCK_SHARDED),
# # (1, 16, 16, 32, 32, 3, 3, 2, 2, 1, 1, 0, 0, None, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), # Issue with reading block sharded tensor
# Vanilla Unet
# Filter Size = 2 not supported in Block sharded
Expand Down Expand Up @@ -232,6 +219,8 @@ def test_simple_conv_t2d(
config,
shard_layout,
):
if device.core_grid.y != 8:
pytest.skip("Needs 8x8 Grid")
run_conv_transpose2d(
device,
math_fidelity=ttnn.MathFidelity.HiFi4,
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co
}
}

static void adjust_conv_op_config_for_auto_shard(
void adjust_conv_op_config_for_auto_shard(
bool is_mm_conv,
uint32_t batch_size,
uint32_t in_channels,
Expand Down
Loading

0 comments on commit 232b9ac

Please sign in to comment.