Skip to content

Commit

Permalink
fix the broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Jan 16, 2024
1 parent 656f944 commit fcad102
Showing 1 changed file with 50 additions and 26 deletions.
76 changes: 50 additions & 26 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,40 @@ std::optional<std::string> GetPjRtPluginPath(const std::string& device_type) {
: std::nullopt;
}

std::unique_ptr<XlaCoordinator> SetKeyValueCallback(
int global_process_rank, int global_world_size,
std::unique_ptr<XlaCoordinator> coordinator,
xla::PjRtClient::KeyValueGetCallback& kv_get,
xla::PjRtClient::KeyValuePutCallback& kv_put) {
<< "function=" << __FUNCTION__ << ": " << std::endl;
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

// Use the XlaCoordinator as the distributed key-value store.
TF_VLOG(3) << "Creating a XlaCoordinator for global_process_rank="
<< global_process_rank
<< ", global_world_size=" << global_world_size
<< ", master_addr=" << master_addr << ", port=" << port;
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(absl::StrCat(key_prefix, k),
timeout);
};
kv_put = [distributed_client, key_prefix](std::string_view k,
std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
return coordinator;
}

} // namespace

void RegisterPjRtPlugin(std::string name, std::string library_path) {
Expand Down Expand Up @@ -90,39 +124,29 @@ InitializePjRt(const std::string& device_type) {
int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank);
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);

xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
xla::PjRtClient::KeyValuePutCallback kv_put = nullptr;
std::optional<std::set<int>> allowed_devices;
bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false);
if (!spmd) {
allowed_devices = std::set{local_process_rank};
} else if (global_world_size > 1) {
allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
// Use the XlaCoordinator as the distributed key-value store.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
std::string key_prefix = "gpu:";
kv_get = [distributed_client, key_prefix](
std::string_view k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
kv_put = [distributed_client, key_prefix](
std::string_view k, std::string_view v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
if (spmd) {
if (global_world_size > 1) {
coordinator =
SetKeyValueCallback(global_process_rank, global_world_size,
std::move(coordinator), kv_get, kv_put);
}
} else {
if (global_world_size > 1) {
allowed_devices = std::set{local_process_rank};
coordinator =
SetKeyValueCallback(global_process_rank, global_world_size,
std::move(coordinator), kv_get, kv_put);
}
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
<< global_process_rank << ", num_nodes=" << global_world_size
<< ", spmd case=" << spmd;

xla::GpuClientOptions options;
options.allocator_config = GetGpuAllocatorConfig();
options.node_id = global_process_rank;
Expand Down

0 comments on commit fcad102

Please sign in to comment.