diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 06f2bd3..af33587 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -3,7 +3,7 @@ import numpy as np import torch.distributed as dist import xarray as xr -from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info, IterableDataset def shard_indices( @@ -25,7 +25,6 @@ def sharded_xr_dataset( chunk_size: int, dim: str, shuffle: bool = False, - drop_remainder: bool = True, seed: int = 0, rank: int | None = None, world_size: int | None = None, @@ -40,9 +39,7 @@ def sharded_xr_dataset( if world_size is None: world_size = dist.get_world_size(process_group) - chunk_indices = shard_indices( - num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=drop_remainder, seed=seed - ) + chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=True, seed=seed) for chunk_idx in chunk_indices: start = chunk_idx * chunk_size @@ -60,7 +57,6 @@ def __init__( chunk_size: int, dim: str, shuffle: bool = False, - drop_remainder: bool = True, seed: int = 0, rank: int | None = None, world_size: int | None = None, @@ -71,7 +67,6 @@ def __init__( self.chunk_size = chunk_size self.dim = dim self.shuffle = shuffle - self.drop_remainder = drop_remainder self.seed = seed self.load = load @@ -85,15 +80,27 @@ def __init__( else: self.world_size = world_size + def __len__(self): + num_total_elements = len(self.ds[self.dim]) + num_chunks = num_total_elements // self.chunk_size + return num_chunks // self.world_size + def __iter__(self): + worker_info = get_worker_info() + if worker_info is None: + rank = self.rank + world_size = self.world_size + else: + rank = self.rank * worker_info.num_workers + worker_info.id + world_size = self.world_size * worker_info.num_workers + return sharded_xr_dataset( self.ds, self.chunk_size, self.dim, self.shuffle, - self.drop_remainder, self.seed, - self.rank, - self.world_size, + rank, + world_size, self.load, ) diff --git a/test/test_data.py b/test/test_data.py index a804a96..d94fedd 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -3,8 +3,9 @@ import numpy as np import pytest import xarray as xr -from dmlcloud.util.data import shard_indices, sharded_xr_dataset +from dmlcloud.util.data import shard_indices, sharded_xr_dataset, ShardedXrDataset from numpy.testing import assert_array_equal +from torch.utils.data import DataLoader, IterableDataset class TestSharding: @@ -91,6 +92,237 @@ def test_shuffled(self): chunk = chunks_1[0]['var'].values assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1)) + def test_XrShardedDataset_multiprocessing(self): + class _Unzip(IterableDataset): + def __init__(self, ds): + self.ds = ds + + def __iter__(self): + for chunk in self.ds: + arr = chunk.to_array().values[0] + yield from arr + + xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + + # Simple case: 2 workers, world_size=1 + # Workers act as additional processes and we expect interleaved chunks + torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=1, rank=0, shuffle=False) + torch_ds = _Unzip(torch_ds) + dataloader = DataLoader( + torch_ds, + num_workers=2, + batch_size=1, + prefetch_factor=1, + ) + results = list(batch.item() for batch in dataloader) + assert results == [ + 0, + 15, + 1, + 16, + 2, + 17, + 3, + 18, + 4, + 19, + 5, + 20, + 6, + 21, + 7, + 22, + 8, + 23, + 9, + 24, + 10, + 25, + 11, + 26, + 12, + 27, + 13, + 28, + 14, + 29, + 30, + 45, + 31, + 46, + 32, + 47, + 33, + 48, + 34, + 49, + 35, + 50, + 36, + 51, + 37, + 52, + 38, + 53, + 39, + 54, + 40, + 55, + 41, + 56, + 42, + 57, + 43, + 58, + 44, + 59, + 60, + 75, + 61, + 76, + 62, + 77, + 63, + 78, + 64, + 79, + 65, + 80, + 66, + 81, + 67, + 82, + 68, + 83, + 69, + 84, + 70, + 85, + 71, + 86, + 72, + 87, + 73, + 88, + 74, + 89, + ] + + # Advanced case: 2 workers, world_size=2 + # Each rank gets consecutive chunks and splits them between workers (which interleave again) + # Since the effective world size is now 4, and the dataset has 6 chunks in total, we will only get 4 chunks (up to 60) + torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=0, shuffle=False) + torch_ds = _Unzip(torch_ds) + dataloader = DataLoader( + torch_ds, + num_workers=2, + batch_size=1, + prefetch_factor=1, + ) + results = list(batch.item() for batch in dataloader) + assert results == [ + 0, + 15, + 1, + 16, + 2, + 17, + 3, + 18, + 4, + 19, + 5, + 20, + 6, + 21, + 7, + 22, + 8, + 23, + 9, + 24, + 10, + 25, + 11, + 26, + 12, + 27, + 13, + 28, + 14, + 29, + ] + + torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=1, shuffle=False) + torch_ds = _Unzip(torch_ds) + dataloader = DataLoader( + torch_ds, + num_workers=2, + batch_size=1, + prefetch_factor=1, + ) + results = list(batch.item() for batch in dataloader) + assert results == [ + 30, + 45, + 31, + 46, + 32, + 47, + 33, + 48, + 34, + 49, + 35, + 50, + 36, + 51, + 37, + 52, + 38, + 53, + 39, + 54, + 40, + 55, + 41, + 56, + 42, + 57, + 43, + 58, + 44, + 59, + ] + + def test_XrShardedDataset_length(self): + ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + chunk_size = 15 + + torch_ds = ShardedXrDataset(ds, chunk_size, 'x', world_size=1, rank=0, shuffle=False) + assert len(torch_ds) == 6 + + torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=0, shuffle=False) + torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=1, shuffle=False) + assert len(torch_ds1) == 3 + assert len(torch_ds2) == 3 + + torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=0, shuffle=False) + torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=1, shuffle=False) + torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=2, shuffle=False) + assert len(torch_ds1) == 2 + assert len(torch_ds2) == 2 + assert len(torch_ds3) == 2 + + torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=0, shuffle=False) + torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=1, shuffle=False) + torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=2, shuffle=False) + torch_ds4 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=3, shuffle=False) + assert len(torch_ds1) == 1 + assert len(torch_ds2) == 1 + assert len(torch_ds3) == 1 + assert len(torch_ds4) == 1 + if __name__ == '__main__': sys.exit(pytest.main([__file__]))