diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index f1b6a6711370..ce8b3f029e5c 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -5,7 +5,9 @@ #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/distributed/client.h" #include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" @@ -95,33 +97,25 @@ InitializePjRt(const std::string& device_type) { std::string port = runtime::sys_util::GetEnvString( "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; - xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); std::optional> allowed_devices; if (!spmd) { allowed_devices = std::set{local_process_rank}; } + + std::shared_ptr kv_store; if (global_world_size > 1) { - // Use the XlaCoordinator as the distributed key-value store. + // Use the distributed key-value store from DistributedRuntimeClient. coordinator = std::make_unique( global_process_rank, global_world_size, master_addr, port); std::shared_ptr distributed_client = coordinator->GetClient(); - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; + xla::GpuClientOptions options; options.allocator_config = GetGpuAllocatorConfig(); options.node_id = global_process_rank; @@ -129,8 +123,7 @@ InitializePjRt(const std::string& device_type) { options.allowed_devices = allowed_devices; options.platform_name = "gpu"; options.should_stage_host_to_device_transfers = true; - options.kv_get = kv_get; - options.kv_put = kv_put; + options.kv_store = kv_store; client = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client...";