diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 70bfe5c589df..b19a7a65e7ff 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -5,30 +5,35 @@ import math import numpy as np import os +import io +import rich import torch import torch_xla import torch_xla.runtime as xr +import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm import torch_xla.experimental.xla_sharding as xs from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor from torch_xla.experimental.xla_sharding import Mesh -from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - import test_xla_sharding_base + class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): - xr.use_spmd()# os.environ["XLA_USE_SPMD"] = "1" + xr.use_spmd() super().setUpClass() - @unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable") - @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), - "TODO(manfei): enable it.") + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") def test_debugging_spmd_single_host_tiled(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) @@ -38,16 +43,76 @@ def test_debugging_spmd_single_host_tiled(self): partition_spec = (0, 1) xs.mark_sharding(t, mesh, partition_spec) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is:") - print(sharding) - print("then print:") - visualize_tensor_sharding(t) + generatedtable = visualize_tensor_sharding(t) + console = rich.console.Console(file=io.StringIO(), width=120) + console.print(generatedtable) + output = console.file.getvalue() + fake_console = rich.console.Console(file=io.StringIO(), width=120) + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 0', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 1', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 2', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 3', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 4', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 5', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 6', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 7', "center", vertical="middle"), + (2,1,2,1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + fake_console.print(fask_table) + fake_output = fake_console.file.getvalue() + assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable") - @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), - "TODO(manfei): enable it.") + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_partial_replication(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) @@ -55,19 +120,50 @@ def test_single_host_partial_replication(self): mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) partition_spec = (0, None) - t = torch.randn(8, 32, device=device) + t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, (0, None)) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is: ") - print(sharding) - print("then print: ") - visualize_tensor_sharding(t) + generatedtable = visualize_tensor_sharding(t) + console = rich.console.Console(file=io.StringIO(), width=120) + console.print(generatedtable) + output = console.file.getvalue() + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), + (2,0,2,0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), + (2,0,2,0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + console.print(fask_table) + fake_console = rich.console.Console(file=io.StringIO(), width=120) + fake_console.print(fask_table) + fake_output = fake_console.file.getvalue() + assert output == fake_output - @unittest.skipIf(xr.device_type() == 'CPU', "skipped on CPU before enable") - @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), - "TODO(manfei): enable it.") + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") def test_single_host_replicated(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() num_devices = xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) @@ -78,10 +174,32 @@ def test_single_host_replicated(self): t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, partition_spec_replicated) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is: ") - print(sharding) - print("then print: ") - visualize_tensor_sharding(t) + generatedtable = visualize_tensor_sharding(t) + console = rich.console.Console(file=io.StringIO(), width=120) + console.print(generatedtable) + output = console.file.getvalue() + + color = None + text_color = None + fask_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + (0,0,1,0), + style=rich.style.Style(bgcolor=color, color=text_color))) + fask_table.add_row(*col) + fake_console = rich.console.Console(file=io.StringIO(), width=120) + fake_console.print(fask_table) + fake_output = fake_console.file.getvalue() + assert output == fake_output + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index 3cd50e1e7c05..fe6613648429 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -3,10 +3,21 @@ XLAPatchedLinear, mark_sharding, clear_sharding, wrap_if_sharded, xla_patched_nn_linear_forward) from .api import xla_distribute_tensor, xla_distribute_module +# from .debugging import visualize_tensor_sharding __all__ = [ - "XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType", - "ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding", - "wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module", - "xla_patched_nn_linear_forward" + "XLAShard", + "XLAShardedTensor", + "Mesh", + "HybridMesh", + "ShardingType", + "ShardingSpec", + "XLAPatchedLinear", + "mark_sharding", + "clear_sharding", + "wrap_if_sharded", + "xla_distribute_tensor", + "xla_distribute_module", + "xla_patched_nn_linear_forward", + "visualize_tensor_sharding", ] diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index 86506df9065f..1c5f39be5710 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -11,9 +11,7 @@ import torch_xla.runtime as xr import torch_xla.experimental.xla_sharding as xs from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor -from torch_xla.experimental.xla_sharding import Mesh -# pytype: disable=import-error try: import rich import rich.align @@ -27,7 +25,7 @@ RICH_ENABLED = False # Sharding visualization -sharding_callbacks = weakref.WeakValueDictionary() # type: ignore +sharding_callbacks = weakref.WeakValueDictionary() _INSPECT_SHARDING_CALL_NAME = "InspectSharding" @@ -81,7 +79,6 @@ def visualize_sharding(shape: torch.Size, raise ValueError( "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") - # sharding[sharding.index(']')+1:-1]# sharding.devices_indices_map(tuple(shape)) slices: dict[tuple[int, ...], set[int]] = {} heights: dict[tuple[int, ...], Optional[float]] = {} widths: dict[tuple[int, ...], float] = {} @@ -102,11 +99,8 @@ def visualize_sharding(shape: torch.Size, # `device_indices_map`: [0, 1, 2, 3] # `sharding_spac`: [2, 2] sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1] - print('sharding_spac: ', sharding_spac) if len(sharding) >= 25 and sharding[-24:-1] == 'last_tile_dim_replicate': device_list = list(sharding[sharding.index(']') + 1:-24]) - print("device_list") - print(device_list) device_indices_map = [int(i) for i in device_list[:-1] if i != ','] heights = int(sharding_spac[1]) widths = int(sharding_spac[3]) @@ -114,20 +108,16 @@ def visualize_sharding(shape: torch.Size, devices_len = len(device_indices_map) len_after_dim_down = devices_len // last_dim_depth for i in range(len_after_dim_down): - slices.setdefault((i // widths, i % widths), - device_indices_map[i:i + last_dim_depth]) + slices.setdefault( + (i // widths, i % widths), + device_indices_map[i*last_dim_depth:(i + 1)*last_dim_depth]) elif sharding[-1] == "}": # eg: '{devices=[2,2]0,1,2,3}' # 13 device_list = list(sharding[sharding.index(']') + 1:-1]) - # print('device_list: ', device_list) device_indices_map = [int(i) for i in device_list if i != ','] - # print('device_indices_map: ', device_indices_map) heights = int(sharding_spac[1]) - # print('heights: ', heights) widths = int(sharding_spac[3]) - # print('widths: ', widths) devices_len = len(device_indices_map) - # print('devices_len: ', devices_len) for i in range(devices_len): slices.setdefault((i // widths, i % widths), device_indices_map[i]) else: @@ -137,21 +127,20 @@ def visualize_sharding(shape: torch.Size, num_rows = heights num_cols = widths - print('slices', slices) console = rich.console.Console(width=max_width) use_color = use_color and console.color_system is not None if use_color and not color_map: try: - import matplotlib as mpl # pytype: disable=import-error + import matplotlib as mpl color_map = mpl.colormaps["tab20b"] except ModuleNotFoundError: use_color = False - base_height = int(10 * scale) + base_height = int(3 * scale) aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] base_width = int(base_height * aspect_ratio) - height_to_width_ratio = 2.5 + height_to_width_ratio = 1.5 # eg: '{devices=[2,2]0,1,2,3}' # 13 # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' # 15 @@ -174,9 +163,7 @@ def visualize_sharding(shape: torch.Size, for i in range(num_rows): col = [] for j in range(num_cols): - entry = f"{device_kind} " + str( - slices[i, - j]) # "entry"# .join([str(s) for s in sorted(slices[i, j])]) + entry = f"{device_kind} " + str(slices[i, j]) width, maybe_height = widths, heights # widths[i, j], heights[i, j] width = int(width * base_width * height_to_width_ratio) if maybe_height is None: @@ -188,6 +175,7 @@ def visualize_sharding(shape: torch.Size, right_padding = left_padding + remainder top_padding, remainder = divmod(height - 2, 2) bottom_padding = top_padding + remainder + if use_color: color = _canonicalize_color(next(color_iter)[:3]) text_color = _get_text_color(color) @@ -198,8 +186,10 @@ def visualize_sharding(shape: torch.Size, else: color = None text_color = None + padding = (top_padding, right_padding, bottom_padding, left_padding) padding = tuple(max(x, 0) for x in padding) # type: ignore + col.append( rich.padding.Padding( rich.align.Align(entry, "center", vertical="middle"), diff --git a/torch_xla/distributed/spmd/test_debugging.py b/torch_xla/distributed/spmd/test_debugging.py deleted file mode 100644 index 9d04bd7ea62f..000000000000 --- a/torch_xla/distributed/spmd/test_debugging.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -import torch_xla -from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding - -import numpy as np -import torch -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr -import torch_xla.experimental.xla_sharding as xs -from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor -from torch_xla.experimental.xla_sharding import Mesh - - -def test_single_host_tiled(): - xr.use_spmd() - num_devices = xr.global_runtime_device_count() - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - t = torch.randn(8, 4).to(xm.xla_device()) - partition_spec = (0, 1) - m1_sharded = xs.mark_sharding(t, mesh, partition_spec) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is:") - print(sharding) - print("then print:") - visualize_tensor_sharding(t) - - -def test_single_host_partial_replication(): - xr.use_spmd() - num_devices = xr.global_runtime_device_count() - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - - partition_spec = (0, None) - t = torch.randn(8, 32).to(xm.xla_device()) - xs.mark_sharding(t, mesh, (0, None)) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is: ") - print(sharding) - print("then print: ") - visualize_tensor_sharding(t) - - -def test_single_host_replicated(): - xr.use_spmd() - num_devices = xr.global_runtime_device_count() - mesh_shape = (2, num_devices // 2) - device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) - - partition_spec_replicated = (None, None) - t = torch.randn(8, 32).to(xm.xla_device()) - # xs.mark_sharding(t, mesh, (0, None)) - sharded = xs.mark_sharding(t, mesh, partition_spec_replicated) - sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - print("sharding is: ") - print(sharding) - print("then print: ") - visualize_tensor_sharding(t) - - -test_single_host_tiled() -test_single_host_partial_replication() -test_single_host_replicated()