diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index d55a8b7b310..f42289f8d26 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -208,9 +208,9 @@ def test_xla_sharding_type(self): self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None) x_dim = 2 if self.n_devices >= 2 else 1 - # if self.n_devices>=4, mesh=(2, 2) - # if self.n_devices>=2, mesh=(2,1) - # if self.n_devices=1, mesh=(1,1) + # if self.n_devices==4, mesh=(2,2) + # if self.n_devices==2, mesh=(2,1) + # if self.n_devices==1, mesh=(1,1) mesh = self._get_mesh((x_dim, self.n_devices // x_dim)) xt = xs.mark_sharding(t, mesh, (0, 1)) if self.n_devices >= 2: diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 5f1e610d1d6..e3746710cb7 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -154,8 +154,6 @@ InitializePjRt(const std::string& device_type) { int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::shared_ptr kv_store; - std::optional> allowed_devices; TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size << ", spmd case=" << sys_util::GetEnvBool("XLA_USE_SPMD", false) @@ -165,15 +163,24 @@ InitializePjRt(const std::string& device_type) { << ", LOCAL_WORLD_SIZE=" << sys_util::GetEnvString("LOCAL_WORLD_SIZE", "") << ", WORLD_SIZE=" << sys_util::GetEnvString("WORLD_SIZE", ""); - if (local_world_size == 1) { - if (global_world_size > 1) { - coordinator = SetGpuClientKVCallBack(global_process_rank, - global_world_size, kv_store); - } - } else { + std::optional> allowed_devices; + if (local_world_size > 1) { allowed_devices = std::set{local_process_rank}; - coordinator = SetGpuClientKVCallBack(global_process_rank, - global_world_size, kv_store); + } + + std::shared_ptr kv_store; + if (global_world_size > 1) { + // Use the distributed key-value store from DistributedRuntimeClient. + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + kv_store = xla::GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); } xla::GpuClientOptions options;