diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index bfa7d8eff78..b790658993d 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -46,14 +46,17 @@ void TestSingleReplication( instances.emplace_back(CreateCrsComputation(shape), device_str, all_device_strings, &shape); } - std::vector compiled_computations = - torch_xla::runtime::GetComputationClient()->Compile(std::move(instances)); + std::vector + compiled_computations = + torch_xla::runtime::GetComputationClient()->Compile( + std::move(instances)); std::vector tensors; for (size_t i = 0; i < device_strings.size(); ++i) { tensors.push_back(at::ones({8, 8}, at::TensorOptions(at::kFloat))); } - std::vector tensors_data = CreateTensorsData(tensors, device_strings); + std::vector tensors_data = + CreateTensorsData(tensors, device_strings); std::vector> results(device_strings.size()); diff --git a/test/pjrt/test_runtime.py b/test/pjrt/test_runtime.py index 2b91f7c149b..9e2b05d6a9c 100644 --- a/test/pjrt/test_runtime.py +++ b/test/pjrt/test_runtime.py @@ -71,13 +71,11 @@ def test_num_local_devices(self): xr.addressable_device_count()) if xr.device_type() == 'CUDA': self.assertEqual(self.num_cuda_devices, xr.addressable_device_count()) - def test_num_global_devices(self): self.assertLen(torch_xla._XLAC._xla_get_all_devices(), xr.global_device_count()) if xr.device_type() == 'CUDA': - print('test_num_global_devices is run for cuda. self.num_cuda_devices=', self.num_cuda_devices) self.assertEqual(self.num_cuda_devices, xr.global_device_count()) def test_world_size(self): diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 9a3fce79499..cd9e746f174 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -15,6 +15,10 @@ def setUp(self): def tearDown(self) -> None: dist.destroy_process_group() + + def test_addressable_device_count(self): + devices_per_process = xr.addressable_device_count() + self.assertEqual(devices_per_process, 1) def test_all_gather(self): dist_world_size = xu.getenv_as('WORLD_SIZE', int) diff --git a/test/test_operations.py b/test/test_operations.py index e57629f0563..3ab7ddc9c5a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -223,8 +223,11 @@ def forward(self, x): x = self.fc2(x) return F.log_softmax(x, dim=1) -@unittest.skipIf(xr.device_type() == 'CUDA', - 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.') + +@unittest.skipIf( + xr.device_type() == 'CUDA', + 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.' +) class TestParallelTensorMNIST(test_utils.XlaTestCase): def test(self): @@ -255,8 +258,11 @@ def loop_fn(model, loader, device, context): model_parallel = dp.DataParallel(XlaMNIST, device_ids=devices) model_parallel(loop_fn, train_loader) -@unittest.skipIf(xr.device_type() == 'CUDA', - 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.') + +@unittest.skipIf( + xr.device_type() == 'CUDA', + 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.' +) class TestParallelTensorResnet18(test_utils.XlaTestCase): def test(self): diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 1c9139960d0..6123b8fd889 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -101,10 +101,9 @@ InitializePjRt(const std::string& device_type) { bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); if (!spmd) { allowed_devices = std::set{local_process_rank}; - } - else if (global_world_size > 1) { + } else if (global_world_size > 1) { allowed_devices = - std::make_optional>(std::set{local_process_rank}); + std::make_optional>(std::set{local_process_rank}); // Use the XlaCoordinator as the distributed key-value store. coordinator = std::make_unique( global_process_rank, global_world_size, master_addr, port);