diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index b19a7a65e7ff..3fff925c7a01 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -14,7 +14,7 @@ import torch_xla.utils.utils as xu import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm -import torch_xla.experimental.xla_sharding as xs +import torch_xla.distributed.spmd as xs from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor from torch_xla.experimental.xla_sharding import Mesh @@ -35,20 +35,22 @@ def setUpClass(cls): def test_debugging_spmd_single_host_tiled(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() - num_devices = xr.global_runtime_device_count() + num_devices = self.n_devices # xr.global_runtime_device_count() mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) t = torch.randn(8, 4, device=device) partition_spec = (0, 1) - xs.mark_sharding(t, mesh, partition_spec) + Mesh.mark_sharding(t, mesh, partition_spec) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generatedtable = visualize_tensor_sharding(t) - console = rich.console.Console(file=io.StringIO(), width=120) - console.print(generatedtable) - output = console.file.getvalue() + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() - fake_console = rich.console.Console(file=io.StringIO(), width=120) + # fake_console = rich.console.Console(file=io.StringIO(), width=120) color = None text_color = None fask_table = rich.table.Table( @@ -62,51 +64,52 @@ def test_debugging_spmd_single_host_tiled(self): col.append( rich.padding.Padding( rich.align.Align('TPU 0', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 1', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 2', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 3', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fask_table.add_row(*col) col = [] col.append( rich.padding.Padding( rich.align.Align('TPU 4', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 5', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 6', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) col.append( rich.padding.Padding( rich.align.Align('TPU 7', "center", vertical="middle"), - (2,1,2,1), + (2, 1, 2, 1), style=rich.style.Style(bgcolor=color, color=text_color))) fask_table.add_row(*col) - fake_console.print(fask_table) - fake_output = fake_console.file.getvalue() + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( not xr.using_pjrt() or xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), @@ -114,23 +117,25 @@ def test_debugging_spmd_single_host_tiled(self): def test_single_host_partial_replication(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() - num_devices = xr.global_runtime_device_count() + num_devices = self.n_devices mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) partition_spec = (0, None) t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, (0, None)) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generatedtable = visualize_tensor_sharding(t) - console = rich.console.Console(file=io.StringIO(), width=120) - console.print(generatedtable) - output = console.file.getvalue() + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() color = None text_color = None - fask_table = rich.table.Table( + fake_table = rich.table.Table( show_header=False, show_lines=True, padding=0, @@ -141,23 +146,22 @@ def test_single_host_partial_replication(self): col.append( rich.padding.Padding( rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), - (2,0,2,0), + (2, 0, 2, 0), style=rich.style.Style(bgcolor=color, color=text_color))) - fask_table.add_row(*col) + fake_table.add_row(*col) col = [] col.append( rich.padding.Padding( rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), - (2,0,2,0), + (2, 0, 2, 0), style=rich.style.Style(bgcolor=color, color=text_color))) - fask_table.add_row(*col) - console.print(fask_table) - fake_console = rich.console.Console(file=io.StringIO(), width=120) - fake_console.print(fask_table) - fake_output = fake_console.file.getvalue() + fake_table.add_row(*col) + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() assert output == fake_output - @unittest.skipIf( not xr.using_pjrt() or xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), @@ -165,19 +169,21 @@ def test_single_host_partial_replication(self): def test_single_host_replicated(self): from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding device = xm.xla_device() - num_devices = xr.global_runtime_device_count() + num_devices = self.n_devices mesh_shape = (2, num_devices // 2) device_ids = np.array(range(num_devices)) - mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + # mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) + mesh = self._get_mesh(mesh_shape) partition_spec_replicated = (None, None) t = torch.randn(8, 32, device=device) xs.mark_sharding(t, mesh, partition_spec_replicated) sharding = torch_xla._XLAC._get_xla_sharding_spec(t) - generatedtable = visualize_tensor_sharding(t) - console = rich.console.Console(file=io.StringIO(), width=120) - console.print(generatedtable) - output = console.file.getvalue() + generated_table = visualize_tensor_sharding(t) + console = Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() color = None text_color = None @@ -191,13 +197,15 @@ def test_single_host_replicated(self): col = [] col.append( rich.padding.Padding( - rich.align.Align('TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), - (0,0,1,0), + rich.align.Align( + 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + (0, 0, 1, 0), style=rich.style.Style(bgcolor=color, color=text_color))) fask_table.add_row(*col) - fake_console = rich.console.Console(file=io.StringIO(), width=120) - fake_console.print(fask_table) - fake_output = fake_console.file.getvalue() + fake_console = Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() assert output == fake_output diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 9ddb03dd8571..e727953ddc43 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -44,18 +44,18 @@ spec: pip install expecttest pip install rich - # python3 /src/pytorch/xla/test/test_operations.py -v - # python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py - # python3 /src/pytorch/xla/test/pjrt/test_collective_ops_tpu.py - # python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py - # python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py - # python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py - # python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py - # python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py - # XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v - # XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v - # python3 /src/pytorch/xla/test/test_autocast.py - # python3 /src/pytorch/xla/test/dynamo/test_dynamo.py + python3 /src/pytorch/xla/test/test_operations.py -v + python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py + python3 /src/pytorch/xla/test/pjrt/test_collective_ops_tpu.py + python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py + python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py + python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py + python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py + python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py + XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v + XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v + python3 /src/pytorch/xla/test/test_autocast.py + python3 /src/pytorch/xla/test/dynamo/test_dynamo.py python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py volumeMounts: - mountPath: /dev/shm diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index 1c5f39be5710..5d26580540de 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -71,13 +71,28 @@ def visualize_sharding(shape: torch.Size, min_width: int = 9, max_width: int = 80, color_map: Optional[ColorMap] = None): - """Visualizes a ``Sharding`` using ``rich``.""" + """Visualizes a ``Sharding`` using ``rich``. + + Args: + shape (`torch.Size`): shape of tensor to be visualized + sharding (`str`): sharding of given tensor with SPMD + use_color (`bool`): whether use color or not + scale (`float`): scale of table visualized in console + min_width (`int`): min width used to setup table to visualize + max_width (`int`): max width used to setup table to visualize + color_map (`Optional[ColorMap]`): color_map used to paint table to visualize + + Returns: + table to visualize given tensor sharding. This function + will also visualize the sharding of the tensor without as return. + """ + if not RICH_ENABLED: raise ValueError("`visualize_sharding` requires `rich` to be installed.") - if len(shape) > 2 or len(shape) < 1: - raise ValueError( - "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") + # if len(shape) > 2 or len(shape) < 1: + # raise ValueError( + # "`visualize_sharding` only works for shapes with 1 and 2 dimensions.") slices: dict[tuple[int, ...], set[int]] = {} heights: dict[tuple[int, ...], Optional[float]] = {} @@ -110,7 +125,7 @@ def visualize_sharding(shape: torch.Size, for i in range(len_after_dim_down): slices.setdefault( (i // widths, i % widths), - device_indices_map[i*last_dim_depth:(i + 1)*last_dim_depth]) + device_indices_map[i * last_dim_depth:(i + 1) * last_dim_depth]) elif sharding[-1] == "}": # eg: '{devices=[2,2]0,1,2,3}' # 13 device_list = list(sharding[sharding.index(']') + 1:-1]) @@ -188,7 +203,7 @@ def visualize_sharding(shape: torch.Size, text_color = None padding = (top_padding, right_padding, bottom_padding, left_padding) - padding = tuple(max(x, 0) for x in padding) # type: ignore + padding = tuple(max(x, 0) for x in padding) col.append( rich.padding.Padding( @@ -200,8 +215,13 @@ def visualize_sharding(shape: torch.Size, return table -def visualize_tensor_sharding(ter, **kwargs): +def visualize_tensor_sharding(t, **kwargs): """Visualizes an array's sharding.""" - import torch_xla - sharding = torch_xla._XLAC._get_xla_sharding_spec(ter) - return visualize_sharding(ter.shape, sharding, **kwargs) + if (assert instanceof(t, torch.tensor)): + import torch_xla + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + return visualize_sharding(t.shape, sharding, **kwargs) + elif (assert instanceof(t, XLAShardedTensor)): + import torch_xla + sharding = torch_xla._XLAC._get_xla_sharding_spec(t.global_tensor) + return visualize_sharding(t.global_tensor.shape, sharding, **kwargs)