diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1473fd5f995..29749a19596 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -17,6 +17,7 @@ import torch_xla.debug.metrics as met import torch_xla.distributed.spmd as xs from torch_xla.distributed.spmd import XLAShardedTensor +import torch_xla.distributed.parallel_loader as pl import test_xla_sharding_base import torch_xla.core.xla_env_vars as xenv @@ -1310,6 +1311,72 @@ def test_get_1d_mesh(self): self.assertEqual(mesh_without_name.mesh_shape, (xr.global_runtime_device_count(),)) + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") + def test_data_loader_with_sharding(self): + device = torch_xla.device() + mesh = xs.get_1d_mesh("data") + batch_size = 8 + train_loader = xu.SampleGenerator( + data=(torch.zeros(batch_size, 3, 64, + 64), torch.zeros(batch_size, dtype=torch.int64)), + sample_count=100) + train_device_loader = pl.MpDeviceLoader( + train_loader, + device, + # Shard the input's batch dimension along the `data` axis, no sharding along other dimensions + input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) + data, _ = iter(train_device_loader).__next__() + self.assertEqual(data.size(), torch.Size([8, 3, 64, 64])) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(data), + f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + ) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") + def test_data_loader_with_non_batch_size(self): + device = torch_xla.device() + mesh = xs.get_1d_mesh("data") + batch_size = mesh.size() - 1 + train_loader = xu.SampleGenerator( + data=(torch.zeros(batch_size, 3, 64, + 64), torch.zeros(batch_size, dtype=torch.int64)), + sample_count=100) + train_device_loader = pl.MpDeviceLoader( + train_loader, + device, + # Shard the input's batch dimension along the `data` axis, no sharding along other dimensions + input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) + data, _ = iter(train_device_loader).__next__() + self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64])) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(data), + f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + ) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") + def test_data_loader_with_non_batch_size_and_mini_batch(self): + device = torch_xla.device() + mesh = xs.get_1d_mesh("data") + batch_size = mesh.size() - 1 + train_loader = xu.SampleGenerator( + data=(torch.zeros(batch_size, 3, 64, + 64), torch.zeros(batch_size, dtype=torch.int64)), + sample_count=100) + train_device_loader = pl.MpDeviceLoader( + train_loader, + device, + # Shard the input's batch dimension along the `data` axis, no sharding along other dimensions + input_sharding=xs.ShardingSpec( + mesh, ('data', None, None, None), minibatch=True)) + with self.assertRaisesRegex( + RuntimeError, + "When minibatch is configured, batch dimension of the tensor must be divisible by local runtime device count*" + ): + data, _ = iter(train_device_loader).__next__() + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 39f6083e828..0e403fb1982 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1299,6 +1299,22 @@ def convert_fn(tensors): shardings = None if input_sharding: shardings = [input_sharding.xla_spec(t) for t in tensors] + if input_sharding and input_sharding.minibatch: + # when minibatch is configured we must make sure batch dimension of + # the tensor is divisible by the local runtime device count. + for tensor, sharding in zip(tensors, shardings): + # assume batch dimension is 0 + local_runtime_device_count = torch_xla.runtime.addressable_runtime_device_count( + ) + if sharding and tensor.dim() > 0 and (tensor.size()[0] % + local_runtime_device_count) != 0: + raise RuntimeError( + "When minibatch is configured, batch dimension of the tensor " + + "must be divisible by local runtime device count.input data shape " + + + f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}" + ) + xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices, shardings) return xtensors