Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Feb 2, 2024
1 parent aee08df commit b3991a7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
6 changes: 3 additions & 3 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def test_xla_sharding_type(self):
self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None)

x_dim = 2 if self.n_devices >= 2 else 1
# if self.n_devices>=4, mesh=(2, 2)
# if self.n_devices>=2, mesh=(2,1)
# if self.n_devices=1, mesh=(1,1)
# if self.n_devices==4, mesh=(2,2)
# if self.n_devices==2, mesh=(2,1)
# if self.n_devices==1, mesh=(1,1)
mesh = self._get_mesh((x_dim, self.n_devices // x_dim))
xt = xs.mark_sharding(t, mesh, (0, 1))
if self.n_devices >= 2:
Expand Down
27 changes: 17 additions & 10 deletions torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ InitializePjRt(const std::string& device_type) {
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;
std::optional<std::set<int>> allowed_devices;
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size
<< ", spmd case=" << sys_util::GetEnvBool("XLA_USE_SPMD", false)
Expand All @@ -165,15 +163,24 @@ InitializePjRt(const std::string& device_type) {
<< ", LOCAL_WORLD_SIZE="
<< sys_util::GetEnvString("LOCAL_WORLD_SIZE", "")
<< ", WORLD_SIZE=" << sys_util::GetEnvString("WORLD_SIZE", "");
if (local_world_size == 1) {
if (global_world_size > 1) {
coordinator = SetGpuClientKVCallBack(global_process_rank,
global_world_size, kv_store);
}
} else {
std::optional<std::set<int>> allowed_devices;
if (local_world_size > 1) {
allowed_devices = std::set{local_process_rank};
coordinator = SetGpuClientKVCallBack(global_process_rank,
global_world_size, kv_store);
}

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;
if (global_world_size > 1) {
// Use the distributed key-value store from DistributedRuntimeClient.
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort);
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
}

xla::GpuClientOptions options;
Expand Down

0 comments on commit b3991a7

Please sign in to comment.