From 6aba540b6f85cdc06d8c5031a0408d0ea50b8756 Mon Sep 17 00:00:00 2001 From: asaigal Date: Wed, 26 Jun 2024 15:22:14 +0000 Subject: [PATCH] #0: Decrease num loops in trace stress tests --- tests/ttnn/unit_tests/test_multi_device_trace.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index c8718e97da8..1aa72ccb755 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -175,7 +175,12 @@ def run_op_chain_2(input_0, input_1, weight): # Execute and verify trace against pytorch torch_silu = torch.nn.SiLU() torch_softmax = torch.nn.Softmax(dim=1) - for i in range(NUM_TRACE_LOOPS): + # Decrease loop count for larger shapes, since they time out on CI + num_trace_loops = NUM_TRACE_LOOPS + if shape == (1, 3, 1024, 1024): + num_trace_loops = 7 + + for i in range(num_trace_loops): # Create torch inputs torch_input_tensor_0 = torch.rand( (t3k_device_mesh.get_num_devices(), shape[1], shape[2], shape[3]), dtype=torch.bfloat16