Skip to content

Commit

Permalink
add spmd debug
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Nov 29, 2023
1 parent 402166b commit 795ef1f
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
run_test "$CDIR/spmd/test_dynamo_spmd.py"
run_test "$CDIR/spmd/test_spmd_debugging.py"
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
Expand Down
214 changes: 214 additions & 0 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import sys

import unittest
from unittest.mock import patch
import math
import numpy as np
import os
import io
import rich

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

import test_xla_sharding_base


class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices # xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)
t = torch.randn(8, 4, device=device)
partition_spec = (0, 1)
Mesh.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

# fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 0', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 1', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 2', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 3', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 4', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 5', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 6', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 7', "center", vertical="middle"),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec = (0, None)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, (0, None))
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

color = None
text_color = None
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"),
(2, 0, 2, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fake_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"),
(2, 0, 2, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fake_table.add_row(*col)
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output

@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = self.n_devices
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec_replicated = (None, None)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, partition_spec_replicated)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

color = None
text_color = None
fask_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
highlight=True,
pad_edge=False,
box=rich.box.SQUARE)
col = []
col.append(
rich.padding.Padding(
rich.align.Align(
'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(0, 0, 1, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ spec:
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
python3 /src/pytorch/xla/test/test_autocast.py
python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
volumeMounts:
- mountPath: /dev/shm
name: dshm
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/distributed/spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded, xla_patched_nn_linear_forward)
from .api import xla_distribute_tensor, xla_distribute_module
from .debugging import visualize_tensor_sharding

__all__ = [
"XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
"ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module",
"xla_patched_nn_linear_forward"
"xla_patched_nn_linear_forward", "visualize_tensor_sharding",
]
Loading

0 comments on commit 795ef1f

Please sign in to comment.