Skip to content

Commit

Permalink
Allow MpDeviceLoader to shard dictionaries of tensor for 2.5 release (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bhavya01 authored Oct 4, 2024
1 parent c074e31 commit 8ff43a9
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 17 deletions.
26 changes: 22 additions & 4 deletions docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall
```python
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```

### Virtual Device Optimization
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
run_test "$CDIR/quantized_ops/test_dot_general.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
151 changes: 151 additions & 0 deletions test/spmd/test_mp_input_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import sys
import numpy as np
import unittest

import torch
import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
import torch_xla.distributed.parallel_loader as pl

xr.use_spmd()


class MpInputShardingTest(unittest.TestCase):

class fake_dataloader:

def __init__(self, batch, size=1):
self.batch = batch
self.batch_size = size
self.counter = 0

def __iter__(self):
return self

def __next__(self):
if self.counter < self.batch_size:
self.counter += 1
return self.batch
raise StopIteration

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_multiple_inputs(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={
'x': xs.ShardingSpec(mesh, ('x', None)),
'y': xs.ShardingSpec(mesh, ('x', None, None))
})
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_single_tensor(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(data))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_single_tensor_with_input_sharding_dict(self):
device = xm.xla_device()
batch = torch.randn((16, 128))
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()
mesh = xs.get_1d_mesh('x')

train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(ValueError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_none(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
num_devices = xr.global_runtime_device_count()

train_loader = pl.MpDeviceLoader(train_loader, device, input_sharding=None)
train_loader = iter(train_loader)
data = next(train_loader)
annotation = '{replicated}'
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_error_missing_keys(self):
device = xm.xla_device()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding={'x': xs.ShardingSpec(mesh, ('x', None))})
train_loader = iter(train_loader)
with self.assertRaises(KeyError):
data = next(train_loader)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required for tupled partition spec")
def test_input_sharding_not_dict(self):
device = xm.xla_device()
num_devices = xr.global_runtime_device_count()
batch = {'x': torch.randn((16, 128)), 'y': torch.randn((16, 128))}
train_loader = self.fake_dataloader(batch)
mesh = xs.get_1d_mesh('x')
train_loader = pl.MpDeviceLoader(
train_loader, device, input_sharding=xs.ShardingSpec(mesh, ('x', None)))
train_loader = iter(train_loader)
data = next(train_loader)
annotation_x = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
annotation_y = '{devices=[%d,1]%s}' % (num_devices, ','.join(
[str(i) for i in range(num_devices)]))
self.assertEqual(annotation_x,
torch_xla._XLAC._get_xla_sharding_spec(data['x']))
self.assertEqual(annotation_y,
torch_xla._XLAC._get_xla_sharding_spec(data['y']))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_xla_virtual_device.py
python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
Expand Down
77 changes: 64 additions & 13 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import queue
import threading
import torch
import torch_xla
Expand All @@ -12,7 +13,7 @@ class PerDeviceQueue(object):

def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.cpu_loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
self.close_queue_count = itertools.count()

Expand Down Expand Up @@ -47,6 +48,8 @@ def next(self):

item = self._loader.next_item(self._device)
if item is None:
if not self._loader._exception_queue.empty():
raise self._loader._exception_queue.get()
xm.mark_step()
raise StopIteration
return item
Expand All @@ -56,7 +59,7 @@ class ParallelLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
cpu_loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
devices (`torch.device`...): The list of devices where the data has to be
sent. The i-th sample returned by the `loader` will be sent to `devices[i
Expand All @@ -74,26 +77,27 @@ class ParallelLoader(object):
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
input_sharding (ShardingSpec, Dict(str, ShardingSpec), optional): Sharding
spec to apply to compatible input tensors after loading.
Default: None
"""

def __init__(self,
loader,
cpu_loader,
devices,
batchdim=0,
batches_per_execution=1,
loader_prefetch_size=16,
device_prefetch_size=8,
host_to_device_transfer_threads=1,
input_sharding=None):
self._loader = loader
self._cpu_loader = cpu_loader
self._devices = [torch.device(x) for x in devices]
self._batchdim = batchdim
self._batches_per_execution = batches_per_execution
self._done = False
self._queues = dict()
self._exception_queue = queue.Queue()
self._input_sharding = input_sharding
for device in self._devices:
self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
Expand Down Expand Up @@ -137,15 +141,15 @@ def close(self):
self._done = True
for dqueue in self._queues.values():
dqueue.queue.close()
dqueue.loader_queue.close()
dqueue.cpu_loader_queue.close()

@property
def batches_per_execution(self):
return self._batches_per_execution

def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
data_iter = enumerate(self._cpu_loader)
batch = []
while not self._done:
try:
Expand All @@ -155,27 +159,74 @@ def _loader_worker(self):
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
queues[queue_no].cpu_loader_queue.put(device_batch)
batch = []
for dqueue in queues:
dqueue.loader_queue.close_write()
dqueue.cpu_loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
while dqueue.queue.max_size() > len(batch):
item = dqueue.loader_queue.get()
while len(batch) < dqueue.queue.max_size():
item = dqueue.cpu_loader_queue.get()
if item is None:
break
batch.append(item)
return batch

def send_cpu_data_to_device(self, batches, device):
"""Move batch to device.
Args:
batch -> List(torch.Tensor), List(Dict(str: torch.Tensor)): Input batch
present in the cpu memory
device: TPU device where the batch should be moved
Returns:
result -> List(torch.Tensor), Dict(str: torch.Tensor): Returns a dict if the
input batch is a dict. Otherwise, returns a list of torch.Tensor.
"""
result = None
if isinstance(self._input_sharding, dict):
if not isinstance(batches[0], dict):
raise ValueError(
f"input batch should be a dict when input sharding is a dict.")
result = []
for batch in batches:
xla_batch = {}
missing_keys = []
for key, tensor in batch.items():
assert type(tensor) == torch.Tensor
sharding_spec = None
if self._input_sharding:
if key not in self._input_sharding:
missing_keys.append(key)
continue
sharding_spec = self._input_sharding[key]

# xla_tensor is a list of tensors.
xla_tensor = xm.send_cpu_data_to_device(tensor, device, sharding_spec)
xla_batch[key] = xla_tensor[0]
if len(missing_keys) != 0:
# Returning exception as raising in the dataloading thread doesn't surface the problem in the main thread.
raise KeyError(
f"Keys: {missing_keys} are missing from input_sharding.")
result.append(xla_batch)
else:
result = xm.send_cpu_data_to_device(batches, device, self._input_sharding)
return result

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)
while True:
batch = self._get_batch(dqueue)
if not batch:
break
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
try:
batch = self.send_cpu_data_to_device(batch, device)
except Exception as e:
# _worker is being run in a daemon thread, raise the error
# will not work. Put the error in an error queue instead.
self._exception_queue.put(e)
break
for data in batch:
dqueue.queue.put(data)
close_queue_count = next(dqueue.close_queue_count)
Expand Down

0 comments on commit 8ff43a9

Please sign in to comment.