Skip to content

Commit

Permalink
Use xla::KeyValueStoreInterface in GpuClientOptions.
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Jan 17, 2024
1 parent 364cd7f commit cc3393b
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_c_api_client.h"
Expand Down Expand Up @@ -95,42 +97,33 @@ InitializePjRt(const std::string& device_type) {
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false);
std::optional<std::set<int>> allowed_devices;
if (!spmd) {
allowed_devices = std::set{local_process_rank};
}

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;
if (global_world_size > 1) {
// Use the XlaCoordinator as the distributed key-value store.
// Use the distributed key-value store from DistributedRuntimeClient.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;

xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
options.num_nodes = global_world_size;
options.allowed_devices = allowed_devices;
options.platform_name = "gpu";
options.should_stage_host_to_device_transfers = true;
options.kv_get = kv_get;
options.kv_put = kv_put;
options.kv_store = kv_store;
client = std::move(xla::GetStreamExecutorGpuClient(options).value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
Expand Down

0 comments on commit cc3393b

Please sign in to comment.