Skip to content

Commit

Permalink
#0: testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Jul 13, 2024
1 parent 1285ff8 commit 56343b2
Showing 1 changed file with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,15 @@

class ResNet50TestInfra:
def __init__(
self, device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight, dealloc_input, final_output_mem_config
self,
device,
batch_size,
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight,
dealloc_input,
final_output_mem_config,
):
super().__init__()
torch.manual_seed(0)
Expand Down Expand Up @@ -234,7 +242,14 @@ def create_test_infra(
final_output_mem_config=ttnn.L1_MEMORY_CONFIG,
):
return ResNet50TestInfra(
device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight, dealloc_input, final_output_mem_config
device,
batch_size,
act_dtype,
weight_dtype,
math_fidelity,
use_pretrained_weight,
dealloc_input,
final_output_mem_config,
)


Expand Down Expand Up @@ -269,6 +284,8 @@ def test_resnet_50(
pytest.skip("Skipping batch size 8 due to memory config issue")
if is_wormhole_b0() and batch_size == 20:
pytest.skip("Skipping batch size 20 for Wormhole B0 due to fitting issue")
if (device.compute_with_storage_grid_size().x, device.compute_with_storage_grid_size().y) == (8, 7):
pytest.skip("Test is not supported on n300 (8,7) grid")

test_infra = create_test_infra(device, batch_size, act_dtype, weight_dtype, math_fidelity, use_pretrained_weight)
enable_memory_reports()
Expand Down

0 comments on commit 56343b2

Please sign in to comment.