From ab617f89fbe90fa36618cfe6c5d9775d280b2daf Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Fri, 10 Nov 2023 00:57:01 +0000 Subject: [PATCH] Support list of ShardingSpec in MpDeviceLoader --- test/spmd/test_xla_sharding.py | 17 +++++++++++++++++ torch_xla/core/xla_model.py | 10 +++++++++- torch_xla/distributed/parallel_loader.py | 4 ++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1b128164a22..5b74f05c35f 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -653,6 +653,23 @@ def test_send_cpu_data_to_device_with_sharding(self): torch_xla._XLAC._get_xla_sharding_spec(xt), torch_xla._XLAC._get_xla_sharding_spec(explicit_xt)) + def test_send_cpu_data_to_device_with_multiple_sharding(self): + tensors = [torch.randn(16), torch.randn(16, 16), torch.randn(16, 16, 16)] + mesh = self._get_mesh((self.n_devices, 1)) + specs = [ + xs.ShardingSpec(mesh, spec) for spec in [(0, None), (0, None, None)] + ] + xtensors = xm.send_cpu_data_to_device(tensors, xm.xla_device(), specs) + str_specs = [torch_xla._XLAC._get_xla_sharding_spec(t) for t in xtensors] + self.assertEqual(str_specs[0], '{replicated}') + if self.n_devices > 1: + dev_fmt = (self.n_devices, ','.join(map(str, range(self.n_devices)))) + self.assertEqual(str_specs[1], "{devices=[%d,1]%s}" % dev_fmt) + self.assertEqual(str_specs[2], "{devices=[%d,1,1]%s}" % dev_fmt) + else: + self.assertEqual(str_specs[1], '{replicated}') + self.assertEqual(str_specs[2], '{replicated}') + def test_multiple_operations(self): t1 = torch.randn(2, 2) t2 = torch.randn(2, 2) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e85db1d20a6..8b3edc7fac4 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1012,7 +1012,15 @@ def convert_fn(tensors): devices = [str(device)] * len(tensors) shardings = None if input_sharding: - shardings = [input_sharding.xla_spec(t) for t in tensors] + if isinstance(input_sharding, list): + shardings = [None] * len(tensors) + for i, tensor in enumerate(tensors): + for sharding in input_sharding: + if sharding.can_apply(tensor): + shardings[i] = sharding.xla_spec(tensor) + break + else: + shardings = [input_sharding.xla_spec(t) for t in tensors] xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices, shardings) return xtensors diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 8af7196e95c..305933a6b11 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -74,8 +74,8 @@ class ParallelLoader(object): host_to_device_transfer_threads (int, optional): The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1 - input_sharding (ShardingSpec, optional): Sharding spec to apply to - compatible input tensors after loading. + input_sharding (Union[ShardingSpec, List[ShardingSpec]], optional): Sharding + specs to apply to compatible input tensors when loading. Default: None """