diff --git a/test/pjrt/test_runtime_gpu.py b/test/pjrt/test_runtime_gpu.py index 07cb40d02b9..d82144b2c1a 100644 --- a/test/pjrt/test_runtime_gpu.py +++ b/test/pjrt/test_runtime_gpu.py @@ -149,9 +149,7 @@ def _all_gather(pin_layout): out = xm.all_gather(ordinal, pin_layout=pin_layout) xm.mark_step() - ret = out.cpu().numpy() - print('_all_gather returning ret=', ret) - return ret + return out.cpu().numpy() @parameterized.named_parameters(('pinned', True), ('unpinned', False)) def test_all_gather(self, pin_layout): diff --git a/test/pjrt/test_torchrun.py b/test/pjrt/test_torchrun.py index cb891e15a9b..6ed5118f2fd 100644 --- a/test/pjrt/test_torchrun.py +++ b/test/pjrt/test_torchrun.py @@ -6,7 +6,6 @@ import torch_xla.distributed.xla_backend import torch_xla.runtime as xr import torch_xla.utils.utils as xu -from torch_xla._internal import gpu class TestTorchrun(absltest.TestCase): @@ -37,8 +36,8 @@ def test_all_reduce(self): dist_world_size = xu.getenv_as('WORLD_SIZE', int) devices_per_thread = xr.addressable_device_count() - expected_world_size = dist_world_size * devices_per_thread - tensors = [torch.arange(2, dtype=torch.int64) + 1 + 2 * r for r in range(expected_world_size)] + world_size = dist_world_size * devices_per_thread + tensors = [torch.arange(2, dtype=torch.int64) + 1 + 2 * r for r in range(world_size)] expected = sum(tensors) xla_tensor = torch.arange(2, dtype=torch.int64, device=xm.xla_device()) + 1 + 2 * dist.get_rank() @@ -48,6 +47,23 @@ def test_all_reduce(self): torch.testing.assert_close(xla_tensor.cpu(), expected) dist.destroy_process_group() + def test_reduce_scatter(self): + dist.init_process_group('xla', init_method='xla://') + + dist_world_size = xu.getenv_as('WORLD_SIZE', int) + devices_per_thread = xr.addressable_device_count() + world_size = dist_world_size * devices_per_thread + tensor = world_size * torch.arange(world_size * 2, dtype=torch.int64) + expected = torch.split(tensor, world_size)[dist.get_rank()] + + tensor_out = torch.zeros(world_size, dtype=torch.int64, device=xm.xla_device()) + tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=xm.xla_device()) + dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM) + xm.mark_step() + + torch.testing.assert_close(tensor_out.cpu(), expected) + dist.destroy_process_group() + if __name__ == '__main__': if not dist.is_torchelastic_launched(): diff --git a/torch_xla/core/xla_env_vars.py b/torch_xla/core/xla_env_vars.py index f8c573c73ff..42209193d81 100644 --- a/torch_xla/core/xla_env_vars.py +++ b/torch_xla/core/xla_env_vars.py @@ -27,5 +27,4 @@ PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR' LOCAL_RANK = 'LOCAL_RANK' RANK = 'RANK' -LOCAL_WORLD_SIZE = 'LOCAL_WORLD_SIZE' WORLD_SIZE = 'WORLD_SIZE' diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 3803ddf71e9..14ef023aa0f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -49,8 +49,9 @@ MaybeInitializeDistributedRuntimeClient(int local_rank) { std::string port = runtime::sys_util::GetEnvString("COORDINATOR_PORT", "8547"); std::string dist_service_addr = master_addr+":"+port ; xla::DistributedRuntimeClient::Options options; - /* TODO(jonbolin): Use global rank for multi-host setup */ options.node_id = local_rank; + TF_VLOG(3) << "Getting distributed runtime client for address=" + << dist_service_addr << ", node_id=" << options.node_id; client = xla::GetDistributedRuntimeClient(dist_service_addr, options); XLA_CHECK(client->Connect().ok()) << "Failed to initialize distributed runtime client"; @@ -155,6 +156,8 @@ PjRtComputationClient::PjRtComputationClient() { }; } int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1); + TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" + << global_rank << ", num_nodes=" << global_world_size; client_ = std::move(xla::GetStreamExecutorGpuClient( /*asynchronous=*/async,