From 82223ec6256e4507c9b2d827457b230d1e5d0b3f Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Fri, 22 Nov 2024 01:37:39 +0000 Subject: [PATCH] [scan] Add a test under SPMD Verifies that the GSPMD sharding annotation propagation pass can propagate through a While op and through the Body computation just fine. --- test/run_tests.sh | 1 + test/scan/test_scan.py | 6 +-- test/scan/test_scan_spmd.py | 87 +++++++++++++++++++++++++++++++++++++ test/tpu/run_tests.sh | 1 + 4 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 test/scan/test_scan_spmd.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 543bc5f8403..0fca0bbcfe5 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -209,6 +209,7 @@ function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" run_test "$CDIR/scan/test_scan.py" + run_save_tensor_hlo "$CDIR/scan/test_scan_spmd.py" run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index f344386c852..47fd7fa611a 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -57,9 +57,6 @@ def compare_pytree(self, expected_pytree, actual_pytree): flat_actual_pytree = [x for x in flat_actual_pytree if x is not None] super().compareResults(flat_expected_pytree, flat_actual_pytree) - -class ScanTest(TestBase): - def run_test(self, fn, init: PyTree, @@ -104,6 +101,9 @@ def run_test(self, return final_carry, ys + +class ScanTest(TestBase): + def test_scan_simple(self): """This test uses `scan` to implement `torch.cumsum`.""" diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py new file mode 100644 index 00000000000..a087867ae2f --- /dev/null +++ b/test/scan/test_scan_spmd.py @@ -0,0 +1,87 @@ +import sys +import os +import re +import unittest +from pathlib import Path + +import torch +import torch_xla +from torch_xla.experimental.scan import scan +from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, Mesh +import torch_xla.runtime as xr + + +class ScanSpmdTest(unittest.TestCase): + + def setUp(self): + # Set up a simple SPMD mesh for these tests. + num_devices = xr.global_runtime_device_count() + mesh_shape = (num_devices,) + self.spmd_mesh = Mesh(list(range(num_devices)), mesh_shape, ('model',)) + set_global_mesh(self.spmd_mesh) + xr.use_spmd() + self.device = torch_xla.device() + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Multiple devices required") + def test_scan_cumsum(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + save_file = os.getenv('XLA_SAVE_TENSORS_FILE') + save_format = os.getenv('XLA_SAVE_TENSORS_FMT') + if not save_file: + assert False, "This test should be run with XLA_SAVE_TENSORS_FILE" + save_file += '.0' + assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo" + + # Remove the save file (if exists) to start from a clean slate + try: + os.remove(save_file) + except: + pass + + def fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.zeros(1024, requires_grad=True, device=self.device) + mark_sharding(init, self.spmd_mesh, ('model',)) + xs = torch.randn([8, 1024], requires_grad=True, device=self.device) + mark_sharding(xs, self.spmd_mesh, (None, 'model')) + final_carry, ys = scan(fn, init, xs) + torch_xla.sync() + + # Check the HLO + hlo_content = Path(save_file).read_text() + lines = hlo_content.splitlines() + + # There should be only one graph. + assert len(re.findall('END_GRAPH', hlo_content)) == 1 + + # The graph should have output sharding. + begin_magic = '#OUTPUT_SHARDING_BEGIN' + end_magic = '#OUTPUT_SHARDING_END' + self.assertIn(end_magic, str(lines[-2])) + + # Extract the output sharding descriptions. + start = hlo_content.find(begin_magic) + assert start != -1 + start += len(begin_magic) + end = hlo_content.find(end_magic, start) + assert end != -1 + end -= len(end_magic) + output_sharding = hlo_content[start:end].strip().splitlines() + + # There should be 4 tensors in output sharding: init, xs, final_carry, ys. + self.assertEqual(len(output_sharding), 4) + for sharding in output_sharding: + self.assertIn('devices=[', sharding) + + # Remove the save file again to avoid cluttering other tests. + os.remove(save_file) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 8d5e74bde03..605dc47240e 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -26,6 +26,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py python3 test/scan/test_scan.py +XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" python3 test/scan/test_scan_spmd.py python3 test/scan/test_scan_layers.py python3 test/test_pallas.py -v python3 test/test_pallas_spmd.py