Skip to content

Commit

Permalink
Support list of ShardingSpec in MpDeviceLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Nov 10, 2023
1 parent bd29e79 commit ab617f8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
17 changes: 17 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,23 @@ def test_send_cpu_data_to_device_with_sharding(self):
torch_xla._XLAC._get_xla_sharding_spec(xt),
torch_xla._XLAC._get_xla_sharding_spec(explicit_xt))

def test_send_cpu_data_to_device_with_multiple_sharding(self):
tensors = [torch.randn(16), torch.randn(16, 16), torch.randn(16, 16, 16)]
mesh = self._get_mesh((self.n_devices, 1))
specs = [
xs.ShardingSpec(mesh, spec) for spec in [(0, None), (0, None, None)]
]
xtensors = xm.send_cpu_data_to_device(tensors, xm.xla_device(), specs)
str_specs = [torch_xla._XLAC._get_xla_sharding_spec(t) for t in xtensors]
self.assertEqual(str_specs[0], '{replicated}')
if self.n_devices > 1:
dev_fmt = (self.n_devices, ','.join(map(str, range(self.n_devices))))
self.assertEqual(str_specs[1], "{devices=[%d,1]%s}" % dev_fmt)
self.assertEqual(str_specs[2], "{devices=[%d,1,1]%s}" % dev_fmt)
else:
self.assertEqual(str_specs[1], '{replicated}')
self.assertEqual(str_specs[2], '{replicated}')

def test_multiple_operations(self):
t1 = torch.randn(2, 2)
t2 = torch.randn(2, 2)
Expand Down
10 changes: 9 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,15 @@ def convert_fn(tensors):
devices = [str(device)] * len(tensors)
shardings = None
if input_sharding:
shardings = [input_sharding.xla_spec(t) for t in tensors]
if isinstance(input_sharding, list):
shardings = [None] * len(tensors)
for i, tensor in enumerate(tensors):
for sharding in input_sharding:
if sharding.can_apply(tensor):
shardings[i] = sharding.xla_spec(tensor)
break
else:
shardings = [input_sharding.xla_spec(t) for t in tensors]
xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
shardings)
return xtensors
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
input_sharding (Union[ShardingSpec, List[ShardingSpec]], optional): Sharding
specs to apply to compatible input tensors when loading.
Default: None
"""

Expand Down

0 comments on commit ab617f8

Please sign in to comment.