Skip to content

Commit

Permalink
add reduce scatter test
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 committed Oct 5, 2023
1 parent 2b1b2e4 commit 0cb25b4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
4 changes: 1 addition & 3 deletions test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def _all_gather(pin_layout):
out = xm.all_gather(ordinal, pin_layout=pin_layout)
xm.mark_step()

ret = out.cpu().numpy()
print('_all_gather returning ret=', ret)
return ret
return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_gather(self, pin_layout):
Expand Down
22 changes: 19 additions & 3 deletions test/pjrt/test_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu
from torch_xla._internal import gpu


class TestTorchrun(absltest.TestCase):
Expand Down Expand Up @@ -37,8 +36,8 @@ def test_all_reduce(self):

dist_world_size = xu.getenv_as('WORLD_SIZE', int)
devices_per_thread = xr.addressable_device_count()
expected_world_size = dist_world_size * devices_per_thread
tensors = [torch.arange(2, dtype=torch.int64) + 1 + 2 * r for r in range(expected_world_size)]
world_size = dist_world_size * devices_per_thread
tensors = [torch.arange(2, dtype=torch.int64) + 1 + 2 * r for r in range(world_size)]
expected = sum(tensors)

xla_tensor = torch.arange(2, dtype=torch.int64, device=xm.xla_device()) + 1 + 2 * dist.get_rank()
Expand All @@ -48,6 +47,23 @@ def test_all_reduce(self):
torch.testing.assert_close(xla_tensor.cpu(), expected)
dist.destroy_process_group()

def test_reduce_scatter(self):
dist.init_process_group('xla', init_method='xla://')

dist_world_size = xu.getenv_as('WORLD_SIZE', int)
devices_per_thread = xr.addressable_device_count()
world_size = dist_world_size * devices_per_thread
tensor = world_size * torch.arange(world_size * 2, dtype=torch.int64)
expected = torch.split(tensor, world_size)[dist.get_rank()]

tensor_out = torch.zeros(world_size, dtype=torch.int64, device=xm.xla_device())
tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=xm.xla_device())
dist.reduce_scatter(tensor_out, [tensor_in], op=dist.ReduceOp.SUM)
xm.mark_step()

torch.testing.assert_close(tensor_out.cpu(), expected)
dist.destroy_process_group()


if __name__ == '__main__':
if not dist.is_torchelastic_launched():
Expand Down
1 change: 0 additions & 1 deletion torch_xla/core/xla_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@
PJRT_DIST_SERVICE_ADDR = 'PJRT_DIST_SERVICE_ADDR'
LOCAL_RANK = 'LOCAL_RANK'
RANK = 'RANK'
LOCAL_WORLD_SIZE = 'LOCAL_WORLD_SIZE'
WORLD_SIZE = 'WORLD_SIZE'
5 changes: 4 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ MaybeInitializeDistributedRuntimeClient(int local_rank) {
std::string port = runtime::sys_util::GetEnvString("COORDINATOR_PORT", "8547");
std::string dist_service_addr = master_addr+":"+port ;
xla::DistributedRuntimeClient::Options options;
/* TODO(jonbolin): Use global rank for multi-host setup */
options.node_id = local_rank;
TF_VLOG(3) << "Getting distributed runtime client for address="
<< dist_service_addr << ", node_id=" << options.node_id;
client = xla::GetDistributedRuntimeClient(dist_service_addr, options);
XLA_CHECK(client->Connect().ok())
<< "Failed to initialize distributed runtime client";
Expand Down Expand Up @@ -155,6 +156,8 @@ PjRtComputationClient::PjRtComputationClient() {
};
}
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", 1);
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_rank << ", num_nodes=" << global_world_size;
client_ =
std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async,
Expand Down

0 comments on commit 0cb25b4

Please sign in to comment.