From 8ff43a9b89dcae06cd7f1055726276193d278e30 Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Thu, 3 Oct 2024 17:26:28 -0700 Subject: [PATCH] Allow MpDeviceLoader to shard dictionaries of tensor for 2.5 release (#8212) --- docs/spmd_advanced.md | 26 +++- test/run_tests.sh | 1 + test/spmd/test_mp_input_sharding.py | 151 +++++++++++++++++++++++ test/tpu/run_tests.sh | 1 + torch_xla/distributed/parallel_loader.py | 77 ++++++++++-- 5 files changed, 239 insertions(+), 17 deletions(-) create mode 100644 test/spmd/test_mp_input_sharding.py diff --git a/docs/spmd_advanced.md b/docs/spmd_advanced.md index 4cd07a558c9..8ce225c54bd 100644 --- a/docs/spmd_advanced.md +++ b/docs/spmd_advanced.md @@ -8,10 +8,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall ```python # MpDeviceLoader returns ParallelLoader.per_device_loader as iterator train_loader = pl.MpDeviceLoader( - train_loader, # wraps PyTorch DataLoader - device, - # assume 4d input and we want to shard at the batch dimension. - input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None))) + train_loader, # wraps PyTorch DataLoader + device, + # assume 4d input and we want to shard at the batch dimension. + input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None))) +``` + +It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes: + +```python +# if batch = next(train_loader) looks like +# {'x': , 'y': } + +# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator +train_loader = pl.MpDeviceLoader( + train_loader, # wraps PyTorch DataLoader + device, + # specify different sharding for each input of the batch. + input_sharding={ + 'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)), + 'y': xs.ShardingSpec(input_mesh, ('data', None)) + } +) ``` ### Virtual Device Optimization diff --git a/test/run_tests.sh b/test/run_tests.sh index 9a8c8fce9d5..8f7ca6b03a3 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -234,6 +234,7 @@ function run_xla_op_tests3 { run_test "$CDIR/stablehlo/test_unbounded_dynamism.py" run_test "$CDIR/quantized_ops/test_quantized_matmul.py" run_test "$CDIR/quantized_ops/test_dot_general.py" + run_test "$CDIR/spmd/test_mp_input_sharding.py" run_test "$CDIR/spmd/test_xla_sharding.py" run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" diff --git a/test/spmd/test_mp_input_sharding.py b/test/spmd/test_mp_input_sharding.py new file mode 100644 index 00000000000..e8c59209ec6 --- /dev/null +++ b/test/spmd/test_mp_input_sharding.py @@ -0,0 +1,151 @@ +import sys +import numpy as np +import unittest + +import torch +import torch_xla +from torch_xla import runtime as xr +import torch_xla.core.xla_model as xm +from torch_xla.distributed.spmd import Mesh +import torch_xla.distributed.spmd as xs +import torch_xla.distributed.parallel_loader as pl + +xr.use_spmd() + + +class MpInputShardingTest(unittest.TestCase): + + class fake_dataloader: + + def __init__(self, batch, size=1): + self.batch = batch + self.batch_size = size + self.counter = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.counter < self.batch_size: + self.counter += 1 + return self.batch + raise StopIteration + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_multiple_inputs(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={ + 'x': xs.ShardingSpec(mesh, ('x', None)), + 'y': xs.ShardingSpec(mesh, ('x', None, None)) + }) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_single_tensor(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data)) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_single_tensor_with_input_sharding_dict(self): + device = xm.xla_device() + batch = torch.randn((16, 128)) + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + mesh = xs.get_1d_mesh('x') + + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(ValueError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_none(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + num_devices = xr.global_runtime_device_count() + + train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None) + train_loader = iter(train_loader) + data = next(train_loader) + annotation = '{replicated}' + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_error_missing_keys(self): + device = xm.xla_device() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, + device, + input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))}) + train_loader = iter(train_loader) + with self.assertRaises(KeyError): + data = next(train_loader) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for tupled partition spec") + def test_input_sharding_not_dict(self): + device = xm.xla_device() + num_devices = xr.global_runtime_device_count() + batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))} + train_loader = self.fake_dataloader(batch) + mesh = xs.get_1d_mesh('x') + train_loader = pl.MpDeviceLoader( + train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None))) + train_loader = iter(train_loader) + data = next(train_loader) + annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join( + [str(i) for i in range(num_devices)])) + self.assertEqual(annotation_x, + torch_xla._XLAC._get_xla_sharding_spec(data['x'])) + self.assertEqual(annotation_y, + torch_xla._XLAC._get_xla_sharding_spec(data['y'])) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) \ No newline at end of file diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 52d1de5b150..2d782a52b7e 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -6,6 +6,7 @@ python3 test/test_operations.py -v python3 test/pjrt/test_runtime_tpu.py python3 test/pjrt/test_collective_ops_tpu.py python3 test/spmd/test_xla_sharding.py +python3 test/spmd/test_mp_input_sharding.py python3 test/spmd/test_xla_virtual_device.py python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 3d98ff4a225..7053361f795 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -1,4 +1,5 @@ import itertools +import queue import threading import torch import torch_xla @@ -12,7 +13,7 @@ class PerDeviceQueue(object): def __init__(self, device, loader_prefetch_size, device_prefetch_size): self.device = device - self.loader_queue = kq.Queue(maxsize=loader_prefetch_size) + self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size) self.queue = kq.Queue(maxsize=device_prefetch_size) self.close_queue_count = itertools.count() @@ -47,6 +48,8 @@ def next(self): item = self._loader.next_item(self._device) if item is None: + if not self._loader._exception_queue.empty(): + raise self._loader._exception_queue.get() xm.mark_step() raise StopIteration return item @@ -56,7 +59,7 @@ class ParallelLoader(object): """Wraps an existing PyTorch DataLoader with background data upload. Args: - loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be + cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be wrapped. devices (`torch.device`...): The list of devices where the data has to be sent. The i-th sample returned by the `loader` will be sent to `devices[i @@ -74,13 +77,13 @@ class ParallelLoader(object): host_to_device_transfer_threads (int, optional): The number of threads that work in parallel to transfer data from loader queue to device queue. Default: 1 - input_sharding (ShardingSpec, optional): Sharding spec to apply to - compatible input tensors after loading. + input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding + spec to apply to compatible input tensors after loading. Default: None """ def __init__(self, - loader, + cpu_loader, devices, batchdim=0, batches_per_execution=1, @@ -88,12 +91,13 @@ def __init__(self, device_prefetch_size=8, host_to_device_transfer_threads=1, input_sharding=None): - self._loader = loader + self._cpu_loader = cpu_loader self._devices = [torch.device(x) for x in devices] self._batchdim = batchdim self._batches_per_execution = batches_per_execution self._done = False self._queues = dict() + self._exception_queue = queue.Queue() self._input_sharding = input_sharding for device in self._devices: self._queues[device] = PerDeviceQueue(device, loader_prefetch_size, @@ -137,7 +141,7 @@ def close(self): self._done = True for dqueue in self._queues.values(): dqueue.queue.close() - dqueue.loader_queue.close() + dqueue.cpu_loader_queue.close() @property def batches_per_execution(self): @@ -145,7 +149,7 @@ def batches_per_execution(self): def _loader_worker(self): queues = list(self._queues.values()) - data_iter = enumerate(self._loader) + data_iter = enumerate(self._cpu_loader) batch = [] while not self._done: try: @@ -155,27 +159,74 @@ def _loader_worker(self): batch.append(data) if len(batch) == len(self._devices): for queue_no, device_batch in enumerate(batch): - queues[queue_no].loader_queue.put(device_batch) + queues[queue_no].cpu_loader_queue.put(device_batch) batch = [] for dqueue in queues: - dqueue.loader_queue.close_write() + dqueue.cpu_loader_queue.close_write() def _get_batch(self, dqueue): batch = [] - while dqueue.queue.max_size() > len(batch): - item = dqueue.loader_queue.get() + while len(batch) < dqueue.queue.max_size(): + item = dqueue.cpu_loader_queue.get() if item is None: break batch.append(item) return batch + def send_cpu_data_to_device(self, batches, device): + """Move batch to device. + Args: + batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch + present in the cpu memory + device: TPU device where the batch should be moved + + Returns: + result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the + input batch is a dict. Otherwise, returns a list of torch.Tensor. + """ + result = None + if isinstance(self._input_sharding, dict): + if not isinstance(batches[0], dict): + raise ValueError( + f"input batch should be a dict when input sharding is a dict.") + result = [] + for batch in batches: + xla_batch = {} + missing_keys = [] + for key, tensor in batch.items(): + assert type(tensor) == torch.Tensor + sharding_spec = None + if self._input_sharding: + if key not in self._input_sharding: + missing_keys.append(key) + continue + sharding_spec = self._input_sharding[key] + + # xla_tensor is a list of tensors. + xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec) + xla_batch[key] = xla_tensor[0] + if len(missing_keys) != 0: + # Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread. + raise KeyError( + f"Keys: {missing_keys} are missing from input_sharding.") + result.append(xla_batch) + else: + result = xm.send_cpu_data_to_device(batches, device, self._input_sharding) + return result + def _worker(self, dqueue, host_to_device_transfer_threads): device = torch.device(dqueue.device) while True: batch = self._get_batch(dqueue) if not batch: break - batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding) + try: + batch = self.send_cpu_data_to_device(batch, device) + except Exception as e: + # _worker is being run in a daemon thread, raise the error + # will not work. Put the error in an error queue instead. + self._exception_queue.put(e) + break for data in batch: dqueue.queue.put(data) close_queue_count = next(dqueue.close_queue_count)