diff --git a/test/run_tests.sh b/test/run_tests.sh index 453abb5e4692..c3fd72572592 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -203,6 +203,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" run_test "$CDIR/spmd/test_dynamo_spmd.py" + run_test "$CDIR/spmd/test_spmd_debugging.py" run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py new file mode 100644 index 000000000000..02139c55a5a2 --- /dev/null +++ b/test/spmd/test_spmd_debugging.py @@ -0,0 +1,809 @@ +import sys + +import unittest +from unittest.mock import patch +import math +import numpy as np +import os +import io +import rich + +import torch +import torch_xla +import torch_xla.runtime as xr +import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs +from torch_xla.distributed.spmd import XLAShardedTensor +from torch_xla.distributed.spmd import Mesh + +import test_xla_sharding_base + + +class DebuggingSpmdTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + super().setUpClass() + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_debugging_spmd_single_host_tiled_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + t = torch.randn(8, 4, device=device) + partition_spec = (0, 1) + xs.mark_sharding(t, mesh, partition_spec) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 0', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 1', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 2', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 3', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 4', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 5', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 6', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 7', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_single_host_partial_replication_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [0, 1, 2, 3]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [4, 5, 6, 7]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_single_host_replicated_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (2, num_devices // 2) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec_replicated = (None, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, partition_spec_replicated) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align( + 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_debugging_spmd_single_host_tiled_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + t = torch.randn(8, 4, device=device) + partition_spec = (0, 1) + xs.mark_sharding(t, mesh, partition_spec) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [0]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_single_host_partial_replication_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec = (0, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, (0, None)) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [0]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_single_host_replicated_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding + device = xm.xla_device() + num_devices = self.n_devices + mesh_shape = (1, num_devices) + device_ids = np.array(range(num_devices)) + mesh = self._get_mesh(mesh_shape) + + partition_spec_replicated = (None, None) + t = torch.randn(8, 32, device=device) + xs.mark_sharding(t, mesh, partition_spec_replicated) + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + generated_table = visualize_tensor_sharding(t) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [0]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + +# Multi-host tests +# e.g.: sharding={devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15} +# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate} +# e.g.: sharding={replicated} + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_debugging_spmd_multi_host_tiled_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 0', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 4', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 8', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 12', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 2', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 6', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 10', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 14', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU 1', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 5', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 9', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 13', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 3', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 7', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 11', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('TPU 15', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_multi_host_partial_replication_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [0, 1]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [4, 5]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [8, 9]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [12, 13]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [2, 3]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [6, 7]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [10, 11]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('TPU [14, 15]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'CPU'), + f"Requires PJRT_DEVICE set to `TPU`.") + def test_multi_host_replicated_tpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{replicated}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align( + 'TPU [0, 1, 2, 3, 4, 5, 6, 7]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_debugging_spmd_multi_host_tiled_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU 0', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 4', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 8', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 12', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 2', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 6', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 10', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 14', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU 1', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 5', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 9', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 13', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 3', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 7', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 11', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + col.append( + rich.padding.Padding( + rich.align.Align('CPU 15', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_multi_host_partial_replication_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [0, 1]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [4, 5]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [8, 9]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [12, 13]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [2, 3]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [6, 7]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [10, 11]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [14, 15]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + + @unittest.skipIf( + not xr.using_pjrt() or + xu.getenv_as(xenv.PJRT_DEVICE, str) in ("GPU", 'CUDA', 'ROCM', 'TPU'), + f"Requires PJRT_DEVICE set to `CPU`.") + def test_multi_host_replicated_cpu(self): + from torch_xla.distributed.spmd.debugging import visualize_sharding + sharding = '{replicated}' + generated_table = visualize_sharding(sharding) + console = rich.console.Console() + with console.capture() as capture: + console.print(generated_table) + output = capture.get() + + color = None + text_color = None + fake_table = rich.table.Table( + show_header=False, + show_lines=True, + padding=0, + highlight=True, + pad_edge=False, + box=rich.box.SQUARE) + col = [] + col.append( + rich.padding.Padding( + rich.align.Align('CPU [0]', "center", vertical="middle"), + (1, 1, 1, 1), + style=rich.style.Style(bgcolor=color, color=text_color))) + fake_table.add_row(*col) + fake_console = rich.console.Console() + with fake_console.capture() as fake_capture: + fake_console.print(fake_table) + fake_output = fake_capture.get() + assert output == fake_output + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 99d59d286dcd..e727953ddc43 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -42,6 +42,7 @@ spec: - -cxe - | pip install expecttest + pip install rich python3 /src/pytorch/xla/test/test_operations.py -v python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py @@ -55,6 +56,7 @@ spec: XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py + python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py volumeMounts: - mountPath: /dev/shm name: dshm diff --git a/torch_xla/distributed/spmd/__init__.py b/torch_xla/distributed/spmd/__init__.py index 3cd50e1e7c05..7f494b74c9d3 100644 --- a/torch_xla/distributed/spmd/__init__.py +++ b/torch_xla/distributed/spmd/__init__.py @@ -5,8 +5,17 @@ from .api import xla_distribute_tensor, xla_distribute_module __all__ = [ - "XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType", - "ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding", - "wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module", - "xla_patched_nn_linear_forward" + "XLAShard", + "XLAShardedTensor", + "Mesh", + "HybridMesh", + "ShardingType", + "ShardingSpec", + "XLAPatchedLinear", + "mark_sharding", + "clear_sharding", + "wrap_if_sharded", + "xla_distribute_tensor", + "xla_distribute_module", + "xla_patched_nn_linear_forward", ] diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py new file mode 100644 index 000000000000..508d8cbb371c --- /dev/null +++ b/torch_xla/distributed/spmd/debugging.py @@ -0,0 +1,166 @@ +from collections.abc import Sequence +import functools +import string +import sys +from typing import Any, Callable, Optional, Union +import weakref + +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr +from torch_xla.distributed.spmd.xla_sharding import * +import torch_xla.utils.utils as xu +import torch_xla.core.xla_env_vars as xenv +from torch_xla.distributed.spmd import XLAShardedTensor + +try: + import rich + import rich.align + import rich.box + import rich.console + import rich.padding + import rich.style + import rich.table + RICH_ENABLED = True +except: + RICH_ENABLED = False + + +def visualize_sharding(sharding: str, + use_color: bool = True, + scale: float = 1., + min_width: int = 9, + max_width: int = 80): + """Visualizes a ``Sharding`` using ``rich``. + Args: + sharding (`str`): sharding of given tensor with SPMD + use_color (`bool`): whether use color or not + scale (`float`): scale of table visualized in console + min_width (`int`): min width used to setup table to visualize + max_width (`int`): max width used to setup table to visualize + Returns: + table to visualize given tensor sharding. This function + will also visualize the sharding of the tensor without as return. + """ + + if not RICH_ENABLED: + raise ValueError("`visualize_sharding` requires `rich` to be installed.") + + slices: dict[tuple[int, ...], set[int]] = {} + heights: dict[tuple[int, ...], Optional[float]] = {} + widths: dict[tuple[int, ...], float] = {} + + if len(sharding) >= 0: + # sharding is longer than 0 + # eg: '{devices=[2,2]0,1,2,3}' + # eg: '{replicated}' + # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' + if sharding == '{replicated}' or len(sharding) == 0: + heights = 1 + widths = 1 + num_devices = xr.global_runtime_device_count() + device_ids = list(range(num_devices)) + slices.setdefault((0, 0), device_ids) + else: + sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1] + device_list_original = sharding.split(' last_tile_dim_replicate') + if len(device_list_original) == 2 and device_list_original[1] == '}': + try: + device_list_original_first = device_list_original[0] + device_list = device_list_original_first[device_list_original_first. + index(']') + 1:] + device_indices_map = [int(s) for s in device_list.split(',')] + heights = int(sharding_spac[1]) + widths = int(sharding_spac[3]) + last_dim_depth = int(sharding_spac[5]) + devices_len = len(device_indices_map) + len_after_dim_down = devices_len // last_dim_depth + for i in range(len_after_dim_down): + slices.setdefault( + (i // widths, i % widths), + device_indices_map[i * last_dim_depth:(i + 1) * last_dim_depth]) + except: + raise ValueError("sharding ", sharding, + " is not organized as expected") + else: + # eg: '{devices=[2,2]0,1,2,3}' + try: + assert device_list_original[0][-1] == '}' + except: + raise ValueError("sharding ", sharding, + " is not organized as expected") + try: + device_list_original_first = device_list_original[0] + device_list = device_list_original_first[device_list_original_first. + index(']') + 1:-1] + device_indices_map = [int(i) for i in device_list.split(',')] + heights = int(sharding_spac[1]) + widths = int(sharding_spac[3]) + devices_len = len(device_indices_map) + for i in range(devices_len): + slices.setdefault((i // widths, i % widths), device_indices_map[i]) + except: + raise ValueError("sharding ", sharding, + " is not organized as expected") + else: + raise ValueError("sharding length should >= 0") + + num_rows = heights + num_cols = widths + + console = rich.console.Console(width=max_width) + use_color = use_color and console.color_system is not None + + base_height = int(3 * scale) + aspect_ratio = 1 + base_width = int(base_height * aspect_ratio) + height_to_width_ratio = 1.5 + + pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str) + device_kind = pjrt_device + + table = rich.table.Table( + show_header=False, + show_lines=not use_color, + padding=0, + highlight=not use_color, + pad_edge=False, + box=rich.box.SQUARE if not use_color else None) + for i in range(num_rows): + col = [] + for j in range(num_cols): + entry = f"{device_kind} " + str(slices[i, j]) + width, maybe_height = widths, heights + width = int(width * base_width * height_to_width_ratio) + if maybe_height is None: + height = 1 + else: + height = int(maybe_height * base_height) + width = min(max(width, min_width), max_width) + + color = None + text_color = None + + padding = (1, 1, 1, 1) + + col.append( + rich.padding.Padding( + rich.align.Align(entry, "center", vertical="middle"), + padding, + style=rich.style.Style(bgcolor=color, color=text_color))) + table.add_row(*col) + console.print(table, end='\n\n') + return table + + +def visualize_tensor_sharding(t, **kwargs): + """Visualizes an array's sharding.""" + + # XLAShardedTensor is-a torch.Tensor + def maybe_unwrap(t: torch.Tensor) -> torch.Tensor: + return t.global_tensor if isinstance(t, XLAShardedTensor) else t + + sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t)) + return visualize_sharding(sharding, **kwargs)