Skip to content

Commit

Permalink
Support PJRT C API create_options (#6289)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Jan 25, 2024
1 parent a4d2d76 commit c738156
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 14 deletions.
48 changes: 47 additions & 1 deletion plugins/cuda/torch_xla_cuda_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
import os
from torch_xla.experimental import plugins
import torch_xla.utils.utils as xu


class CudaPlugin(plugins.DevicePlugin):

def _get_process_rank(self) -> int:
local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int,
xu.getenv_as("LOCAL_RANK", int, 0))
global_process_rank = xu.getenv_as("RANK", int, local_process_rank)

return local_process_rank, global_process_rank

def _get_world_size(self) -> int:
local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int,
xu.getenv_as("LOCAL_WORLD_SIZE", int, 1))
global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size)

return local_world_size, global_world_size

def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
return int(os.getenv('GPU_NUM_DEVICES', '1'))
return xu.getenv_as('GPU_NUM_DEVICES', int, 1)

def client_create_options(self) -> dict:
local_process_rank, global_process_rank = self._get_process_rank()
local_world_size, global_world_size = self._get_world_size()

# The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67
options = {
"platform_name":
"gpu",
# TODO(wcromar): make this configurable
"allocator":
"cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool,
False) else "default",
"memory_fraction":
xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None),
"preallocate":
xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None),
# Use all devices by default and when using SPMD
"visible_devices": [local_process_rank]
if local_world_size > 1 else None,
"node_id":
global_process_rank,
"num_nodes":
global_world_size,
}

return {k: v for k, v in options.items() if v is not None}

def requires_xla_coordinator(self) -> bool:
_, global_world_size = self._get_world_size()
return global_world_size > 1
8 changes: 6 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2320,8 +2320,12 @@ void InitXlaModuleBindings(py::module m) {
});
// -------------Dynamo Integration API End-------------------------
m.def("_register_pjrt_plugin",
[](std::string name, std::string library_path) {
runtime::RegisterPjRtPlugin(name, library_path);
[](std::string name, std::string library_path,
std::unordered_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
runtime::RegisterPjRtPlugin(
name, library_path,
{create_options.begin(), create_options.end()}, init_coordinator);
});
}
} // namespace
Expand Down
60 changes: 51 additions & 9 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
namespace torch_xla {
namespace runtime {

std::unordered_map<std::string, std::string> pjrt_plugins_;

namespace {

struct PluginEntry {
std::string library_path;
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options;
bool init_coordinator;
};

std::unordered_map<std::string, PluginEntry> pjrt_plugins_;

xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
auto allocator_config = xla::GpuAllocatorConfig{};
if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() &&
Expand All @@ -37,17 +43,21 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
return allocator_config;
}

std::optional<std::string> GetPjRtPluginPath(const std::string& device_type) {
std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) {
auto plugin_path = pjrt_plugins_.find(device_type);
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second)
: std::nullopt;
}

} // namespace

void RegisterPjRtPlugin(std::string name, std::string library_path) {
void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path;
pjrt_plugins_[name] = library_path;
pjrt_plugins_[name] = {std::move(library_path), std::move(create_options),
init_coordinator};
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
Expand All @@ -56,13 +66,45 @@ InitializePjRt(const std::string& device_type) {
std::unique_ptr<XlaCoordinator> coordinator;

if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
std::optional<std::string> plugin_path = GetPjRtPluginPath(device_type);
if (plugin_path) {
std::optional<PluginEntry> plugin = GetPjRtPlugin(device_type);
if (plugin) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;

std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
if (plugin->init_coordinator) {
int local_process_rank = sys_util::GetEnvInt(
env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0));
int global_process_rank =
sys_util::GetEnvInt("RANK", local_process_rank);
int local_world_size =
sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount,
sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
int global_world_size =
sys_util::GetEnvInt("WORLD_SIZE", local_world_size);

std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank
<< ", world size=" << global_world_size
<< ", coordinator address=" << master_addr << ":" << port;

// Use the XlaCoordinator as the distributed key-value store.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"pjrt:");
}
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
absl::AsciiStrToLower(device_type), *plugin_path);
absl::AsciiStrToLower(device_type), plugin->library_path);
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value();
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
plugin->create_options, kv_store)
.value();
profiler::RegisterProfilerForPlugin(c_api);
}
} else if (device_type == "CPU") {
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/runtime/pjrt_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
namespace torch_xla {
namespace runtime {

void RegisterPjRtPlugin(std::string name, std::string library_path);
void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options = {},
bool init_coordinator = true);

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type);
Expand Down
15 changes: 14 additions & 1 deletion torch_xla/experimental/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def physical_chip_count(self):
"""
return 1

def client_create_options(self) -> dict:
return {}

def requires_xla_coordinator(self) -> bool:
"""Whether to initialize the XLA coordinator before plugin client.
Expects `torchrun` variables such as RANK, WORLD_SIZE, MASTER_ADDR to be
set.
"""
return False


_plugin_registry = {}

Expand All @@ -73,7 +84,9 @@ def default() -> DevicePlugin:

def register_plugin(name: str, device_plugin: DevicePlugin):
_plugin_registry[name.upper()] = device_plugin
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path())
torch_xla._XLAC._register_pjrt_plugin(
name, device_plugin.library_path(), device_plugin.client_create_options(),
device_plugin.requires_xla_coordinator())


def register_installed_plugins():
Expand Down

0 comments on commit c738156

Please sign in to comment.