Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PJRT C API create_options #6289

Merged
merged 16 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion plugins/cuda/torch_xla_cuda_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,59 @@
import os
from torch_xla.experimental import plugins
import torch_xla.utils.utils as xu


class CudaPlugin(plugins.DevicePlugin):

def _get_process_rank(self) -> int:
local_process_rank = xu.getenv_as("PJRT_LOCAL_PROCESS_RANK", int,
xu.getenv_as("LOCAL_RANK", int, 0))
global_process_rank = xu.getenv_as("RANK", int, local_process_rank)

return local_process_rank, global_process_rank

def _get_world_size(self) -> int:
local_world_size = xu.getenv_as("PJRT_LOCAL_PROCESS_COUNT", int,
xu.getenv_as("LOCAL_WORLD_SIZE", int, 1))
global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size)

return local_world_size, global_world_size

def library_path(self) -> str:
return os.path.join(
os.path.dirname(__file__), 'lib', 'pjrt_c_api_gpu_plugin.so')

def physical_chip_count(self) -> int:
# TODO: default to actual device count
return int(os.getenv('GPU_NUM_DEVICES', '1'))
return xu.getenv_as('GPU_NUM_DEVICES', int, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps irrelevant to this PR but just want to confirm that the # TODO: default to actual device count still holds, since GPU_NUM_DEVICES is not always set and the default value may not be 1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it's only used in run_multiprocess. So it sound using GPU_NUM_DEVICES preserves the current behavior. Looks good to me then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is the same as the current behavior. Ideally this should check the PCI device IDs like we do for TPUs.


def client_create_options(self) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who will call this client_create_options method here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugins.py will call it during plugin registration in this PR. In the follow up, this will be called when the client is created.

local_process_rank, global_process_rank = self._get_process_rank()
local_world_size, global_world_size = self._get_world_size()

# The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67
options = {
"platform_name":
"gpu",
# TODO(wcromar): make this configurable
"allocator":
"cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool,
False) else "default",
"memory_fraction":
xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, None),
"preallocate":
xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, None),
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved
# Use all devices by default and when using SPMD
"visible_devices": [local_process_rank]
if local_world_size > 1 else None,
"node_id":
global_process_rank,
"num_nodes":
global_world_size,
}

return {k: v for k, v in options.items() if v is not None}

def requires_xla_coordinator(self) -> bool:
_, global_world_size = self._get_world_size()
return global_world_size > 1
8 changes: 6 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2320,8 +2320,12 @@ void InitXlaModuleBindings(py::module m) {
});
// -------------Dynamo Integration API End-------------------------
m.def("_register_pjrt_plugin",
[](std::string name, std::string library_path) {
runtime::RegisterPjRtPlugin(name, library_path);
[](std::string name, std::string library_path,
std::unordered_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
runtime::RegisterPjRtPlugin(
name, library_path,
{create_options.begin(), create_options.end()}, init_coordinator);
});
}
} // namespace
Expand Down
60 changes: 51 additions & 9 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
namespace torch_xla {
namespace runtime {

std::unordered_map<std::string, std::string> pjrt_plugins_;

namespace {

struct PluginEntry {
std::string library_path;
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options;
bool init_coordinator;
Copy link
Collaborator

@jonb377 jonb377 Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should drop init_coordinator and instead always initialize the coordinator when the distributed env vars are set, even on TPU where it's not strictly necessary. As long as torchrun launches the training in a distributed context, the env vars should be set, which I believe covers all GPU use cases since we plan to use torchrun for GPU SPMD (cc @vanbasten23).

On TPU it's not currently required, but if that changes we can always detect the environment from the GCE metadata and set the env vars automatically for the user in a distributed context, since we don't require torchrun for multicontroller execution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should drop init_coordinator and instead always initialize the coordinator when the distributed env vars are set, even on TPU where it's not strictly necessary.

In this case, I think we still want to keep requires_xla_coordinator option. We would just throw an error immediately if we don't have enough information to start the coordinator.

My other idea initially was that we could ask the plugin for the master IP, local rank, global rank, and world size, perhaps just asking torch.distributed for those values in the default implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - that makes sense. Just for context why I brought this up, JAX recently started requiring the coordinator to be initialized before the backend can be used (even on TPUs), but I'm not sure on the reason.

Keeping init_coordinator sounds fine to me. Thanks Will!

};

std::unordered_map<std::string, PluginEntry> pjrt_plugins_;

xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
auto allocator_config = xla::GpuAllocatorConfig{};
if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() &&
Expand All @@ -37,17 +43,21 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
return allocator_config;
}

std::optional<std::string> GetPjRtPluginPath(const std::string& device_type) {
std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) {
auto plugin_path = pjrt_plugins_.find(device_type);
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second)
: std::nullopt;
}

} // namespace

