Skip to content

Commit

Permalink
feat: torch multiprocessing support for ShardedXrDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 20, 2024
1 parent 98b133a commit 172550d
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 11 deletions.
27 changes: 17 additions & 10 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.distributed as dist
import xarray as xr
from torch.utils.data import IterableDataset
from torch.utils.data import get_worker_info, IterableDataset


def shard_indices(
Expand All @@ -25,7 +25,6 @@ def sharded_xr_dataset(
chunk_size: int,
dim: str,
shuffle: bool = False,
drop_remainder: bool = True,
seed: int = 0,
rank: int | None = None,
world_size: int | None = None,
Expand All @@ -40,9 +39,7 @@ def sharded_xr_dataset(
if world_size is None:
world_size = dist.get_world_size(process_group)

chunk_indices = shard_indices(
num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=drop_remainder, seed=seed
)
chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=True, seed=seed)

for chunk_idx in chunk_indices:
start = chunk_idx * chunk_size
Expand All @@ -60,7 +57,6 @@ def __init__(
chunk_size: int,
dim: str,
shuffle: bool = False,
drop_remainder: bool = True,
seed: int = 0,
rank: int | None = None,
world_size: int | None = None,
Expand All @@ -71,7 +67,6 @@ def __init__(
self.chunk_size = chunk_size
self.dim = dim
self.shuffle = shuffle
self.drop_remainder = drop_remainder
self.seed = seed
self.load = load

Expand All @@ -85,15 +80,27 @@ def __init__(
else:
self.world_size = world_size

def __len__(self):
num_total_elements = len(self.ds[self.dim])
num_chunks = num_total_elements // self.chunk_size
return num_chunks // self.world_size

def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
rank = self.rank
world_size = self.world_size
else:
rank = self.rank * worker_info.num_workers + worker_info.id
world_size = self.world_size * worker_info.num_workers

return sharded_xr_dataset(
self.ds,
self.chunk_size,
self.dim,
self.shuffle,
self.drop_remainder,
self.seed,
self.rank,
self.world_size,
rank,
world_size,
self.load,
)
234 changes: 233 additions & 1 deletion test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np
import pytest
import xarray as xr
from dmlcloud.util.data import shard_indices, sharded_xr_dataset
from dmlcloud.util.data import shard_indices, sharded_xr_dataset, ShardedXrDataset
from numpy.testing import assert_array_equal
from torch.utils.data import DataLoader, IterableDataset


class TestSharding:
Expand Down Expand Up @@ -91,6 +92,237 @@ def test_shuffled(self):
chunk = chunks_1[0]['var'].values
assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1))

def test_XrShardedDataset_multiprocessing(self):
class _Unzip(IterableDataset):
def __init__(self, ds):
self.ds = ds

def __iter__(self):
for chunk in self.ds:
arr = chunk.to_array().values[0]
yield from arr

xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()

# Simple case: 2 workers, world_size=1
# Workers act as additional processes and we expect interleaved chunks
torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=1, rank=0, shuffle=False)
torch_ds = _Unzip(torch_ds)
dataloader = DataLoader(
torch_ds,
num_workers=2,
batch_size=1,
prefetch_factor=1,
)
results = list(batch.item() for batch in dataloader)
assert results == [
0,
15,
1,
16,
2,
17,
3,
18,
4,
19,
5,
20,
6,
21,
7,
22,
8,
23,
9,
24,
10,
25,
11,
26,
12,
27,
13,
28,
14,
29,
30,
45,
31,
46,
32,
47,
33,
48,
34,
49,
35,
50,
36,
51,
37,
52,
38,
53,
39,
54,
40,
55,
41,
56,
42,
57,
43,
58,
44,
59,
60,
75,
61,
76,
62,
77,
63,
78,
64,
79,
65,
80,
66,
81,
67,
82,
68,
83,
69,
84,
70,
85,
71,
86,
72,
87,
73,
88,
74,
89,
]

# Advanced case: 2 workers, world_size=2
# Each rank gets consecutive chunks and splits them between workers (which interleave again)
# Since the effective world size is now 4, and the dataset has 6 chunks in total, we will only get 4 chunks (up to 60)
torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=0, shuffle=False)
torch_ds = _Unzip(torch_ds)
dataloader = DataLoader(
torch_ds,
num_workers=2,
batch_size=1,
prefetch_factor=1,
)
results = list(batch.item() for batch in dataloader)
assert results == [
0,
15,
1,
16,
2,
17,
3,
18,
4,
19,
5,
20,
6,
21,
7,
22,
8,
23,
9,
24,
10,
25,
11,
26,
12,
27,
13,
28,
14,
29,
]

torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=1, shuffle=False)
torch_ds = _Unzip(torch_ds)
dataloader = DataLoader(
torch_ds,
num_workers=2,
batch_size=1,
prefetch_factor=1,
)
results = list(batch.item() for batch in dataloader)
assert results == [
30,
45,
31,
46,
32,
47,
33,
48,
34,
49,
35,
50,
36,
51,
37,
52,
38,
53,
39,
54,
40,
55,
41,
56,
42,
57,
43,
58,
44,
59,
]

def test_XrShardedDataset_length(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
chunk_size = 15

torch_ds = ShardedXrDataset(ds, chunk_size, 'x', world_size=1, rank=0, shuffle=False)
assert len(torch_ds) == 6

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=1, shuffle=False)
assert len(torch_ds1) == 3
assert len(torch_ds2) == 3

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=1, shuffle=False)
torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=2, shuffle=False)
assert len(torch_ds1) == 2
assert len(torch_ds2) == 2
assert len(torch_ds3) == 2

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=1, shuffle=False)
torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=2, shuffle=False)
torch_ds4 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=3, shuffle=False)
assert len(torch_ds1) == 1
assert len(torch_ds2) == 1
assert len(torch_ds3) == 1
assert len(torch_ds4) == 1


if __name__ == '__main__':
sys.exit(pytest.main([__file__]))

0 comments on commit 172550d

Please sign in to comment.