diff --git a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py index 7d899eb5edb..af7ae1781a3 100644 --- a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py +++ b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py @@ -45,6 +45,7 @@ ], "enable_async": [True], "num_iters": [1], + "tile": [(32, 32)], }, } @@ -58,6 +59,7 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: test_vector["num_links"], test_vector["input_dtype"], test_vector["layout"], + test_vector["tile"], ) if is_known_failure: return True, f"Skipping unsupported case {message}." @@ -90,6 +92,7 @@ def run( mem_config, enable_async, num_iters, + tile, *, device, ) -> list: @@ -108,7 +111,9 @@ def run( input_tensors = torch.chunk(input_tensor, num_devices, dim) tt_input_tensors = [] for i, t in enumerate(input_tensors): - tt_input_tensors.append(ttnn.Tensor(t, input_dtype).to(layout).to(all_devices[i], mem_config)) + tt_input_tensors.append( + ttnn.Tensor(t, input_dtype, {}, ttnn.Tile(tile)).to(layout).to(all_devices[i], mem_config) + ) input_tensor_mesh = ttnn.aggregate_as_tensor(tt_input_tensors) for i in range(num_iters): diff --git a/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py b/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py index 710ccb9f21a..e64a913fd0e 100644 --- a/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py +++ b/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py @@ -121,6 +121,7 @@ ], "enable_async": [True], "num_iters": [1], + "tile": [(32, 32)], }, "all_gather_n300_focused_large": { "num_devices": [2], @@ -134,6 +135,7 @@ ], "enable_async": [True], "num_iters": [1], + "tile": [(32, 32)], }, } @@ -179,6 +181,7 @@ def run( mem_config, enable_async, num_iters, + tile, *, device, ) -> list: @@ -204,7 +207,7 @@ def run( input_tensors = torch.chunk(input_tensor, num_devices, dim) tt_input_tensors = [] for i, t in enumerate(input_tensors): - t = ttnn.from_torch(t, input_dtype, layout=ttnn.Layout.TILE) + t = ttnn.from_torch(t, input_dtype, tile=ttnn.Tile(tile), layout=ttnn.Layout.TILE) t = t.to(all_devices[i], mem_config) tt_input_tensors.append(t) diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index f6a621a338f..bdd30d3393c 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -39,6 +39,7 @@ "mem_config": [ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM)], "enable_async": [True, False], "num_iters": [1], + "tile": [(32, 32)], }, } @@ -52,6 +53,7 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: test_vector["num_links"], test_vector["input_dtype"], test_vector["layout"], + test_vector["tile"], ) if is_known_failure: return True, f"Skipping unsupported case {message}." @@ -89,6 +91,7 @@ def run( mem_config, enable_async, num_iters, + tile, *, device, ) -> list: @@ -100,7 +103,9 @@ def run( input_tensor = torch.rand(input_shape).bfloat16() - ttnn_tensor = ttnn.from_torch(input_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim)) + ttnn_tensor = ttnn.from_torch( + input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim) + ) input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device) for i in range(num_iters):