Skip to content

Commit

Permalink
Fix global_device_count(), local_device_count() for single process on…
Browse files Browse the repository at this point in the history
… CUDA (#6022)
  • Loading branch information
vanbasten23 authored Feb 3, 2024
1 parent 3e68409 commit 8fc8d57
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 51 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_hanging.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
Expand Down
36 changes: 36 additions & 0 deletions openxla_patches/gpu_hanging.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24.
// It can be removed in the next openXLA pin update after 01/26/2024.
diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc
index 0f1818be2..c181f3025 100644
--- a/xla/service/gpu/gpu_executable.cc
+++ b/xla/service/gpu/gpu_executable.cc
@@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name,
}
}

- // Maybe join a round of rendezvous after thunk initialization.
- TF_RETURN_IF_ERROR(
- MaybeRendezvousAfterInitialization(run_options, thunks_initialized));
+ // Maybe join a round of rendezvous after thunk initialization. We do this
+ // only in presence of collective cliques which means that we have collective
+ // operations in the XLA operations that tend to cause deadlocks.
+ if (!collective_cliques.empty()) {
+ TF_RETURN_IF_ERROR(
+ MaybeRendezvousAfterInitialization(run_options, thunks_initialized));
+ }

// Prepare parameters for thunks execution.
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h
index 51a566b8f..94bab421f 100644
--- a/xla/service/gpu/thunk.h
+++ b/xla/service/gpu/thunk.h
@@ -175,6 +175,8 @@ class Thunk {
absl::StatusOr<NcclComm::Lock> GetComm(const NcclCliqueKey& clique_key,
int32_t rank) const;

+ bool empty() const { return cliques_map_.empty(); }
+
private:
CliquesMap cliques_map_;
};
16 changes: 11 additions & 5 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ void TestSingleReplication(
instances.emplace_back(CreateCrsComputation(shape), device_str,
all_device_strings, &shape);
}
auto compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(std::move(instances));
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(
std::move(instances));

std::vector<at::Tensor> tensors;
for (size_t i = 0; i < device_strings.size(); ++i) {
tensors.push_back(at::ones({8, 8}, at::TensorOptions(at::kFloat)));
}
auto tensors_data = CreateTensorsData(tensors, device_strings);
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, device_strings);

