diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py index e23d90637873..abcc1bd81568 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py @@ -25,7 +25,11 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_transformer_inference(mesh_device, use_program_cache, reset_seeds): diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 613fd2a30219..bea85a0a16f7 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -33,7 +33,11 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_program_cache, reset_seeds, ensure_gc): diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index 1fee8a125c44..d042eb1e6833 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -33,7 +33,11 @@ ) @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_image_transformer_inference( diff --git a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py index 50064a2e4800..61824eb484e1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py +++ b/models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py @@ -23,7 +23,11 @@ @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "mesh_device", - [{"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get(os.environ.get("FAKE_DEVICE"), None)], + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], indirect=True, ) def test_llama_vision_encoder_inference(mesh_device, use_program_cache, reset_seeds):