diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f9e46dce55d..18964f358b1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -152,29 +152,28 @@ PjRtComputationClient::PjRtComputationClient() { xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix](const std::string& k, - absl::Duration timeout) { + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; - kv_put = [distributed_client, key_prefix](const std::string& k, - const std::string& v) { + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; - client_ = std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, - /*allocator_config=*/GetGpuAllocatorConfig(), - /*node_id=*/global_process_rank, - /*num_nodes=*/global_world_size, - /*allowed_devices=*/allowed_devices, - /*platform_name=*/"gpu", - /*should_stage_host_to_device_transfers=*/true, - /*kv_get=*/kv_get, - /*kv_put=*/kv_put) - .value()); + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.kv_get = kv_get; + options.kv_put = kv_put; + client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; XLA_CHECK_OK(