std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
results(device_strings.size());
Expand All @@ -75,7 +78,7 @@ void TestSingleReplication(
counter.Wait();

for (size_t i = 0; i < results.size(); ++i) {
auto literals =
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
results[i]);
ASSERT_EQ(literals.size(), 1);
Expand All @@ -92,9 +95,12 @@ void TestSingleReplication(

class ReplicationTest : public AtenXlaTensorTestBase {};

// Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU
// device per process instead of relying on threads so we will not run the test
// on GPU.
TEST_F(ReplicationTest, TestNSingleReplication) {
WithAllDevices(
{XlaDeviceType::TPU, XlaDeviceType::CUDA},
{XlaDeviceType::TPU},
[&](const std::vector<torch::lazy::BackendDevice>& devices,
const std::vector<torch::lazy::BackendDevice>& all_devices) {
TestSingleReplication(devices, all_devices);
Expand Down
5 changes: 0 additions & 5 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,6 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
}

TEST_F(XLAShardingTest, CreateTensorsData) {
if (torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
}

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalPjrtGpu(parameterized.TestCase):
class TestExperimentalPjrtMultiGpu(parameterized.TestCase):

def setUp(self):
xr.set_device_type('CUDA')
Expand Down
49 changes: 49 additions & 0 deletions test/pjrt/test_runtime_single_proc_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import concurrent.futures
import itertools
import os
import queue
import requests
import unittest
import subprocess

import numpy as np
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import runtime as xr
from torch_xla._internal import pjrt
from absl.testing import absltest, parameterized


@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalSingleProcPjrtGpu(parameterized.TestCase):

@classmethod
def setUpClass(cls):
command = 'nvidia-smi --list-gpus | wc -l'
result = subprocess.run(
command,
capture_output=True,
shell=True,
check=True,
text=True,
)
cls.num_cuda_devices = int(result.stdout)

def test_num_local_devices(self):
self.assertLen(xm.get_xla_supported_devices(),
xr.addressable_device_count())
self.assertEqual(self.num_cuda_devices, xr.addressable_device_count())

def test_num_global_devices(self):
self.assertLen(torch_xla._XLAC._xla_get_all_devices(),
xr.global_device_count())
self.assertEqual(self.num_cuda_devices, xr.global_device_count())


if __name__ == '__main__':
absltest.main()
4 changes: 4 additions & 0 deletions test/pjrt/test_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def setUp(self):
def tearDown(self) -> None:
dist.destroy_process_group()

def test_addressable_device_count(self):
devices_per_process = xr.addressable_device_count()
self.assertEqual(devices_per_process, 1)

def test_all_gather(self):
dist_world_size = xu.getenv_as('WORLD_SIZE', int)
devices_per_thread = xr.addressable_device_count()
Expand Down
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/pjrt/test_runtime.py"
run_test "$CDIR/pjrt/test_runtime_gpu.py"
run_test "$CDIR/pjrt/test_runtime_single_proc_gpu.py"
run_test "$CDIR/pjrt/test_runtime_multi_gpu.py"
run_test "$CDIR/pjrt/test_runtime_multi_cpu.py"
run_test "$CDIR/pjrt/test_internal_tpu.py"
run_test "$CDIR/pjrt/test_ddp.py"
Expand Down
55 changes: 30 additions & 25 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,13 @@ def test_xla_sharding_type(self):
t = torch.randn(10, 20).to(xm.xla_device())
self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None)

x_dim = 2 if self.n_devices % 4 == 0 else 1
x_dim = 2 if self.n_devices >= 2 else 1
# if self.n_devices==4, mesh=(2,2)
# if self.n_devices==2, mesh=(2,1)
# if self.n_devices==1, mesh=(1,1)
mesh = self._get_mesh((x_dim, self.n_devices // x_dim))
xt = xs.mark_sharding(t, mesh, (0, 1))
if self.n_devices > 1:
if self.n_devices >= 2:
self.assertEqual(xt.sharding_type, xs.ShardingType.TILED)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
Expand All @@ -221,7 +224,7 @@ def test_xla_sharding_type(self):

xs.clear_sharding(t)
xt = xs.mark_sharding(t, mesh, (None, 1))
if self.n_devices > 1:
if mesh.get_logical_mesh().shape[1] > 1:
self.assertEqual(xt.sharding_type, xs.ShardingType.PARTIAL)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
Expand Down Expand Up @@ -339,14 +342,13 @@ def test_mark_sharding_partial(self):
mesh = self._get_mesh((z_dim, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (0, None))

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[%d,1,%d]' %
(z_dim, self.n_devices //
z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[%d,1,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
Expand Down Expand Up @@ -381,14 +383,13 @@ def test_mark_sharding_partial_unordered(self):
mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (1, None, 0))

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[1,1,%d,%d]' %
(z_dim, self.n_devices //
z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[1,1,%d,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
Expand Down Expand Up @@ -485,14 +486,14 @@ def test_partial_replication_addmm(self):
xs.mark_sharding(xw, mesh, (None, 1))

# Check if the partial replication annotations are passed to the compiler.
# Note that partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(xx))
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(xw))
# Note that partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xx))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xw))
actual = (xx @ xw + xb).cpu()
self.assertTrue(torch.allclose(expected, actual))
self.assertTrue(torch.allclose(expected, actual, atol=1e-5))

def test_clear_sharding(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
Expand Down Expand Up @@ -723,10 +724,14 @@ def test_2d_tensor_3d_mesh(self):
# Meaningful test for higher-order mesh with extra replication
# requires multiple devices. Otherwise, this should defaults back to
# full replication.
if self.n_devices > 1:
if self.n_devices >= 4:
mesh = self._get_mesh((2, self.n_devices // 2, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = 'sharding={devices=[1,%d,2]' % (self.n_devices // 2)
elif self.n_devices == 2:
mesh = self._get_mesh((2, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
else:
mesh = self._get_mesh((1, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
Expand Down
2 changes: 0 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
# import pdb
# pdb.set_trace()
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
Expand Down
11 changes: 9 additions & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,14 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorMNIST(test_utils.XlaTestCase):

def test(self):
# devices=['xla:0', 'xla:1', 'xla:2', 'xla:3'] for example.
devices = xm.get_xla_supported_devices()
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=8)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
Expand Down Expand Up @@ -267,6 +272,10 @@ def loop_fn(model, loader, device, context):
model_parallel(loop_fn, train_loader)


@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorResnet18(test_utils.XlaTestCase):

def test(self):
Expand Down Expand Up @@ -1247,8 +1256,6 @@ def test_fn(a):

self.runAtenTest(torch.zeros([4, 4]), test_fn)

@unittest.skipIf(xr.device_type() == 'GPU',
"This test fails only on GPU with 07/05 XLA pin update.")
def test_stack_pred(self):

def test_fn(a):
Expand Down
12 changes: 11 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
that kind.
Returns:
The list of device strings.
The list of device strings such as ['xla:0', 'xla:1', ...]
"""
# TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
# multiple device types.
Expand Down Expand Up @@ -220,6 +220,14 @@ def _xla_real_device(device):


def xla_real_devices(devices: Optional[List[torch.device]] = None):
"""Returns the real devices' name.
Args:
devices: The list of torch devices such as ['xla:0', 'xla:1'].
Returns:
A list of real devices' name such as ['CUDA:0', 'CUDA:1'].
"""
if not devices:
devices = get_xla_supported_devices()

Expand Down Expand Up @@ -260,6 +268,7 @@ def xla_replication_devices(local_devices):
format(len(local_devices), len(kind_devices)))
replication_devices = []
for device in torch_xla._XLAC._xla_get_all_devices():
# device is like 'CUDA:0'
xdev = parse_xla_device(device)
if not xdev:
raise RuntimeError('Invalid device format: {}'.format(device))
Expand Down Expand Up @@ -287,6 +296,7 @@ def set_replication(device, devices):
devctx = _get_device_context(device=device)
devices = [str(x) for x in devices]
if devices:
# sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3']
replication_devices = xla_replication_devices(devices)
torch_xla._XLAC._xla_set_replication_devices(replication_devices)
devctx.device_index = devices.index(device)
Expand Down
Loading

0 comments on commit 8fc8d57

Please sign in to comment.