Skip to content

Commit

Permalink
Update test_spmd_debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 29, 2023
1 parent e31665f commit 1312801
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_debugging_spmd_single_host_tiled(self):
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)
t = torch.randn(8, 4, device=device)
partition_spec = (0, 1)
xs.mark_sharding(t, mesh, partition_spec)
Expand Down Expand Up @@ -116,7 +117,8 @@ def test_single_host_partial_replication(self):
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec = (0, None)
t = torch.randn(8, 32, device=device)
Expand Down Expand Up @@ -166,7 +168,8 @@ def test_single_host_replicated(self):
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec_replicated = (None, None)
t = torch.randn(8, 32, device=device)
Expand Down

0 comments on commit 1312801

Please sign in to comment.