From 2df3c56d2fcfa9a00a3c26894c75b529669a108a Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Sat, 9 Dec 2023 07:00:39 +0000 Subject: [PATCH] Add all-gather coalesce tests; fix all-gather coalesce bug "len(input)" --- test/test_mp_all_gather.py | 45 ++++++++++++++++++++++++++++++++++++- torch_xla/core/xla_model.py | 8 +++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 3ffeebc963d6..626573aa6d20 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -13,7 +13,8 @@ def all_gather(tensor, dim): def _mp_fn(index): device = xm.xla_device() world_size = xm.xrt_world_size() - if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'): + input_list_size = 5 + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'): # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) result = xm.all_gather(ordinal_tensor, dim=0) @@ -57,6 +58,48 @@ def _mp_fn(index): f'Failed to create two replica groups with {world_size} replicas', file=sys.stderr) + # Testing with a single replica group and tensor list as input + ordinal_tensors = [ + torch.tensor([i * 1000 + index], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output=None + result_list = xm.all_gather(ordinal_tensors, dim=0, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + + # Testing with a single replica group and tensor list as input and output!=None (out-of-place) + ordinal_tensors = [ + torch.tensor([i * 1000 + index], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + output_tensors = [ + torch.zeros([world_size], dtype=torch.float).to(device) + for i in range(input_list_size) + ] + # TODO: add support for list input with pin_layout=True and output!=None + result_list = xm.all_gather( + ordinal_tensors, dim=0, output=output_tensors, pin_layout=False) + + for i, result in enumerate(result_list): + cpu_result = result.cpu() + expected = i * 1000 + torch.arange(world_size, dtype=torch.float) + if not cpu_result.allclose(expected): + print( + 'xm.all_gather() produced wrong reductions for item {i} in result list', + file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + # TODO: add test for torch.compile when support for list input is ready + else: print(f'{device} is not a TPU or GPU device', file=sys.stderr) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 8bf298be2def..9f7cb74e6b96 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -553,7 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): A tensor which has, in the ``dim`` dimension, all the values from the participating replicas. """ - if pin_layout and (output == None or xla_device_hw(value.device) == 'NEURON'): + if pin_layout and output == None and isinstance(value, torch.Tensor): # There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON, # use all_reduce based all_gather for this purpose. return _all_gather_using_all_reduce( @@ -587,13 +587,17 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True): # Now the input should be a list of Tensors. elif isinstance(value, list) and all( isinstance(v, torch.Tensor) for v in value): + if pin_layout: + raise RuntimeError( + "For xm.all_gather with list of tensors input, pin_layout=True is not yet supported." + ) if output != None: if not isinstance(output, list) or any( not isinstance(v, torch.Tensor) for v in output): raise TypeError( f"`output` needs to be a list of Tensors, but given {type(output)}." ) - if len(output) != len(input): + if len(output) != len(value): raise ValueError("`output` length doesn't match `input` length: " f"{len(output)} vs {len(input)}.") # Call the out of place version of the reduce_scatter