void RegisterPjRtPlugin(std::string name, std::string library_path) {
void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options,
bool init_coordinator) {
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path;
pjrt_plugins_[name] = library_path;
pjrt_plugins_[name] = {std::move(library_path), std::move(create_options),
init_coordinator};
}

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
Expand All @@ -56,13 +66,45 @@ InitializePjRt(const std::string& device_type) {
std::unique_ptr<XlaCoordinator> coordinator;

if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
std::optional<std::string> plugin_path = GetPjRtPluginPath(device_type);
if (plugin_path) {
std::optional<PluginEntry> plugin = GetPjRtPlugin(device_type);
will-cromar marked this conversation as resolved.
Show resolved Hide resolved
if (plugin) {
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;

std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
if (plugin->init_coordinator) {
int local_process_rank = sys_util::GetEnvInt(
env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0));
int global_process_rank =
sys_util::GetEnvInt("RANK", local_process_rank);
int local_world_size =
sys_util::GetEnvInt(env::kEnvPjRtLocalProcessCount,
sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
int global_world_size =
sys_util::GetEnvInt("WORLD_SIZE", local_world_size);

std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

TF_VLOG(3) << "Creating coordinator for rank=" << global_process_rank
<< ", world size=" << global_world_size
<< ", coordinator address=" << master_addr << ":" << port;

// Use the XlaCoordinator as the distributed key-value store.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"pjrt:");
}
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
absl::AsciiStrToLower(device_type), *plugin_path);
absl::AsciiStrToLower(device_type), plugin->library_path);
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value();
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
will-cromar marked this conversation as resolved.
Show resolved Hide resolved
plugin->create_options, kv_store)
.value();
profiler::RegisterProfilerForPlugin(c_api);
}
} else if (device_type == "CPU") {
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/runtime/pjrt_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
namespace torch_xla {
namespace runtime {

void RegisterPjRtPlugin(std::string name, std::string library_path);
void RegisterPjRtPlugin(
std::string name, std::string library_path,
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options = {},
bool init_coordinator = true);

std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
InitializePjRt(const std::string& device_type);
Expand Down
15 changes: 14 additions & 1 deletion torch_xla/experimental/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def physical_chip_count(self):
"""
return 1

def client_create_options(self) -> dict:
return {}

def requires_xla_coordinator(self) -> bool:
will-cromar marked this conversation as resolved.
Show resolved Hide resolved
"""Whether to initialize the XLA coordinator before plugin client.

Expects `torchrun` variables such as RANK, WORLD_SIZE, MASTER_ADDR to be
set.
"""
return False


_plugin_registry = {}

Expand All @@ -73,7 +84,9 @@ def default() -> DevicePlugin:

def register_plugin(name: str, device_plugin: DevicePlugin):
_plugin_registry[name.upper()] = device_plugin
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin.library_path())
torch_xla._XLAC._register_pjrt_plugin(
name, device_plugin.library_path(), device_plugin.client_create_options(),
device_plugin.requires_xla_coordinator())


def register_installed_plugins():
Expand Down
Loading