Skip to content

Commit

Permalink
Allow scalar all gather (#5797)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfontain authored Dec 11, 2023
1 parent 9be1e94 commit b224350
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def test_allgather(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

@patch_world(rank=3, size=8)
def test_all_scalar_allgather(self):
device = xm.xla_device()
tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank()
output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)]
all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\('
dist.all_gather(output_tensors, tensor)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

@patch_world(rank=3, size=8)
def test_allgather_coalesced(self):
device = xm.xla_device()
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ def allreduce(self, tensors, all_reduce_options):

def allgather(self, output_tensors_list, input_tensors, opts=None):
for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
is_scalar = (input_tensor.dim() == 0)
if is_scalar:
input_tensor = torch.reshape(input_tensor, (1,))
result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False)
for i, slice in enumerate(torch.split(result, input_tensor.shape[0])):
with torch.no_grad():
output_tensors[i].copy_(slice)
output_tensors[i].copy_(
slice if not is_scalar else torch.reshape(slice, ()))

return _ret_work([t for sublist in output_tensors_list for t in sublist])

Expand Down

0 comments on commit b224350

Please sign in to comment.