Skip to content

Commit

Permalink
Update test_spmd_debugging.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 29, 2023
1 parent d46323b commit 17fcfa6
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def test_debugging_spmd_single_host_tiled(self):
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
print("num_devices: ")
print(num_devices)
print("device: ")
print(device)
device_ids = np.array(range(num_devices))
print("device_ids: ")
print(device_ids)
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)
t = torch.randn(8, 4, device=device)
Expand Down

0 comments on commit 17fcfa6

Please sign in to comment.