Skip to content

Commit

Permalink
Update xla_test_job.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai committed Nov 29, 2023
1 parent 7e3096f commit 281685b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 69 deletions.
102 changes: 55 additions & 47 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.distributed.spmd as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

Expand All @@ -35,20 +35,22 @@ def setUpClass(cls):
def test_debugging_spmd_single_host_tiled(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
num_devices = self.n_devices # xr.global_runtime_device_count()
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)
t = torch.randn(8, 4, device=device)
partition_spec = (0, 1)
xs.mark_sharding(t, mesh, partition_spec)
Mesh.mark_sharding(t, mesh, partition_spec)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

fake_console = rich.console.Console(file=io.StringIO(), width=120)
# fake_console = rich.console.Console(file=io.StringIO(), width=120)
color = None
text_color = None
fask_table = rich.table.Table(
Expand All @@ -62,75 +64,78 @@ def test_debugging_spmd_single_host_tiled(self):
col.append(
rich.padding.Padding(
rich.align.Align('TPU 0', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 1', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 2', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 3', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU 4', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 5', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 6', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
col.append(
rich.padding.Padding(
rich.align.Align('TPU 7', "center", vertical="middle"),
(2,1,2,1),
(2, 1, 2, 1),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output


@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_partial_replication(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
num_devices = self.n_devices
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec = (0, None)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, (0, None))
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

color = None
text_color = None
fask_table = rich.table.Table(
fake_table = rich.table.Table(
show_header=False,
show_lines=True,
padding=0,
Expand All @@ -141,43 +146,44 @@ def test_single_host_partial_replication(self):
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"),
(2,0,2,0),
(2, 0, 2, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_table.add_row(*col)
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"),
(2,0,2,0),
(2, 0, 2, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
console.print(fask_table)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
fake_table.add_row(*col)
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output


@unittest.skipIf(
not xr.using_pjrt() or
xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'),
f"Requires PJRT_DEVICE set to `TPU`.")
def test_single_host_replicated(self):
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
num_devices = self.n_devices
mesh_shape = (2, num_devices // 2)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
# mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
mesh = self._get_mesh(mesh_shape)

partition_spec_replicated = (None, None)
t = torch.randn(8, 32, device=device)
xs.mark_sharding(t, mesh, partition_spec_replicated)
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
generatedtable = visualize_tensor_sharding(t)
console = rich.console.Console(file=io.StringIO(), width=120)
console.print(generatedtable)
output = console.file.getvalue()
generated_table = visualize_tensor_sharding(t)
console = Console()
with console.capture() as capture:
console.print(generated_table)
output = capture.get()

color = None
text_color = None
Expand All @@ -191,13 +197,15 @@ def test_single_host_replicated(self):
col = []
col.append(
rich.padding.Padding(
rich.align.Align('TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(0,0,1,0),
rich.align.Align(
'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"),
(0, 0, 1, 0),
style=rich.style.Style(bgcolor=color, color=text_color)))
fask_table.add_row(*col)
fake_console = rich.console.Console(file=io.StringIO(), width=120)
fake_console.print(fask_table)
fake_output = fake_console.file.getvalue()
fake_console = Console()
with fake_console.capture() as fake_capture:
fake_console.print(fake_table)
fake_output = fake_capture.get()
assert output == fake_output


Expand Down
24 changes: 12 additions & 12 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ spec:
pip install expecttest
pip install rich
# python3 /src/pytorch/xla/test/test_operations.py -v
# python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py
# python3 /src/pytorch/xla/test/pjrt/test_collective_ops_tpu.py
# python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
# python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
# python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py
# python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
# python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py
# XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
# XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
# python3 /src/pytorch/xla/test/test_autocast.py
# python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/test_operations.py -v
python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py
python3 /src/pytorch/xla/test/pjrt/test_collective_ops_tpu.py
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py
python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
python3 /src/pytorch/xla/test/test_autocast.py
python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
volumeMounts:
- mountPath: /dev/shm
Expand Down
40 changes: 30 additions & 10 deletions torch_xla/distributed/spmd/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,28 @@ def visualize_sharding(shape: torch.Size,
min_width: int = 9,
max_width: int = 80,
color_map: Optional[ColorMap] = None):
"""Visualizes a ``Sharding`` using ``rich``."""
"""Visualizes a ``Sharding`` using ``rich``.
Args:
shape (`torch.Size`): shape of tensor to be visualized
sharding (`str`): sharding of given tensor with SPMD
use_color (`bool`): whether use color or not
scale (`float`): scale of table visualized in console
min_width (`int`): min width used to setup table to visualize
max_width (`int`): max width used to setup table to visualize
color_map (`Optional[ColorMap]`): color_map used to paint table to visualize
Returns:
table to visualize given tensor sharding. This function
will also visualize the sharding of the tensor without as return.
"""

if not RICH_ENABLED:
raise ValueError("`visualize_sharding` requires `rich` to be installed.")

if len(shape) > 2 or len(shape) < 1:
raise ValueError(
"`visualize_sharding` only works for shapes with 1 and 2 dimensions.")
# if len(shape) > 2 or len(shape) < 1:
# raise ValueError(
# "`visualize_sharding` only works for shapes with 1 and 2 dimensions.")

slices: dict[tuple[int, ...], set[int]] = {}
heights: dict[tuple[int, ...], Optional[float]] = {}
Expand Down Expand Up @@ -110,7 +125,7 @@ def visualize_sharding(shape: torch.Size,
for i in range(len_after_dim_down):
slices.setdefault(
(i // widths, i % widths),
device_indices_map[i*last_dim_depth:(i + 1)*last_dim_depth])
device_indices_map[i * last_dim_depth:(i + 1) * last_dim_depth])
elif sharding[-1] == "}":
# eg: '{devices=[2,2]0,1,2,3}' # 13
device_list = list(sharding[sharding.index(']') + 1:-1])
Expand Down Expand Up @@ -188,7 +203,7 @@ def visualize_sharding(shape: torch.Size,
text_color = None

padding = (top_padding, right_padding, bottom_padding, left_padding)
padding = tuple(max(x, 0) for x in padding) # type: ignore
padding = tuple(max(x, 0) for x in padding)

col.append(
rich.padding.Padding(
Expand All @@ -200,8 +215,13 @@ def visualize_sharding(shape: torch.Size,
return table


def visualize_tensor_sharding(ter, **kwargs):
def visualize_tensor_sharding(t, **kwargs):
"""Visualizes an array's sharding."""
import torch_xla
sharding = torch_xla._XLAC._get_xla_sharding_spec(ter)
return visualize_sharding(ter.shape, sharding, **kwargs)
if (assert instanceof(t, torch.tensor)):
import torch_xla
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
return visualize_sharding(t.shape, sharding, **kwargs)
elif (assert instanceof(t, XLAShardedTensor)):
import torch_xla
sharding = torch_xla._XLAC._get_xla_sharding_spec(t.global_tensor)
return visualize_sharding(t.global_tensor.shape, sharding, **kwargs)

0 comments on commit 281685b

Please sign in to comment.