diff --git a/test/spmd/test_xla_spmd_python_api_interaction.py b/test/spmd/test_xla_spmd_python_api_interaction.py index 8ea4db3e051..8f9c319ac54 100644 --- a/test/spmd/test_xla_spmd_python_api_interaction.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -3,7 +3,9 @@ import sys import torch +import torch.distributed as dist import torch_xla +import torch_xla.distributed.xla_backend import torch_xla.core.xla_model as xm from torch_xla import runtime as xr from torch_xla.amp import autocast @@ -132,6 +134,19 @@ def test_xla_autocast_api(self): self.assertTrue(t3.dtype == expected_dtype) +class BasicDistributedTest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + xr.use_spmd() + return super().setUpClass() + + def test_xla_backend(self): + # XLA backend is not supported with SPMD + with self.assertRaises(AssertionError): + dist.init_process_group('xla', init_method='xla://') + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index d448b09dd84..aa2769cb94d 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,6 +1,7 @@ import torch import torch.distributed as dist import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr from torch_xla._internal import rendezvous import logging import os @@ -8,6 +9,8 @@ def _create_xla_process_group(prefix_store, rank, size, timeout): + assert not xr.is_spmd( + ), "XLA backend is not supported with SPMD. Please use a CPU process group instead." return ProcessGroupXla(prefix_store, rank, size, timeout)