Skip to content

Commit

Permalink
Add memory configs for forge test cases.
Browse files Browse the repository at this point in the history
Signed-off-by: Nilaykumar Patel <[email protected]>
  • Loading branch information
nkpatel-tt committed Dec 24, 2024
1 parent c01b9a6 commit 957e890
Show file tree
Hide file tree
Showing 2 changed files with 8,810 additions and 1,960 deletions.
79 changes: 54 additions & 25 deletions tests/sweep_framework/sweep_utils/conv2d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ def run_conv2d_short_sweep(
input_specs,
device,
) -> list:
# for tt-forge suite, extra argument is input datatype
if len(input_specs) == 15:
# for tt-forge suite, extra arguments are tensor configs
is_forge_suite = False
if len(input_specs) > 15:
[
batch_size,
output_channels,
Expand All @@ -193,10 +194,14 @@ def run_conv2d_short_sweep(
pad_h,
pad_w,
groups,
dilation_h,
dilation_w,
has_bias,
dilation,
datatype,
[input_layout, input_buffer_type, input_dtype],
[weights_layout, weights_buffer_type, weights_dtype],
[output_layout, output_buffer_type, output_dtype],
] = input_specs
is_forge_suite = True
else:
[
batch_size,
Expand All @@ -211,51 +216,74 @@ def run_conv2d_short_sweep(
pad_h,
pad_w,
groups,
dilation_h,
dilation_w,
has_bias,
dilation,
] = input_specs
datatype = int(ttnn.bfloat16)
print(input_specs)

if datatype == int(ttnn.float32):
ttnn_datatype = ttnn.float32
torch_datatype = torch.float32
else:
ttnn_datatype = ttnn.bfloat16
torch_datatype = torch.bfloat16
if is_forge_suite:
torch_input_dtype = torch.bfloat16 if input_dtype == ttnn.DataType(ttnn.bfloat16) else torch.float32
torch_weight_dtype = torch.bfloat16 if weights_dtype == ttnn.DataType(ttnn.bfloat16) else torch.float32

conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, kernel_height, kernel_width]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_nchw = torch.randn(conv_input_shape, dtype=torch_datatype).float()
torch_input_tensor_nchw = torch.randn(
conv_input_shape, dtype=torch_input_dtype if is_forge_suite else torch.bfloat16
).float()

torch_input_tensor = torch.permute(torch_input_tensor_nchw, (0, 2, 3, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch_datatype).float()
torch_weight_tensor = torch.randn(
conv_weight_shape, dtype=torch_weight_dtype if is_forge_suite else torch.bfloat16
).float()

torch_bias_tensor = None
if has_bias:
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch_datatype).float() if has_bias else None
torch_bias_tensor = (
torch.randn(conv_bias_shape, dtype=torch_weight_dtype if is_forge_suite else torch.bfloat16).float()
if has_bias
else None
)
torch_out_golden_tensor = torch.nn.functional.conv2d(
torch_input_tensor_nchw,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation, dilation),
dilation=(dilation_h, dilation_w),
groups=groups,
)

tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn_datatype)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn_datatype)
if is_forge_suite:
input_layout = ttnn.Layout(input_layout)
input_dtype = ttnn.DataType(input_dtype)
input_memory_config = (
ttnn.DRAM_MEMORY_CONFIG if input_buffer_type == int(ttnn.BufferType.DRAM) else ttnn.L1_MEMORY_CONFIG
)
tt_input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, layout=input_layout, device=device, memory_config=input_memory_config
)
weights_dtype = ttnn.DataType(weights_dtype)
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, weights_dtype)
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, weights_dtype)
output_layout = ttnn.Layout(output_layout)
output_dtype = ttnn.DataType(output_dtype)
conv_config = ttnn.Conv2dConfig(
dtype=output_dtype,
weights_dtype=weights_dtype,
output_layout=output_layout,
)
else:
tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn_datatype)
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)
conv_config = ttnn.Conv2dConfig()

conv_config = ttnn.Conv2dConfig(
dtype=ttnn_datatype,
weights_dtype=ttnn_datatype,
)
start_time = start_measuring_time()
[tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d(
input_tensor=tt_input_tensor,
Expand All @@ -267,7 +295,7 @@ def run_conv2d_short_sweep(
kernel_size=(kernel_height, kernel_width),
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(dilation, dilation),
dilation=(dilation_h, dilation_w),
batch_size=batch_size,
input_height=input_height,
input_width=input_width,
Expand All @@ -288,6 +316,7 @@ def run_conv2d_short_sweep(

torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

print("End of test case")
return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]


Expand Down
Loading

0 comments on commit 957e890

Please sign in to comment.