From c41d2076f9ae611ceb46c6973b66c5922b26d1e0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 10 Jan 2024 22:50:46 +0000 Subject: [PATCH 01/14] Support `create_options` --- torch_xla/csrc/init_python_bindings.cpp | 7 +++++-- torch_xla/csrc/runtime/pjrt_registry.cc | 26 ++++++++++++++++--------- torch_xla/csrc/runtime/pjrt_registry.h | 2 +- torch_xla/experimental/plugins.py | 4 ++-- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 86b9c896b1a..71bd8248006 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2319,8 +2319,11 @@ void InitXlaModuleBindings(py::module m) { }); // -------------Dynamo Integration API End------------------------- m.def("_register_pjrt_plugin", - [](std::string name, std::string library_path) { - runtime::RegisterPjRtPlugin(name, library_path); + [](std::string name, std::string library_path, std::unordered_map create_options) { + // for (auto item : create_options) { + // std::cout << "key: " << item.first << std::endl; + // } + runtime::RegisterPjRtPlugin(name, library_path, {create_options.begin(), create_options.end()}); }); } } // namespace diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index f79b24fe92f..ae0775802ce 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -14,10 +14,15 @@ namespace torch_xla { namespace runtime { -std::unordered_map pjrt_plugins_; - namespace { +struct PluginEntry { + std::string library_path; + absl::flat_hash_map create_options; +}; + +std::unordered_map pjrt_plugins_; + xla::GpuAllocatorConfig GetGpuAllocatorConfig() { auto allocator_config = xla::GpuAllocatorConfig{}; if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && @@ -35,7 +40,7 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { return allocator_config; } -std::optional GetPjRtPluginPath(const std::string& device_type) { +std::optional GetPjRtPlugin(const std::string& device_type) { auto plugin_path = pjrt_plugins_.find(device_type); return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) : std::nullopt; @@ -43,9 +48,12 @@ std::optional GetPjRtPluginPath(const std::string& device_type) { } // namespace -void RegisterPjRtPlugin(std::string name, std::string library_path) { +void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; - pjrt_plugins_[name] = library_path; + for (auto item : create_options) { + std::cout << "key: " << item.first << std::endl; + } + pjrt_plugins_[name] = {std::move(library_path), std::move(create_options)}; } std::tuple, std::unique_ptr> @@ -54,13 +62,13 @@ InitializePjRt(const std::string& device_type) { std::unique_ptr coordinator; if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { - std::optional plugin_path = GetPjRtPluginPath(device_type); - if (plugin_path) { + std::optional plugin = GetPjRtPlugin(device_type); + if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( - absl::AsciiStrToLower(device_type), *plugin_path); + absl::AsciiStrToLower(device_type), plugin->library_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); - client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value(); + client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), plugin->create_options).value(); profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { diff --git a/torch_xla/csrc/runtime/pjrt_registry.h b/torch_xla/csrc/runtime/pjrt_registry.h index 4cb7b70a661..2eed1fb7eb8 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.h +++ b/torch_xla/csrc/runtime/pjrt_registry.h @@ -6,7 +6,7 @@ namespace torch_xla { namespace runtime { -void RegisterPjRtPlugin(std::string name, std::string library_path); +void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options = {}); std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 90b0ff2b8b0..43ad501bc49 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -62,6 +62,6 @@ def default() -> DevicePlugin: return _plugin_registry[xr.device_type()] -def register_plugin(name: str, device_plugin: DevicePlugin): +def register_plugin(name: str, device_plugin: DevicePlugin, create_options = {}): _plugin_registry[name.upper()] = device_plugin - torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path()) + torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), create_options) From 673db212f0ccd6bb8073fb6dbb929ebf5f8c9f17 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 10 Jan 2024 22:54:05 +0000 Subject: [PATCH 02/14] add options to DevicePlugin --- torch_xla/experimental/plugins.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 43ad501bc49..4fd7058f7e9 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -41,6 +41,9 @@ def physical_chip_count(self): """ return 1 + def client_create_options(self) -> dict: + return {} + _plugin_registry = {} @@ -62,6 +65,6 @@ def default() -> DevicePlugin: return _plugin_registry[xr.device_type()] -def register_plugin(name: str, device_plugin: DevicePlugin, create_options = {}): +def register_plugin(name: str, device_plugin: DevicePlugin): _plugin_registry[name.upper()] = device_plugin - torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), create_options) + torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), device_plugin.client_create_options()) From a01c5858300064d5c2e3f2a1448f0284bb9f48a5 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 11 Jan 2024 22:13:24 +0000 Subject: [PATCH 03/14] init XlaCoordinator --- torch_xla/csrc/init_python_bindings.cpp | 4 +-- torch_xla/csrc/runtime/pjrt_registry.cc | 37 +++++++++++++++++++++++-- torch_xla/csrc/runtime/pjrt_registry.h | 2 +- torch_xla/experimental/plugins.py | 10 ++++++- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 71bd8248006..c11db060e9e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2319,11 +2319,11 @@ void InitXlaModuleBindings(py::module m) { }); // -------------Dynamo Integration API End------------------------- m.def("_register_pjrt_plugin", - [](std::string name, std::string library_path, std::unordered_map create_options) { + [](std::string name, std::string library_path, std::unordered_map create_options, bool init_coordinator) { // for (auto item : create_options) { // std::cout << "key: " << item.first << std::endl; // } - runtime::RegisterPjRtPlugin(name, library_path, {create_options.begin(), create_options.end()}); + runtime::RegisterPjRtPlugin(name, library_path, {create_options.begin(), create_options.end()}, init_coordinator); }); } } // namespace diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index ae0775802ce..256286e1190 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -19,6 +19,7 @@ namespace { struct PluginEntry { std::string library_path; absl::flat_hash_map create_options; + bool init_coordinator; }; std::unordered_map pjrt_plugins_; @@ -48,12 +49,12 @@ std::optional GetPjRtPlugin(const std::string& device_type) { } // namespace -void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options) { +void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options, bool init_coordinator) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; for (auto item : create_options) { std::cout << "key: " << item.first << std::endl; } - pjrt_plugins_[name] = {std::move(library_path), std::move(create_options)}; + pjrt_plugins_[name] = {std::move(library_path), std::move(create_options), init_coordinator}; } std::tuple, std::unique_ptr> @@ -65,10 +66,40 @@ InitializePjRt(const std::string& device_type) { std::optional plugin = GetPjRtPlugin(device_type); if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; + + xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; + xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; + if (plugin->init_coordinator) { + int global_process_rank = sys_util::GetEnvInt("RANK", 0); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + if (global_world_size > 1) { + // Use the XlaCoordinator as the distributed key-value store. + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet( + absl::StrCat(key_prefix, k), timeout); + }; + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + } + } const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( absl::AsciiStrToLower(device_type), plugin->library_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); - client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), plugin->create_options).value(); + client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), plugin->create_options, kv_get, kv_put).value(); profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { diff --git a/torch_xla/csrc/runtime/pjrt_registry.h b/torch_xla/csrc/runtime/pjrt_registry.h index 2eed1fb7eb8..0881fa52e8f 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.h +++ b/torch_xla/csrc/runtime/pjrt_registry.h @@ -6,7 +6,7 @@ namespace torch_xla { namespace runtime { -void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options = {}); +void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options = {}, bool init_coordinator = true); std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 4fd7058f7e9..2727cea50d9 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -44,6 +44,14 @@ def physical_chip_count(self): def client_create_options(self) -> dict: return {} + def requires_xla_coordinator(self) -> bool: + """Whether to initialize the XLA coordinator before plugin client. + + Expects `torchrun` variables such as RANK, WORLD_SIZE, MASTER_ADDR to be + set. + """ + return False + _plugin_registry = {} @@ -67,4 +75,4 @@ def default() -> DevicePlugin: def register_plugin(name: str, device_plugin: DevicePlugin): _plugin_registry[name.upper()] = device_plugin - torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), device_plugin.client_create_options()) + torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), device_plugin.client_create_options(), device_plugin.requires_xla_coordinator()) From 5771baa25f583d961c8ebeed9e08f310cc4ee220 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 12 Jan 2024 22:57:27 +0000 Subject: [PATCH 04/14] implement create options for GPU plugin --- .../cuda/torch_xla_cuda_plugin/__init__.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index f10a412bfaa..cdf59b2b045 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -1,6 +1,6 @@ import os from torch_xla.experimental import plugins -from torch_xla._internal import tpu +import torch_xla.utils.utils as xu class GpuPlugin(plugins.DevicePlugin): def library_path(self) -> str: @@ -8,4 +8,24 @@ def library_path(self) -> str: def physical_chip_count(self) -> int: # TODO: default to actual device count - return int(os.getenv('GPU_NUM_DEVICES', '1')) + return xu.getenv_as('GPU_NUM_DEVICES', int, 1) + + def client_create_options(self) -> dict: + local_process_rank = xu.getenv_as("LOCAL_RANK", int, 0) + global_process_rank = xu.getenv_as("RANK", int, local_process_rank) + local_world_size = xu.getenv_as("LOCAL_WORLD_SIZE", int, 1) + global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + + return { + "platform_name": "gpu", + # TODO(wcromar): make this configurable + "allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", + "memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), + "preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), + "visible_devices": [local_process_rank], + "node_id": global_process_rank, + "num_nodes": global_world_size, + } + + def requires_xla_coordinator(self) -> bool: + return True From ba24333e2c925f854b8c2766ad3b49cd00e8975f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 12 Jan 2024 23:10:05 +0000 Subject: [PATCH 05/14] formatting --- torch_xla/csrc/init_python_bindings.cpp | 11 ++++++----- torch_xla/csrc/runtime/pjrt_registry.cc | 24 ++++++++++++++---------- torch_xla/csrc/runtime/pjrt_registry.h | 5 ++++- torch_xla/experimental/plugins.py | 4 +++- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c11db060e9e..0584b1f0a01 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2319,11 +2319,12 @@ void InitXlaModuleBindings(py::module m) { }); // -------------Dynamo Integration API End------------------------- m.def("_register_pjrt_plugin", - [](std::string name, std::string library_path, std::unordered_map create_options, bool init_coordinator) { - // for (auto item : create_options) { - // std::cout << "key: " << item.first << std::endl; - // } - runtime::RegisterPjRtPlugin(name, library_path, {create_options.begin(), create_options.end()}, init_coordinator); + [](std::string name, std::string library_path, + std::unordered_map create_options, + bool init_coordinator) { + runtime::RegisterPjRtPlugin( + name, library_path, + {create_options.begin(), create_options.end()}, init_coordinator); }); } } // namespace diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 256286e1190..fdf3d38501f 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -49,12 +49,13 @@ std::optional GetPjRtPlugin(const std::string& device_type) { } // namespace -void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options, bool init_coordinator) { +void RegisterPjRtPlugin( + std::string name, std::string library_path, + absl::flat_hash_map create_options, + bool init_coordinator) { TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; - for (auto item : create_options) { - std::cout << "key: " << item.first << std::endl; - } - pjrt_plugins_[name] = {std::move(library_path), std::move(create_options), init_coordinator}; + pjrt_plugins_[name] = {std::move(library_path), std::move(create_options), + init_coordinator}; } std::tuple, std::unique_ptr> @@ -85,21 +86,24 @@ InitializePjRt(const std::string& device_type) { coordinator->GetClient(); std::string key_prefix = "gpu:"; kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + std::string_view k, std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), + v); }; } } const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( absl::AsciiStrToLower(device_type), plugin->library_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); - client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), plugin->create_options, kv_get, kv_put).value(); + client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), + plugin->create_options, kv_get, kv_put) + .value(); profiler::RegisterProfilerForPlugin(c_api); } } else if (device_type == "CPU") { diff --git a/torch_xla/csrc/runtime/pjrt_registry.h b/torch_xla/csrc/runtime/pjrt_registry.h index 0881fa52e8f..24e80a298a7 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.h +++ b/torch_xla/csrc/runtime/pjrt_registry.h @@ -6,7 +6,10 @@ namespace torch_xla { namespace runtime { -void RegisterPjRtPlugin(std::string name, std::string library_path, absl::flat_hash_map create_options = {}, bool init_coordinator = true); +void RegisterPjRtPlugin( + std::string name, std::string library_path, + absl::flat_hash_map create_options = {}, + bool init_coordinator = true); std::tuple, std::unique_ptr> InitializePjRt(const std::string& device_type); diff --git a/torch_xla/experimental/plugins.py b/torch_xla/experimental/plugins.py index 2727cea50d9..34bdaed40c7 100644 --- a/torch_xla/experimental/plugins.py +++ b/torch_xla/experimental/plugins.py @@ -75,4 +75,6 @@ def default() -> DevicePlugin: def register_plugin(name: str, device_plugin: DevicePlugin): _plugin_registry[name.upper()] = device_plugin - torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path(), device_plugin.client_create_options(), device_plugin.requires_xla_coordinator()) + torch_xla._XLAC._register_pjrt_plugin( + name, device_plugin.library_path(), device_plugin.client_create_options(), + device_plugin.requires_xla_coordinator()) From a0968ce1e24e442545ebfff75c0052fb9793f70c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 12 Jan 2024 23:15:11 +0000 Subject: [PATCH 06/14] link openxla gpu plugin options --- plugins/cuda/torch_xla_cuda_plugin/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index cdf59b2b045..d501ca0e449 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -16,6 +16,7 @@ def client_create_options(self) -> dict: local_world_size = xu.getenv_as("LOCAL_WORLD_SIZE", int, 1) global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 return { "platform_name": "gpu", # TODO(wcromar): make this configurable From 36d48e70da2ac5bee4d599953fa4615a58356e9d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 18 Jan 2024 00:09:22 +0000 Subject: [PATCH 07/14] `gpu`->`pjrt` --- torch_xla/csrc/runtime/pjrt_registry.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index fdf3d38501f..8845f462d5c 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -84,7 +84,7 @@ InitializePjRt(const std::string& device_type) { global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator->GetClient(); - std::string key_prefix = "gpu:"; + std::string key_prefix = "pjrt:"; kv_get = [distributed_client, key_prefix]( std::string_view k, absl::Duration timeout) -> xla::StatusOr { From 0236b0cf5341caa6e381b437b31a90423e80a1f0 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 19:22:54 +0000 Subject: [PATCH 08/14] remove some create options when unset --- .../cuda/torch_xla_cuda_plugin/__init__.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index d501ca0e449..ff3cacf76a2 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -2,7 +2,9 @@ from torch_xla.experimental import plugins import torch_xla.utils.utils as xu + class GpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') @@ -17,16 +19,25 @@ def client_create_options(self) -> dict: global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 - return { - "platform_name": "gpu", - # TODO(wcromar): make this configurable - "allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", - "memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), - "preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), - "visible_devices": [local_process_rank], - "node_id": global_process_rank, - "num_nodes": global_world_size, + options = { + "platform_name": + "gpu", + # TODO(wcromar): make this configurable + "allocator": + "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, + False) else "default", + "memory_fraction": + xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None), + "preallocate": + xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), + "visible_devices": [local_process_rank], + "node_id": + global_process_rank, + "num_nodes": + global_world_size, } + return {k: v for k, v in options.items() if v is not None} + def requires_xla_coordinator(self) -> bool: return True From 25d154788295546edb4c7327697831df2a385dee Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 21:51:12 +0000 Subject: [PATCH 09/14] fix pin update --- torch_xla/csrc/runtime/pjrt_registry.cc | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 3e91ac7a914..f8ccec07adb 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -70,8 +70,7 @@ InitializePjRt(const std::string& device_type) { if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; + std::shared_ptr kv_store; if (plugin->init_coordinator) { int global_process_rank = sys_util::GetEnvInt("RANK", 0); int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); @@ -86,25 +85,15 @@ InitializePjRt(const std::string& device_type) { global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator->GetClient(); - std::string key_prefix = "pjrt:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), - v); - }; + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"pjrt:"); } } const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( absl::AsciiStrToLower(device_type), plugin->library_path); XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), - plugin->create_options, kv_get, kv_put) + plugin->create_options, kv_store) .value(); profiler::RegisterProfilerForPlugin(c_api); } From 61bf72f376b1fce8e51d400ddf881a230a8e838c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 23:27:47 +0000 Subject: [PATCH 10/14] use all devices in SPMD case (see #6022) --- plugins/cuda/torch_xla_cuda_plugin/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index ff3cacf76a2..4dacd9457c4 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -2,7 +2,6 @@ from torch_xla.experimental import plugins import torch_xla.utils.utils as xu - class GpuPlugin(plugins.DevicePlugin): def library_path(self) -> str: @@ -15,6 +14,7 @@ def physical_chip_count(self) -> int: def client_create_options(self) -> dict: local_process_rank = xu.getenv_as("LOCAL_RANK", int, 0) global_process_rank = xu.getenv_as("RANK", int, local_process_rank) + local_world_size = xu.getenv_as("LOCAL_WORLD_SIZE", int, 1) global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) @@ -30,11 +30,10 @@ def client_create_options(self) -> dict: xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None), "preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), - "visible_devices": [local_process_rank], - "node_id": - global_process_rank, - "num_nodes": - global_world_size, + # Use all devices by default and when using SPMD + "visible_devices": [local_process_rank] if local_world_size > 1 else None, + "node_id": global_process_rank, + "num_nodes": global_world_size, } return {k: v for k, v in options.items() if v is not None} From 775f9367979c4a116110907825da28e2ac8ecc1d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 23:28:17 +0000 Subject: [PATCH 11/14] Add some logging --- torch_xla/csrc/runtime/pjrt_registry.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index f8ccec07adb..a9d299ecd63 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -70,7 +70,7 @@ InitializePjRt(const std::string& device_type) { if (plugin) { TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; - std::shared_ptr kv_store; + std::shared_ptr kv_store = nullptr; if (plugin->init_coordinator) { int global_process_rank = sys_util::GetEnvInt("RANK", 0); int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); @@ -79,15 +79,15 @@ InitializePjRt(const std::string& device_type) { std::string port = runtime::sys_util::GetEnvString( "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - if (global_world_size > 1) { - // Use the XlaCoordinator as the distributed key-value store. - coordinator = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator->GetClient(); - kv_store = xla::GetDistributedKeyValueStore(distributed_client, - /*key_prefix=*/"pjrt:"); - } + TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank << ", world size=" << global_world_size << ", coordinator address=" << master_addr << ":" << port; + + // Use the XlaCoordinator as the distributed key-value store. + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"pjrt:"); } const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( absl::AsciiStrToLower(device_type), plugin->library_path); From 193951b11df55fe4e722eed3eec3144864756f2a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 23:28:58 +0000 Subject: [PATCH 12/14] formatting --- plugins/cpu/torch_xla_cpu_plugin/__init__.py | 2 ++ plugins/cuda/torch_xla_cuda_plugin/__init__.py | 10 +++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/plugins/cpu/torch_xla_cpu_plugin/__init__.py b/plugins/cpu/torch_xla_cpu_plugin/__init__.py index da7a3234267..db0662f4aa7 100644 --- a/plugins/cpu/torch_xla_cpu_plugin/__init__.py +++ b/plugins/cpu/torch_xla_cpu_plugin/__init__.py @@ -2,7 +2,9 @@ from torch_xla.experimental import plugins from torch_xla._internal import tpu + class CpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_cpu_plugin.so') diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index 4dacd9457c4..88ca859e4d1 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -2,6 +2,7 @@ from torch_xla.experimental import plugins import torch_xla.utils.utils as xu + class GpuPlugin(plugins.DevicePlugin): def library_path(self) -> str: @@ -31,9 +32,12 @@ def client_create_options(self) -> dict: "preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None), # Use all devices by default and when using SPMD - "visible_devices": [local_process_rank] if local_world_size > 1 else None, - "node_id": global_process_rank, - "num_nodes": global_world_size, + "visible_devices": [local_process_rank] + if local_world_size > 1 else None, + "node_id": + global_process_rank, + "num_nodes": + global_world_size, } return {k: v for k, v in options.items() if v is not None} From b9ddd6e8dc5169c040e457aec9b20672f6358245 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 19 Jan 2024 23:30:04 +0000 Subject: [PATCH 13/14] more formatting --- torch_xla/csrc/runtime/pjrt_registry.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index a9d299ecd63..f7f5601bba5 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -79,7 +79,9 @@ InitializePjRt(const std::string& device_type) { std::string port = runtime::sys_util::GetEnvString( "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank << ", world size=" << global_world_size << ", coordinator address=" << master_addr << ":" << port; + TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank + << ", world size=" << global_world_size + << ", coordinator address=" << master_addr << ":" << port; // Use the XlaCoordinator as the distributed key-value store. coordinator = std::make_unique( From 9b6dc0d75cc947b6d68c449faf47240a4602683f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 25 Jan 2024 19:31:35 +0000 Subject: [PATCH 14/14] Address review comments --- .../cuda/torch_xla_cuda_plugin/__init__.py | 24 ++++++++++++++----- torch_xla/csrc/runtime/pjrt_registry.cc | 12 ++++++++-- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py index 898f218a3a0..9321d26a1a6 100644 --- a/plugins/cuda/torch_xla_cuda_plugin/__init__.py +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -5,6 +5,20 @@ class CudaPlugin(plugins.DevicePlugin): + def _get_process_rank(self) -> int: + local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int, + xu.getenv_as("LOCAL_RANK", int, 0)) + global_process_rank = xu.getenv_as("RANK", int, local_process_rank) + + return local_process_rank, global_process_rank + + def _get_world_size(self) -> int: + local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int, + xu.getenv_as("LOCAL_WORLD_SIZE", int, 1)) + global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + + return local_world_size, global_world_size + def library_path(self) -> str: return os.path.join( os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so') @@ -14,11 +28,8 @@ def physical_chip_count(self) -> int: return xu.getenv_as('GPU_NUM_DEVICES', int, 1) def client_create_options(self) -> dict: - local_process_rank = xu.getenv_as("LOCAL_RANK", int, 0) - global_process_rank = xu.getenv_as("RANK", int, local_process_rank) - - local_world_size = xu.getenv_as("LOCAL_WORLD_SIZE", int, 1) - global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) + local_process_rank, global_process_rank = self._get_process_rank() + local_world_size, global_world_size = self._get_world_size() # The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 options = { @@ -44,4 +55,5 @@ def client_create_options(self) -> dict: return {k: v for k, v in options.items() if v is not None} def requires_xla_coordinator(self) -> bool: - return True + _, global_world_size = self._get_world_size() + return global_world_size > 1 diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index f7f5601bba5..1738f21a829 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -72,8 +72,16 @@ InitializePjRt(const std::string& device_type) { std::shared_ptr kv_store = nullptr; if (plugin->init_coordinator) { - int global_process_rank = sys_util::GetEnvInt("RANK", 0); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); + int local_process_rank = sys_util::GetEnvInt( + env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0)); + int global_process_rank = + sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = + sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount, + sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1)); + int global_world_size = + sys_util::GetEnvInt("WORLD_SIZE", local_world_size); + std::string master_addr = runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); std::string port = runtime::sys_util::GetEnvString(