Skip to content

Commit

Permalink
#16178:Add tt-forge sweep for conv2d.
Browse files Browse the repository at this point in the history
Add sweep suite for tt-forge

Signed-off-by: Nilaykumar Patel <[email protected]>
  • Loading branch information
nkpatel-tt authored Jan 10, 2025
1 parent a94c89e commit 5cdf0fa
Show file tree
Hide file tree
Showing 4 changed files with 8,888 additions and 1,594 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ on:
- conv2d.full.conv2d_sharding
- conv2d.full.conv2d_sliding_window
- conv2d.short.conv2d_short_sweep
- conv2d.short.conv2d_ttforge_sweep
- pooling.global_avg_pool2d
- pooling.max_pool2d
- max_pool2d.short.max_pool2d_short_sweep
Expand Down
110 changes: 85 additions & 25 deletions tests/sweep_framework/sweep_utils/conv2d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,51 +178,109 @@ def run_conv2d_short_sweep(
input_specs,
device,
) -> list:
[
batch_size,
output_channels,
input_channels,
input_height,
input_width,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
groups,
has_bias,
dilation,
] = input_specs
# for tt-forge suite, extra arguments are tensor configs
is_forge_suite = False
if len(input_specs) > 15:
[
batch_size,
output_channels,
input_channels,
input_height,
input_width,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
groups,
dilation_h,
dilation_w,
has_bias,
[input_layout, input_buffer_type, input_dtype],
[weights_layout, weights_buffer_type, weights_dtype],
[output_layout, output_buffer_type, output_dtype],
] = input_specs
is_forge_suite = True
else:
[
batch_size,
output_channels,
input_channels,
input_height,
input_width,
kernel_height,
kernel_width,
stride_h,
stride_w,
pad_h,
pad_w,
groups,
dilation_h,
dilation_w,
has_bias,
] = input_specs
print(input_specs)

if is_forge_suite:
torch_input_dtype = torch.bfloat16 if input_dtype == ttnn.DataType(ttnn.bfloat16) else torch.float32
torch_weight_dtype = torch.bfloat16 if weights_dtype == ttnn.DataType(ttnn.bfloat16) else torch.float32

conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, kernel_height, kernel_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
torch_input_tensor_nchw = torch.randn(
conv_input_shape, dtype=torch_input_dtype if is_forge_suite else torch.bfloat16
).float()

torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()
torch_weight_tensor = torch.randn(
conv_weight_shape, dtype=torch_weight_dtype if is_forge_suite else torch.bfloat16
).float()

torch_bias_tensor = None
if has_bias:
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_bias_tensor = (
torch.randn(conv_bias_shape, dtype=torch_weight_dtype if is_forge_suite else torch.bfloat16).float()
if has_bias
else None
)
torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation, dilation),
dilation=(dilation_h, dilation_w),
groups=groups,
)

tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)
if is_forge_suite:
input_layout = ttnn.Layout(input_layout)
input_dtype = ttnn.DataType(input_dtype)
input_memory_config = ttnn.DRAM_MEMORY_CONFIG if input_buffer_type == "dram" else ttnn.L1_MEMORY_CONFIG
tt_input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, layout=input_layout, device=device, memory_config=input_memory_config
)
weights_dtype = ttnn.DataType(weights_dtype)
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, weights_dtype)
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, weights_dtype)
output_layout = ttnn.Layout(output_layout)
output_dtype = ttnn.DataType(output_dtype)
conv_config = ttnn.Conv2dConfig(
dtype=output_dtype,
weights_dtype=weights_dtype,
output_layout=output_layout,
)
else:
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
conv_config = ttnn.Conv2dConfig()

start_time = start_measuring_time()
[tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
Expand All @@ -235,11 +293,12 @@ def run_conv2d_short_sweep(
kernel_size=(kernel_height, kernel_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation, dilation),
dilation=(dilation_h, dilation_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
groups=groups,
conv_config=conv_config,
return_output_dim=True,
return_weights_and_bias=True,
)
Expand All @@ -255,6 +314,7 @@ def run_conv2d_short_sweep(

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

print("End of test case")
return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]


Expand Down
Loading

0 comments on commit 5cdf0fa

Please sign in to comment.