diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index d8d159de2ef..9321d26a1a6 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -1,13 +1,59 @@ import os from torch_xla.experimental import plugins +import torch_xla.utils.utils as xu class CudaPlugin(plugins.DevicePlugin): + def _get_process_rank(self) -> int: + local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int, + xu.getenv_as("LOCAL_RANK", int, 0)) + global_process_rank = xu.getenv_as("RANK", int, local_process_rank) + + return local_process_rank, global_process_rank + + def _get_world_size(self) -> int: + local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int, + xu.getenv_as("LOCAL_WORLD_SIZE", int, 1)) + global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + + return local_world_size, global_world_size + def library_path(self) -> str: return os.path.join( os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') def physical_chip_count(self) -> int: # TODO: default to actual device count - return int(os.getenv('GPU_NUM_DEVICES', '1')) + return xu.getenv_as('GPU_NUM_DEVICES', int, 1) + + def client_create_options(self) -> dict: + local_process_rank, global_process_rank = self._get_process_rank() + local_world_size, global_world_size = self._get_world_size() + + # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 + options = { + "platform_name": + "gpu", + # TODO(wcromar): make this configurable + "allocator": + "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, + False) else "default", + "memory_fraction": + xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None), + "preallocate": + xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), + # Use all devices by default and when using SPMD + "visible_devices": [local_process_rank] + if local_world_size > 1 else None, + "node_id": + global_process_rank, + "num_nodes": + global_world_size, + } + + return {k: v for k, v in options.items() if v is not None} + + def requires_xla_coordinator(self) -> bool: + _, global_world_size = self._get_world_size() + return global_world_size > 1 diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 3281f0e9a67..1cb2dc2ba4a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2320,8 +2320,12 @@ void InitXlaModuleBindings(py::module m) { }); // -------------Dynamo Integration API End------------------------- m.def("_register_pjrt_plugin", - [](std::string name, std::string library_path) { - runtime::RegisterPjRtPlugin(name, library_path); + [](std::string name, std::string library_path, + std::unordered_map create_options, + bool init_coordinator) { + runtime::RegisterPjRtPlugin( + name, library_path, + {create_options.begin(), create_options.end()}, init_coordinator); }); } } // namespace diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index ce8b3f029e5..1738f21a829 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -16,10 +16,16 @@ namespace torch_xla { namespace runtime { -std::unordered_map pjrt_plugins_; - namespace { +struct PluginEntry { + std::string library_path; + absl::flat_hash_map create_options; + bool init_coordinator; +}; + +std::unordered_map pjrt_plugins_; + xla::GpuAllocatorConfig GetGpuAllocatorConfig() { auto allocator_config = xla::GpuAllocatorConfig{}; if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && @@ -37,7 +43,7 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } -std::optional GetPjRtPluginPath(const std::string& device_type) { +std::optional GetPjRtPlugin(const std::string& device_type) { auto plugin_path = pjrt_plugins_.find(device_type); return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) : std::nullopt; @@ -45,9 +51,13 @@ std::optional GetPjRtPluginPath(const std::string& device_type) { } // namespace -void RegisterPjRtPlugin(std::string name, std::string library_path) { +void RegisterPjRtPlugin( + std::string name, std::string library_path, + absl::flat_hash_map create_options, + bool init_coordinator) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; - pjrt_plugins_[name] = library_path; + pjrt_plugins_[name] = {std::move(library_path), std::move(create_options), + init_coordinator}; } std::tuple, std::unique_ptr> @@ -56,13 +66,45 @@ InitializePjRt(const std::string& device_type) { std::unique_ptr coordinator; if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { - std::optional plugin_path = GetPjRtPluginPath(device_type); - if (plugin_path) { + std::optional plugin = GetPjRtPlugin(device_type); + if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + + std::shared_ptr kv_store = nullptr; + if (plugin->init_coordinator) { + int local_process_rank = sys_util::GetEnvInt( + env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0)); + int global_process_rank = + sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = + sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount, + sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1)); + int global_world_size = + sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank + << ", world size=" << global_world_size + << ", coordinator address=" << master_addr << ":" << port; + + // Use the XlaCoordinator as the distributed key-value store. + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"pjrt:"); + } const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( - absl::AsciiStrToLower(device_type), *plugin_path); + absl::AsciiStrToLower(device_type), plugin->library_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); - client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value(); + client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), + plugin->create_options, kv_store) + .value(); profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { diff --git a/torch_xla/csrc/runtime/pjrt_registry.h b/torch_xla/csrc/runtime/pjrt_registry.h index 4cb7b70a661..24e80a298a7 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.h +++ b/torch_xla/csrc/runtime/pjrt_registry.h @@ -6,7 +6,10 @@ namespace torch_xla { namespace runtime { -void RegisterPjRtPlugin(std::string name, std::string library_path); +void RegisterPjRtPlugin( + std::string name, std::string library_path, + absl::flat_hash_map create_options = {}, + bool init_coordinator = true); std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index b7b51d53242..99abf53f8da 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -50,6 +50,17 @@ def physical_chip_count(self): """ return 1 + def client_create_options(self) -> dict: + return {} + + def requires_xla_coordinator(self) -> bool: + """Whether to initialize the XLA coordinator before plugin client. + + Expects `torchrun` variables such as RANK, WORLD_SIZE, MASTER_ADDR to be + set. + """ + return False + _plugin_registry = {} @@ -73,7 +84,9 @@ def default() -> DevicePlugin: def register_plugin(name: str, device_plugin: DevicePlugin): _plugin_registry[name.upper()] = device_plugin - torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) + torch_xla._XLAC._register_pjrt_plugin( + name, device_plugin.library_path(), device_plugin.client_create_options(), + device_plugin.requires_xla_coordinator()) def register_installed_plugins():