-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support PJRT C API create_options
#6289
Changes from 6 commits
c41d207
673db21
a01c585
5771baa
ba24333
a0968ce
36d48e7
0236b0c
0f13d98
25d1547
61bf72f
775f936
193951b
b9ddd6e
b022da1
9b6dc0d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,11 +1,32 @@ | ||||||||||||||
import os | ||||||||||||||
from torch_xla.experimental import plugins | ||||||||||||||
from torch_xla._internal import tpu | ||||||||||||||
import torch_xla.utils.utils as xu | ||||||||||||||
|
||||||||||||||
class GpuPlugin(plugins.DevicePlugin): | ||||||||||||||
def library_path(self) -> str: | ||||||||||||||
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') | ||||||||||||||
|
||||||||||||||
def physical_chip_count(self) -> int: | ||||||||||||||
# TODO: default to actual device count | ||||||||||||||
return int(os.getenv('GPU_NUM_DEVICES', '1')) | ||||||||||||||
return xu.getenv_as('GPU_NUM_DEVICES', int, 1) | ||||||||||||||
|
||||||||||||||
def client_create_options(self) -> dict: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. who will call this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||
local_process_rank = xu.getenv_as("LOCAL_RANK", int, 0) | ||||||||||||||
global_process_rank = xu.getenv_as("RANK", int, local_process_rank) | ||||||||||||||
local_world_size = xu.getenv_as("LOCAL_WORLD_SIZE", int, 1) | ||||||||||||||
global_world_size = xu.getenv_as("WORLD_SIZE", int, local_world_size) | ||||||||||||||
|
||||||||||||||
# The available options are defined in OpenXLA: https://github.com/openxla/xla/blob/1bb2a74be91fabf5f9aa2702b2592b5b022c9052/xla/pjrt/c/pjrt_c_api_gpu_internal.cc#L58-L67 | ||||||||||||||
return { | ||||||||||||||
"platform_name": "gpu", | ||||||||||||||
# TODO(wcromar): make this configurable | ||||||||||||||
"allocator": "cuda_async" if xu.getenv_as("PJRT_ALLOCATOR_CUDA_ASYNC", bool, False) else "default", | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for these 3 settings, is it possible not to use the hardcoded default settings: False, .75, True, such as xla/torch_xla/csrc/runtime/pjrt_registry.cc Lines 22 to 27 in 4bf8d44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Removed some options entirely when the environment variable is not set. |
||||||||||||||
"memory_fraction": xu.getenv_as("PJRT_ALLOCATOR_FRACTION", float, .75), | ||||||||||||||
"preallocate": xu.getenv_as("PJRT_ALLOCATOR_PREALLOCATE", bool, True), | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these env var new? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These all exist in |
||||||||||||||
"visible_devices": [local_process_rank], | ||||||||||||||
will-cromar marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
"node_id": global_process_rank, | ||||||||||||||
"num_nodes": global_world_size, | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
def requires_xla_coordinator(self) -> bool: | ||||||||||||||
return True | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why it always return True. For single processing, probably we don't need the coordinator? So should it depend on whether it is single processing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in a previous draft, I caught this case in |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,16 @@ | |
namespace torch_xla { | ||
namespace runtime { | ||
|
||
std::unordered_map<std::string, std::string> pjrt_plugins_; | ||
|
||
namespace { | ||
|
||
struct PluginEntry { | ||
std::string library_path; | ||
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options; | ||
bool init_coordinator; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should drop On TPU it's not currently required, but if that changes we can always detect the environment from the GCE metadata and set the env vars automatically for the user in a distributed context, since we don't require torchrun for multicontroller execution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this case, I think we still want to keep My other idea initially was that we could ask the plugin for the master IP, local rank, global rank, and world size, perhaps just asking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see - that makes sense. Just for context why I brought this up, JAX recently started requiring the coordinator to be initialized before the backend can be used (even on TPUs), but I'm not sure on the reason. Keeping |
||
}; | ||
|
||
std::unordered_map<std::string, PluginEntry> pjrt_plugins_; | ||
|
||
xla::GpuAllocatorConfig GetGpuAllocatorConfig() { | ||
auto allocator_config = xla::GpuAllocatorConfig{}; | ||
if (sys_util::GetEnvString(env::kEnvPjrtAllocatorCudaAsync, "").empty() && | ||
|
@@ -35,17 +41,21 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() { | |
return allocator_config; | ||
} | ||
|
||
std::optional<std::string> GetPjRtPluginPath(const std::string& device_type) { | ||
std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) { | ||
auto plugin_path = pjrt_plugins_.find(device_type); | ||
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second) | ||
: std::nullopt; | ||
} | ||
|
||
} // namespace | ||
|
||
void RegisterPjRtPlugin(std::string name, std::string library_path) { | ||
void RegisterPjRtPlugin( | ||
std::string name, std::string library_path, | ||
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options, | ||
bool init_coordinator) { | ||
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path; | ||
pjrt_plugins_[name] = library_path; | ||
pjrt_plugins_[name] = {std::move(library_path), std::move(create_options), | ||
init_coordinator}; | ||
} | ||
|
||
std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>> | ||
|
@@ -54,13 +64,46 @@ InitializePjRt(const std::string& device_type) { | |
std::unique_ptr<XlaCoordinator> coordinator; | ||
|
||
if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) { | ||
std::optional<std::string> plugin_path = GetPjRtPluginPath(device_type); | ||
if (plugin_path) { | ||
std::optional<PluginEntry> plugin = GetPjRtPlugin(device_type); | ||
will-cromar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (plugin) { | ||
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type; | ||
|
||
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; | ||
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; | ||
if (plugin->init_coordinator) { | ||
int global_process_rank = sys_util::GetEnvInt("RANK", 0); | ||
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. didn't we already get it in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I didn't catch it earlier. What if the users start the single-host training in a non-torchrun way such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is slightly wrong. The precedence should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it be WORLD_SIZE -> PJRT_LOCAL_PROCESS_COUNT or LOCAL_WORLD_SIZE -> 1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, you're right. I tripped over this while testing as well. |
||
std::string master_addr = | ||
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); | ||
std::string port = runtime::sys_util::GetEnvString( | ||
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); | ||
|
||
if (global_world_size > 1) { | ||
// Use the XlaCoordinator as the distributed key-value store. | ||
coordinator = std::make_unique<XlaCoordinator>( | ||
global_process_rank, global_world_size, master_addr, port); | ||
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client = | ||
coordinator->GetClient(); | ||
std::string key_prefix = "gpu:"; | ||
will-cromar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kv_get = [distributed_client, key_prefix]( | ||
std::string_view k, | ||
absl::Duration timeout) -> xla::StatusOr<std::string> { | ||
return distributed_client->BlockingKeyValueGet( | ||
absl::StrCat(key_prefix, k), timeout); | ||
}; | ||
kv_put = [distributed_client, key_prefix]( | ||
std::string_view k, std::string_view v) -> xla::Status { | ||
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), | ||
v); | ||
}; | ||
} | ||
} | ||
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin( | ||
absl::AsciiStrToLower(device_type), *plugin_path); | ||
absl::AsciiStrToLower(device_type), plugin->library_path); | ||
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type)); | ||
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type)).value(); | ||
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type), | ||
will-cromar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
plugin->create_options, kv_get, kv_put) | ||
.value(); | ||
profiler::RegisterProfilerForPlugin(c_api); | ||
} | ||
} else if (device_type == "CPU") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps irrelevant to this PR but just want to confirm that the
# TODO: default to actual device count
still holds, sinceGPU_NUM_DEVICES
is not always set and the default value may not be 1.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it's only used in
run_multiprocess
. So it sound usingGPU_NUM_DEVICES
preserves the current behavior. Looks good to me then.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is the same as the current behavior. Ideally this should check the PCI device IDs like we do for TPUs.