diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 5cfa4518cb4..214c5f08afa 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -97,6 +97,16 @@ def test_allgather(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) hlo_matches(hlo, all_gather_pattern) + @patch_world(rank=3, size=8) + def test_all_scalar_allgather(self): + device = xm.xla_device() + tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank() + output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)] + all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' + dist.all_gather(output_tensors, tensor) + hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) + hlo_matches(hlo, all_gather_pattern) + @patch_world(rank=3, size=8) def test_allgather_coalesced(self): device = xm.xla_device() diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 401cc623622..0348f2e9a6d 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -73,10 +73,14 @@ def allreduce(self, tensors, all_reduce_options): def allgather(self, output_tensors_list, input_tensors, opts=None): for input_tensor, output_tensors in zip(input_tensors, output_tensors_list): + is_scalar = (input_tensor.dim() == 0) + if is_scalar: + input_tensor = torch.reshape(input_tensor, (1,)) result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False) for i, slice in enumerate(torch.split(result, input_tensor.shape[0])): with torch.no_grad(): - output_tensors[i].copy_(slice) + output_tensors[i].copy_( + slice if not is_scalar else torch.reshape(slice, ())) return _ret_work([t for sublist in output_tensors_list for t in sublist])