Skip to content

Commit

Permalink
Update pjrt_computation_client.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 29, 2023
1 parent 7d50110 commit c95ab73
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,28 @@ PjRtComputationClient::PjRtComputationClient() {
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
if (distributed_client != nullptr) {
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](const std::string& k,
absl::Duration timeout) {
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](const std::string& k,
const std::string& v) {
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async,
/*allocator_config=*/GetGpuAllocatorConfig(),
/*node_id=*/global_process_rank,
/*num_nodes=*/global_world_size,
/*allowed_devices=*/allowed_devices,
/*platform_name=*/"gpu",
/*should_stage_host_to_device_transfers=*/true,
/*kv_get=*/kv_get,
/*kv_put=*/kv_put)
.value());
xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
options.num_nodes = global_world_size;
options.allowed_devices = allowed_devices;
options.platform_name = "gpu";
options.kv_get = kv_get;
options.kv_put = kv_put;
client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(
Expand Down

0 comments on commit c95ab73

Please sign in to comment.