From 8fc8d57acc4752224077e75ecf432112b40432ec Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Fri, 2 Feb 2024 20:17:46 -0800 Subject: [PATCH] Fix global_device_count(), local_device_count() for single process on CUDA (#6022) --- WORKSPACE | 1 + openxla_patches/gpu_hanging.diff | 36 ++++++++++++ test/cpp/test_replication.cpp | 16 ++++-- test/cpp/test_xla_sharding.cpp | 5 -- ...ntime_gpu.py => test_runtime_multi_gpu.py} | 2 +- test/pjrt/test_runtime_single_proc_gpu.py | 49 +++++++++++++++++ test/pjrt/test_torchrun.py | 4 ++ test/run_tests.sh | 3 +- test/spmd/test_xla_sharding.py | 55 ++++++++++--------- test/test_core_aten_ops.py | 2 - test/test_operations.py | 11 +++- torch_xla/core/xla_model.py | 12 +++- torch_xla/csrc/runtime/pjrt_registry.cc | 22 +++++--- torch_xla/csrc/tensor_util.cpp | 1 + torch_xla/csrc/xla_sharding_util.cpp | 3 +- 15 files changed, 171 insertions(+), 51 deletions(-) create mode 100644 openxla_patches/gpu_hanging.diff rename test/pjrt/{test_runtime_gpu.py => test_runtime_multi_gpu.py} (99%) create mode 100644 test/pjrt/test_runtime_single_proc_gpu.py diff --git a/WORKSPACE b/WORKSPACE index 20aa282ef06..dc4c1eaab47 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,6 +49,7 @@ http_archive( "//openxla_patches:cache_urls.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", + "//openxla_patches:gpu_hanging.diff", "//openxla_patches:quant_dequant_converter.diff", "//openxla_patches:stablehlo_quant_seralization.diff", ], diff --git a/openxla_patches/gpu_hanging.diff b/openxla_patches/gpu_hanging.diff new file mode 100644 index 00000000000..f64d18084b4 --- /dev/null +++ b/openxla_patches/gpu_hanging.diff @@ -0,0 +1,36 @@ +// This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24. +// It can be removed in the next openXLA pin update after 01/26/2024. +diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc +index 0f1818be2..c181f3025 100644 +--- a/xla/service/gpu/gpu_executable.cc ++++ b/xla/service/gpu/gpu_executable.cc +@@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name, + } + } + +- // Maybe join a round of rendezvous after thunk initialization. +- TF_RETURN_IF_ERROR( +- MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); ++ // Maybe join a round of rendezvous after thunk initialization. We do this ++ // only in presence of collective cliques which means that we have collective ++ // operations in the XLA operations that tend to cause deadlocks. ++ if (!collective_cliques.empty()) { ++ TF_RETURN_IF_ERROR( ++ MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); ++ } + + // Prepare parameters for thunks execution. + Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( +diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h +index 51a566b8f..94bab421f 100644 +--- a/xla/service/gpu/thunk.h ++++ b/xla/service/gpu/thunk.h +@@ -175,6 +175,8 @@ class Thunk { + absl::StatusOr GetComm(const NcclCliqueKey& clique_key, + int32_t rank) const; + ++ bool empty() const { return cliques_map_.empty(); } ++ + private: + CliquesMap cliques_map_; + }; diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index 8e4ed3c83eb..0fc20878253 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -46,14 +46,17 @@ void TestSingleReplication( instances.emplace_back(CreateCrsComputation(shape), device_str, all_device_strings, &shape); } - auto compiled_computations = - torch_xla::runtime::GetComputationClient()->Compile(std::move(instances)); + std::vector + compiled_computations = + torch_xla::runtime::GetComputationClient()->Compile( + std::move(instances)); std::vector tensors; for (size_t i = 0; i < device_strings.size(); ++i) { tensors.push_back(at::ones({8, 8}, at::TensorOptions(at::kFloat))); } - auto tensors_data = CreateTensorsData(tensors, device_strings); + std::vector tensors_data = + CreateTensorsData(tensors, device_strings); std::vector> results(device_strings.size()); @@ -75,7 +78,7 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - auto literals = + std::vector literals = torch_xla::runtime::GetComputationClient()->TransferFromDevice( results[i]); ASSERT_EQ(literals.size(), 1); @@ -92,9 +95,12 @@ void TestSingleReplication( class ReplicationTest : public AtenXlaTensorTestBase {}; +// Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU +// device per process instead of relying on threads so we will not run the test +// on GPU. TEST_F(ReplicationTest, TestNSingleReplication) { WithAllDevices( - {XlaDeviceType::TPU, XlaDeviceType::CUDA}, + {XlaDeviceType::TPU}, [&](const std::vector& devices, const std::vector& all_devices) { TestSingleReplication(devices, all_devices); diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index 8cb448f108b..611e29a03f7 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -309,11 +309,6 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) { } TEST_F(XLAShardingTest, CreateTensorsData) { - if (torch_xla::runtime::sys_util::GetEnvString( - torch_xla::runtime::env::kEnvPjRtDevice, "") == "") { - GTEST_SKIP() << "`PJRT_DEVICE` is not set."; - } - std::vector tensors(2); auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat)); xla::Shape tensor_shape = diff --git a/test/pjrt/test_runtime_gpu.py b/test/pjrt/test_runtime_multi_gpu.py similarity index 99% rename from test/pjrt/test_runtime_gpu.py rename to test/pjrt/test_runtime_multi_gpu.py index 207e37a9759..31623aecad9 100644 --- a/test/pjrt/test_runtime_gpu.py +++ b/test/pjrt/test_runtime_multi_gpu.py @@ -19,7 +19,7 @@ @unittest.skipIf(xr.device_type() != "CUDA", f"GPU tests should only run on GPU devices.") -class TestExperimentalPjrtGpu(parameterized.TestCase): +class TestExperimentalPjrtMultiGpu(parameterized.TestCase): def setUp(self): xr.set_device_type('CUDA') diff --git a/test/pjrt/test_runtime_single_proc_gpu.py b/test/pjrt/test_runtime_single_proc_gpu.py new file mode 100644 index 00000000000..7d8192e1f55 --- /dev/null +++ b/test/pjrt/test_runtime_single_proc_gpu.py @@ -0,0 +1,49 @@ +import concurrent.futures +import itertools +import os +import queue +import requests +import unittest +import subprocess + +import numpy as np +import torch +import torch.nn as nn +import torch_xla +import torch_xla.core.xla_env_vars as xenv +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla import runtime as xr +from torch_xla._internal import pjrt +from absl.testing import absltest, parameterized + + +@unittest.skipIf(xr.device_type() != "CUDA", + f"GPU tests should only run on GPU devices.") +class TestExperimentalSingleProcPjrtGpu(parameterized.TestCase): + + @classmethod + def setUpClass(cls): + command = 'nvidia-smi --list-gpus | wc -l' + result = subprocess.run( + command, + capture_output=True, + shell=True, + check=True, + text=True, + ) + cls.num_cuda_devices = int(result.stdout) + + def test_num_local_devices(self): + self.assertLen(xm.get_xla_supported_devices(), + xr.addressable_device_count()) + self.assertEqual(self.num_cuda_devices, xr.addressable_device_count()) + + def test_num_global_devices(self): + self.assertLen(torch_xla._XLAC._xla_get_all_devices(), + xr.global_device_count()) + self.assertEqual(self.num_cuda_devices, xr.global_device_count()) + + +if __name__ == '__main__': + absltest.main() diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index 9a3fce79499..1a8f9d93f60 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -16,6 +16,10 @@ def setUp(self): def tearDown(self) -> None: dist.destroy_process_group() + def test_addressable_device_count(self): + devices_per_process = xr.addressable_device_count() + self.assertEqual(devices_per_process, 1) + def test_all_gather(self): dist_world_size = xu.getenv_as('WORLD_SIZE', int) devices_per_thread = xr.addressable_device_count() diff --git a/test/run_tests.sh b/test/run_tests.sh index 1553d53e409..84debb84b25 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -166,7 +166,8 @@ function run_xla_op_tests1 { run_test "$CDIR/test_hlo_metadata.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/pjrt/test_runtime.py" - run_test "$CDIR/pjrt/test_runtime_gpu.py" + run_test "$CDIR/pjrt/test_runtime_single_proc_gpu.py" + run_test "$CDIR/pjrt/test_runtime_multi_gpu.py" run_test "$CDIR/pjrt/test_runtime_multi_cpu.py" run_test "$CDIR/pjrt/test_internal_tpu.py" run_test "$CDIR/pjrt/test_ddp.py" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 08143cf9402..f42289f8d26 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -207,10 +207,13 @@ def test_xla_sharding_type(self): t = torch.randn(10, 20).to(xm.xla_device()) self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None) - x_dim = 2 if self.n_devices % 4 == 0 else 1 + x_dim = 2 if self.n_devices >= 2 else 1 + # if self.n_devices==4, mesh=(2,2) + # if self.n_devices==2, mesh=(2,1) + # if self.n_devices==1, mesh=(1,1) mesh = self._get_mesh((x_dim, self.n_devices // x_dim)) xt = xs.mark_sharding(t, mesh, (0, 1)) - if self.n_devices > 1: + if self.n_devices >= 2: self.assertEqual(xt.sharding_type, xs.ShardingType.TILED) else: self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) @@ -221,7 +224,7 @@ def test_xla_sharding_type(self): xs.clear_sharding(t) xt = xs.mark_sharding(t, mesh, (None, 1)) - if self.n_devices > 1: + if mesh.get_logical_mesh().shape[1] > 1: self.assertEqual(xt.sharding_type, xs.ShardingType.PARTIAL) else: self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED) @@ -339,14 +342,13 @@ def test_mark_sharding_partial(self): mesh = self._get_mesh((z_dim, self.n_devices // z_dim)) xt1 = xs.mark_sharding(t1, mesh, (0, None)) - # partial replication requires >1 devices; otherwise, it's replicated. - if self.n_devices > 1: + # partial replication requires >= 4 devices; otherwise, it's replicated. + if self.n_devices >= 4: # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. - self.assertTrue('last_tile_dim_replicate' in - torch_xla._XLAC._get_xla_sharding_spec(t1)) - self.assertTrue('[%d,1,%d]' % - (z_dim, self.n_devices // - z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertIn('last_tile_dim_replicate', + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertIn('[%d,1,%d]' % (z_dim, self.n_devices // z_dim), + torch_xla._XLAC._get_xla_sharding_spec(t1)) # replicated group should share the same data content. if (self.n_devices // z_dim) > 1: shards = xt1.local_shards @@ -381,14 +383,13 @@ def test_mark_sharding_partial_unordered(self): mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim)) xt1 = xs.mark_sharding(t1, mesh, (1, None, 0)) - # partial replication requires >1 devices; otherwise, it's replicated. - if self.n_devices > 1: + # partial replication requires >= 4 devices; otherwise, it's replicated. + if self.n_devices >= 4: # xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way. - self.assertTrue('last_tile_dim_replicate' in - torch_xla._XLAC._get_xla_sharding_spec(t1)) - self.assertTrue('[1,1,%d,%d]' % - (z_dim, self.n_devices // - z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertIn('last_tile_dim_replicate', + torch_xla._XLAC._get_xla_sharding_spec(t1)) + self.assertIn('[1,1,%d,%d]' % (z_dim, self.n_devices // z_dim), + torch_xla._XLAC._get_xla_sharding_spec(t1)) # replicated group should share the same data content. if (self.n_devices // z_dim) > 1: shards = xt1.local_shards @@ -485,14 +486,14 @@ def test_partial_replication_addmm(self): xs.mark_sharding(xw, mesh, (None, 1)) # Check if the partial replication annotations are passed to the compiler. - # Note that partial replication requires >1 devices; otherwise, it's replicated. - if self.n_devices > 1: - self.assertTrue('last_tile_dim_replicate' in - torch_xla._XLAC._get_xla_sharding_spec(xx)) - self.assertTrue('last_tile_dim_replicate' in - torch_xla._XLAC._get_xla_sharding_spec(xw)) + # Note that partial replication requires >= 4 devices; otherwise, it's replicated. + if self.n_devices >= 4: + self.assertIn('last_tile_dim_replicate', + torch_xla._XLAC._get_xla_sharding_spec(xx)) + self.assertIn('last_tile_dim_replicate', + torch_xla._XLAC._get_xla_sharding_spec(xw)) actual = (xx @ xw + xb).cpu() - self.assertTrue(torch.allclose(expected, actual)) + self.assertTrue(torch.allclose(expected, actual, atol=1e-5)) def test_clear_sharding(self): xt = torch.randn(2, 4, 8, 16).to(xm.xla_device()) @@ -723,10 +724,14 @@ def test_2d_tensor_3d_mesh(self): # Meaningful test for higher-order mesh with extra replication # requires multiple devices. Otherwise, this should defaults back to # full replication. - if self.n_devices > 1: + if self.n_devices >= 4: mesh = self._get_mesh((2, self.n_devices // 2, 1)) xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) sharding_annotation = 'sharding={devices=[1,%d,2]' % (self.n_devices // 2) + elif self.n_devices == 2: + mesh = self._get_mesh((2, 1, 1)) + xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) + sharding_annotation = "sharding={replicated}" else: mesh = self._get_mesh((1, 1, 1)) xs.mark_sharding(t1, mesh, partition_spec=(2, 1)) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 220883e6b4e..b1262fdd8cb 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -16,8 +16,6 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True): output2_cpu = output2.detach().cpu() if output2_cpu.dtype != output1.dtype: output2_cpu = output2_cpu.to(output1.dtype) - # import pdb - # pdb.set_trace() testcase.assertTrue( torch.allclose( output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan)) diff --git a/test/test_operations.py b/test/test_operations.py index e67722dfec2..1499d5f1c75 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -237,9 +237,14 @@ def forward(self, x): return F.log_softmax(x, dim=1) +@unittest.skipIf( + xr.device_type() == 'CUDA', + 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.' +) class TestParallelTensorMNIST(test_utils.XlaTestCase): def test(self): + # devices=['xla:0', 'xla:1', 'xla:2', 'xla:3'] for example. devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=8) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) @@ -267,6 +272,10 @@ def loop_fn(model, loader, device, context): model_parallel(loop_fn, train_loader) +@unittest.skipIf( + xr.device_type() == 'CUDA', + 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.' +) class TestParallelTensorResnet18(test_utils.XlaTestCase): def test(self): @@ -1247,8 +1256,6 @@ def test_fn(a): self.runAtenTest(torch.zeros([4, 4]), test_fn) - @unittest.skipIf(xr.device_type() == 'GPU', - "This test fails only on GPU with 07/05 XLA pin update.") def test_stack_pred(self): def test_fn(a): diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index b3f7b4c9ad1..28622fdafc2 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -90,7 +90,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None): that kind. Returns: - The list of device strings. + The list of device strings such as ['xla:0', 'xla:1', ...] """ # TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support # multiple device types. @@ -220,6 +220,14 @@ def _xla_real_device(device): def xla_real_devices(devices: Optional[List[torch.device]] = None): + """Returns the real devices' name. + + Args: + devices: The list of torch devices such as ['xla:0', 'xla:1']. + + Returns: + A list of real devices' name such as ['CUDA:0', 'CUDA:1']. + """ if not devices: devices = get_xla_supported_devices() @@ -260,6 +268,7 @@ def xla_replication_devices(local_devices): format(len(local_devices), len(kind_devices))) replication_devices = [] for device in torch_xla._XLAC._xla_get_all_devices(): + # device is like 'CUDA:0' xdev = parse_xla_device(device) if not xdev: raise RuntimeError('Invalid device format: {}'.format(device)) @@ -287,6 +296,7 @@ def set_replication(device, devices): devctx = _get_device_context(device=device) devices = [str(x) for x in devices] if devices: + # sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3'] replication_devices = xla_replication_devices(devices) torch_xla._XLAC._xla_set_replication_devices(replication_devices) devctx.device_index = devices.index(device) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 0faa164a09d..f2465d68238 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -130,20 +130,28 @@ InitializePjRt(const std::string& device_type) { int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_process_rank << ", num_nodes=" << global_world_size + << ", spmd case=" << sys_util::GetEnvBool("XLA_USE_SPMD", false) + << ", PJRT_LOCAL_PROCESS_RANK=" + << sys_util::GetEnvString(env::kEnvPjRtLocalRank, "") + << ", RANK=" << sys_util::GetEnvString("RANK", "") + << ", LOCAL_WORLD_SIZE=" + << sys_util::GetEnvString("LOCAL_WORLD_SIZE", "") + << ", WORLD_SIZE=" << sys_util::GetEnvString("WORLD_SIZE", ""); std::optional> allowed_devices; - if (!spmd) { + if (local_world_size > 1) { allowed_devices = std::set{local_process_rank}; } std::shared_ptr kv_store; if (global_world_size > 1) { // Use the distributed key-value store from DistributedRuntimeClient. + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); coordinator = std::make_unique( global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = @@ -151,8 +159,6 @@ InitializePjRt(const std::string& device_type) { kv_store = xla::GetDistributedKeyValueStore(distributed_client, /*key_prefix=*/"gpu:"); } - TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size; xla::GpuClientOptions options; options.allocator_config = GetGpuAllocatorConfig(); diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 25af9ff2269..2e4f280ba66 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -733,6 +733,7 @@ std::vector CreateTensorsData( if (static_cast(device.type()) == XlaDeviceType::SPMD) { // GetLocalDevices returns the list of local devices specified by their // global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]). + std::vector local_devices = runtime::GetComputationClient()->GetLocalDevices(); // Shards the input tensors with padding, to split evenly. diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index c0481e4bca2..bd838528c2e 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -273,7 +273,8 @@ xla::OpSharding ShardingUtil::CreateOpSharding( } } TF_VLOG(INFO) << "OpSharding (ShardingType: " << sharding_type << "):\n" - << sharding.DebugString(); + << sharding.DebugString() + << ", sharding.type()=" << sharding.type(); return sharding; }