diff --git a/test/run_tests.sh b/test/run_tests.sh index 543bc5f8403..a3a8c74cedd 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -209,6 +209,7 @@ function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" run_test "$CDIR/scan/test_scan.py" + run_test "$CDIR/scan/test_scan_spmd.py" run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py new file mode 100644 index 00000000000..cde7fb7bb65 --- /dev/null +++ b/test/scan/test_scan_spmd.py @@ -0,0 +1,51 @@ +import sys +import unittest + +import torch +import torch_xla +from torch_xla.experimental.scan import scan +from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh +import torch_xla.runtime as xr + + +class ScanSpmdTest(unittest.TestCase): + + def setUp(self): + # Activate SPMD + xr.use_spmd() + + # Set up a simple SPMD mesh for these tests. + self.spmd_mesh = get_1d_mesh(axis_name="model") + set_global_mesh(self.spmd_mesh) + self.device = torch_xla.device() + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Multiple devices required") + def test_scan_cumsum(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.zeros(1024, requires_grad=True, device=self.device) + mark_sharding(init, self.spmd_mesh, ('model',)) + xs = torch.randn([8, 1024], requires_grad=True, device=self.device) + mark_sharding(xs, self.spmd_mesh, (None, 'model')) + final_carry, ys = scan(fn, init, xs) + torch_xla.sync() + + # Check the input and output sharding. Note that we do this after + # `torch_xla.sync()` to ensure the output tensors are materialized and + # have taken on sharding annotations propagated by the compiler. + for tensor in [init, xs, final_carry, ys]: + self.assertIn('ShardingSpec: {devices=[', + torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) + self.assertIn('OpSharding: {devices=[', + torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6ad06b07740..fb5cdd51c8e 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py python3 test/scan/test_scan.py +python3 test/scan/test_scan_spmd.py python3 test/scan/test_scan_layers.py python3 test/test_pallas.py -v python3 test/test_pallas_spmd.py