Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list of ShardingSpec in MpDeviceLoader #5789

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,23 @@ def test_send_cpu_data_to_device_with_sharding(self):
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(explicit_xt))

def test_send_cpu_data_to_device_with_multiple_sharding(self):
tensors = [torch.randn(16), torch.randn(16, 16), torch.randn(16, 16, 16)]
mesh = self._get_mesh((self.n_devices, 1))
specs = [
xs.ShardingSpec(mesh, spec) for spec in [(0, None), (0, None, None)]
]
xtensors = xm.send_cpu_data_to_device(tensors, xm.xla_device(), specs)
str_specs = [torch_xla._XLAC._get_xla_sharding_spec(t) for t in xtensors]
self.assertEqual(str_specs[0], '{replicated}')
if self.n_devices > 1:
dev_fmt = (self.n_devices, ','.join(map(str, range(self.n_devices))))
self.assertEqual(str_specs[1], "{devices=[%d,1]%s}" % dev_fmt)
self.assertEqual(str_specs[2], "{devices=[%d,1,1]%s}" % dev_fmt)
else:
self.assertEqual(str_specs[1], '{replicated}')
self.assertEqual(str_specs[2], '{replicated}')

def test_multiple_operations(self):
t1 = torch.randn(2, 2)
t2 = torch.randn(2, 2)
Expand Down
10 changes: 9 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,15 @@ def convert_fn(tensors):
devices = [str(device)] * len(tensors)
shardings = None
if input_sharding:
shardings = [input_sharding.xla_spec(t) for t in tensors]
if isinstance(input_sharding, list):
shardings = [None] * len(tensors)
for i, tensor in enumerate(tensors):
for sharding in input_sharding:
if sharding.can_apply(tensor):
shardings[i] = sharding.xla_spec(tensor)
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we also add some comments about the first match is applied?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks Yeounoh!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is first match something common practice? Or will it be better:

  1. either one sharding spec for everything;
  2. or one sharding spec for each input and in order.

That the behavior is well defined. Otherwise, this "compatible check" is totally a black box for the user.

Copy link
Collaborator Author

@jonb377 jonb377 Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely not ideal... Probably better would be to have the user provide sharding specs in a structure matching that of the inputs and avoid the black-box compatibility check, e.g.:

# Inputs of the form:
{"a": torch.randn(16), "b": torch.randn(16, 16), "c": torch.randn(16, 16, 16)}

# Then input_sharding would be:
{"a": ShardingSpec(mesh, (0,)), "b": ShardingSpec(mesh, (0, None)), "c": ShardingSpec(mesh, (0, None, None))}

This is a bit of a refactor of the TensorToXlaArena code though, so I took the easy way out. I'll go ahead and take a stab at the cleaner approach now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. Take your time.

else:
shardings = [input_sharding.xla_spec(t) for t in tensors]
xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
shardings)
return xtensors
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
input_sharding (Union[ShardingSpec, List[ShardingSpec]], optional): Sharding
specs to apply to compatible input tensors when loading.
Default: None
"""

Expand Down