diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 68ab7b99882..be58fb355b7 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -38,7 +38,8 @@ def test_debugging_spmd_single_host_tiled(self): num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) t = torch.randn(8, 4, device=device) partition_spec = (0, 1) xs.mark_sharding(t, mesh, partition_spec) @@ -116,7 +117,8 @@ def test_single_host_partial_replication(self): num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) partition_spec = (0, None) t = torch.randn(8, 32, device=device) @@ -166,7 +168,8 @@ def test_single_host_replicated(self): num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) partition_spec_replicated = (None, None) t = torch.randn(8, 32, device=device)