-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix global_device_count(), local_device_count() for single process on…
… CUDA (#6022)
- Loading branch information
1 parent
3e68409
commit 8fc8d57
Showing
15 changed files
with
171 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24. | ||
// It can be removed in the next openXLA pin update after 01/26/2024. | ||
diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc | ||
index 0f1818be2..c181f3025 100644 | ||
--- a/xla/service/gpu/gpu_executable.cc | ||
+++ b/xla/service/gpu/gpu_executable.cc | ||
@@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name, | ||
} | ||
} | ||
|
||
- // Maybe join a round of rendezvous after thunk initialization. | ||
- TF_RETURN_IF_ERROR( | ||
- MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); | ||
+ // Maybe join a round of rendezvous after thunk initialization. We do this | ||
+ // only in presence of collective cliques which means that we have collective | ||
+ // operations in the XLA operations that tend to cause deadlocks. | ||
+ if (!collective_cliques.empty()) { | ||
+ TF_RETURN_IF_ERROR( | ||
+ MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); | ||
+ } | ||
|
||
// Prepare parameters for thunks execution. | ||
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( | ||
diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h | ||
index 51a566b8f..94bab421f 100644 | ||
--- a/xla/service/gpu/thunk.h | ||
+++ b/xla/service/gpu/thunk.h | ||
@@ -175,6 +175,8 @@ class Thunk { | ||
absl::StatusOr<NcclComm::Lock> GetComm(const NcclCliqueKey& clique_key, | ||
int32_t rank) const; | ||
|
||
+ bool empty() const { return cliques_map_.empty(); } | ||
+ | ||
private: | ||
CliquesMap cliques_map_; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import concurrent.futures | ||
import itertools | ||
import os | ||
import queue | ||
import requests | ||
import unittest | ||
import subprocess | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch_xla | ||
import torch_xla.core.xla_env_vars as xenv | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
from torch_xla import runtime as xr | ||
from torch_xla._internal import pjrt | ||
from absl.testing import absltest, parameterized | ||
|
||
|
||
@unittest.skipIf(xr.device_type() != "CUDA", | ||
f"GPU tests should only run on GPU devices.") | ||
class TestExperimentalSingleProcPjrtGpu(parameterized.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
command = 'nvidia-smi --list-gpus | wc -l' | ||
result = subprocess.run( | ||
command, | ||
capture_output=True, | ||
shell=True, | ||
check=True, | ||
text=True, | ||
) | ||
cls.num_cuda_devices = int(result.stdout) | ||
|
||
def test_num_local_devices(self): | ||
self.assertLen(xm.get_xla_supported_devices(), | ||
xr.addressable_device_count()) | ||
self.assertEqual(self.num_cuda_devices, xr.addressable_device_count()) | ||
|
||
def test_num_global_devices(self): | ||
self.assertLen(torch_xla._XLAC._xla_get_all_devices(), | ||
xr.global_device_count()) | ||
self.assertEqual(self.num_cuda_devices, xr.global_device_count()) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.