Skip to content

Commit

Permalink
Add all-gather coalesce tests; fix all-gather coalesce bug "len(input)"
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws authored and Arjunbala committed Dec 9, 2023
1 parent 4058ee9 commit 2df3c56
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
45 changes: 44 additions & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def all_gather(tensor, dim):
def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down Expand Up @@ -57,6 +58,48 @@ def _mp_fn(index):
f'Failed to create two replica groups with {world_size} replicas',
file=sys.stderr)

# Testing with a single replica group and tensor list as input
ordinal_tensors = [
torch.tensor([i * 1000 + index], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather(ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place)
ordinal_tensors = [
torch.tensor([i * 1000 + index], dtype=torch.float).to(device)
for i in range(input_list_size)
]
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)
# TODO: add test for torch.compile when support for list input is ready

else:
print(f'{device} is not a TPU or GPU device', file=sys.stderr)

Expand Down
8 changes: 6 additions & 2 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and (output == None or xla_device_hw(value.device) == 'NEURON'):
if pin_layout and output == None and isinstance(value, torch.Tensor):
# There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON,
# use all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down Expand Up @@ -587,13 +587,17 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
# Now the input should be a list of Tensors.
elif isinstance(value, list) and all(
isinstance(v, torch.Tensor) for v in value):
if pin_layout:
raise RuntimeError(
"For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
)
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
if len(output) != len(value):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
Expand Down

0 comments on commit 2df3c56

Please sign in to comment.