diff --git a/test/run_tests.sh b/test/run_tests.sh index a4c82a6d4c7..240731c702b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -203,6 +203,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" run_test "$CDIR/spmd/test_dynamo_spmd.py" + run_test "$CDIR/spmd/test_spmd_debugging.py" run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py new file mode 100644 index 00000000000..3fff925c7a0 --- /dev/null +++ b/test/spmd/test_spmd_debugging.py @@ -0,0 +1,214 @@ +import sys + +import unittest +from unittest.mock import patch +import math +import numpy as np +import os +import io +import rich + +import torch +import torch_xla +import torch_xla.runtime as xr +import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs +from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor +from torch_xla.experimental.xla_sharding import Mesh + +import test_xla_sharding_base + + +class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_debugging_spmd_single_host_tiled(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices # xr.global_runtime_device_count() + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) + t = torch.randn(8, 4, device=device) + partition_spec = (0, 1) + Mesh.mark_sharding(t, mesh, partition_spec) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + # fake_console = rich.console.Console(file=io.StringIO(), width=120) + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 0', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 1', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 2', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 3', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 4', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 5', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 6', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 7', "center", vertical="middle"), + (2, 1, 2, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_single_host_partial_replication(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), + (2, 0, 2, 0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), + (2, 0, 2, 0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_single_host_replicated(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) + + partition_spec_replicated = (None, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, partition_spec_replicated) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align( + 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + (0, 0, 1, 0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 99d59d286dc..b02356064db 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -55,6 +55,7 @@ spec: XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py + python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py volumeMounts: - mountPath: /dev/shm name: dshm diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index 3cd50e1e7c0..802d8fbed53 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -3,10 +3,11 @@ XLAPatchedLinear, mark_sharding, clear_sharding, wrap_if_sharded, xla_patched_nn_linear_forward) from .api import xla_distribute_tensor, xla_distribute_module +from .debugging import visualize_tensor_sharding __all__ = [ "XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType", "ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding", "wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module", - "xla_patched_nn_linear_forward" + "xla_patched_nn_linear_forward", "visualize_tensor_sharding", ] diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py new file mode 100644 index 00000000000..3586cde5693 --- /dev/null +++ b/torch_xla/distributed/spmd/debugging.py @@ -0,0 +1,226 @@ +from collections.abc import Sequence +import functools +import string +import sys +from typing import Any, Callable, Optional, Union +import weakref + +import numpy as np +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +import torch_xla.experimental.xla_sharding as xs +from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor + +try: + import rich + import rich.align + import rich.box + import rich.console + import rich.padding + import rich.style + import rich.table + RICH_ENABLED = True +except: + RICH_ENABLED = False + +# Sharding visualization +sharding_callbacks = weakref.WeakValueDictionary() +_INSPECT_SHARDING_CALL_NAME = "InspectSharding" + + +class ShardingCallbackInfo: + + def __init__(self, callback, module_context): + self.callback = callback + self.module_context = module_context + + +Color = Union[tuple[float, float, float], str] +ColorMap = Callable[[float], tuple[float, float, float, float]] + + +def _canonicalize_color(color: Color) -> str: + if isinstance(color, str): + return color + r, g, b = (int(a * 255) for a in color) + return f"#{r:02X}{g:02X}{b:02X}" + + +def _get_text_color(color: str) -> str: + r, g, b = torch.map(lambda x: int(x, 16), + (color[1:3], color[3:5], color[5:7])) + if (r * 0.299 + g * 0.587 + b * 0.114) > 186: + return "#000000" + return "#ffffff" + + +def make_color_iter(color_map, num_rows, num_cols): + num_colors = num_rows * num_cols + color_values = np.linspace(0, 1, num_colors) + idx = 0 + for _ in range(num_colors): + yield color_map(color_values[idx]) + idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors + + +def visualize_sharding(shape: torch.Size, + sharding: str, + use_color: bool = True, + scale: float = 1., + min_width: int = 9, + max_width: int = 80, + color_map: Optional[ColorMap] = None): + """Visualizes a ``Sharding`` using ``rich``. + Args: + shape (`torch.Size`): shape of tensor to be visualized + sharding (`str`): sharding of given tensor with SPMD + use_color (`bool`): whether use color or not + scale (`float`): scale of table visualized in console + min_width (`int`): min width used to setup table to visualize + max_width (`int`): max width used to setup table to visualize + color_map (`Optional[ColorMap]`): color_map used to paint table to visualize + Returns: + table to visualize given tensor sharding. This function + will also visualize the sharding of the tensor without as return. + """ + + if not RICH_ENABLED: + raise ValueError("`visualize_sharding` requires `rich` to be installed.") + + # if len(shape) > 2 or len(shape) < 1: + # raise ValueError( + # "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") + + slices: dict[tuple[int, ...], set[int]] = {} + heights: dict[tuple[int, ...], Optional[float]] = {} + widths: dict[tuple[int, ...], float] = {} + + if len(sharding) > 0: + # sharding is longer than 0 + # eg: '{devices=[2,2]0,1,2,3}' # 13 + # eg: '{replicated}' + # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 + if sharding == '{replicated}': + # eg: '{replicated}' + heights = 1 + widths = 1 + num_devices = xr.global_runtime_device_count() + device_ids = list(range(num_devices)) + slices.setdefault((0, 0), device_ids) + else: + # `device_indices_map`: [0, 1, 2, 3] + # `sharding_spac`: [2, 2] + sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1] + if len(sharding) >= 25 and sharding[-24:-1] == 'last_tile_dim_replicate': + device_list = list(sharding[sharding.index(']') + 1:-24]) + device_indices_map = [int(i) for i in device_list[:-1] if i != ','] + heights = int(sharding_spac[1]) + widths = int(sharding_spac[3]) + last_dim_depth = int(sharding_spac[5]) + devices_len = len(device_indices_map) + len_after_dim_down = devices_len // last_dim_depth + for i in range(len_after_dim_down): + slices.setdefault( + (i // widths, i % widths), + device_indices_map[i * last_dim_depth:(i + 1) * last_dim_depth]) + elif sharding[-1] == "}": + # eg: '{devices=[2,2]0,1,2,3}' # 13 + device_list = list(sharding[sharding.index(']') + 1:-1]) + device_indices_map = [int(i) for i in device_list if i != ','] + heights = int(sharding_spac[1]) + widths = int(sharding_spac[3]) + devices_len = len(device_indices_map) + for i in range(devices_len): + slices.setdefault((i // widths, i % widths), device_indices_map[i]) + else: + raise ValueError("sharding is not organized as expected") + else: + raise ValueError("sharding has no value") + + num_rows = heights + num_cols = widths + + console = rich.console.Console(width=max_width) + use_color = use_color and console.color_system is not None + if use_color and not color_map: + try: + import matplotlib as mpl + color_map = mpl.colormaps["tab20b"] + except ModuleNotFoundError: + use_color = False + + base_height = int(3 * scale) + aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] + base_width = int(base_height * aspect_ratio) + height_to_width_ratio = 1.5 + + # eg: '{devices=[2,2]0,1,2,3}' # 13 + # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 + + # slcs is the data we saved on this slice + # `device_indices_map`: [0, 1, 2, 3] + # `sharding_spac`: [2, 2] + + # set the device kind to TPU as default since `sharding` here is `str`, TODO(@manfei): get device kind from commands for TPU/GPU/CPU + device_kind = 'TPU' # next(iter(sharding.device_set)).platform.upper() + + color_iter = make_color_iter(color_map, num_rows, num_cols) + table = rich.table.Table( + show_header=False, + show_lines=not use_color, + padding=0, + highlight=not use_color, + pad_edge=False, + box=rich.box.SQUARE if not use_color else None) + for i in range(num_rows): + col = [] + for j in range(num_cols): + entry = f"{device_kind} " + str(slices[i, j]) + width, maybe_height = widths, heights # widths[i, j], heights[i, j] + width = int(width * base_width * height_to_width_ratio) + if maybe_height is None: + height = 1 + else: + height = int(maybe_height * base_height) + width = min(max(width, min_width), max_width) + left_padding, remainder = divmod(width - len(entry) - 2, 2) + right_padding = left_padding + remainder + top_padding, remainder = divmod(height - 2, 2) + bottom_padding = top_padding + remainder + + if use_color: + color = _canonicalize_color(next(color_iter)[:3]) + text_color = _get_text_color(color) + top_padding += 1 + bottom_padding += 1 + left_padding += 1 + right_padding += 1 + else: + color = None + text_color = None + + padding = (top_padding, right_padding, bottom_padding, left_padding) + padding = tuple(max(x, 0) for x in padding) + + col.append( + rich.padding.Padding( + rich.align.Align(entry, "center", vertical="middle"), + padding, + style=rich.style.Style(bgcolor=color, color=text_color))) + table.add_row(*col) + console.print(table, end='\n\n') + return table + + +def visualize_tensor_sharding(t, **kwargs): + """Visualizes an array's sharding.""" + if (assert instanceof(t, torch.tensor)): + import torch_xla + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + return visualize_sharding(t.shape, sharding, **kwargs) + elif (assert instanceof(t, XLAShardedTensor)): + import torch_xla + sharding = torch_xla._XLAC._get_xla_sharding_spec(t.global_tensor) + return visualize_sharding(t.global_tensor.shape, sharding, **kwargs) +