diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 772250a..e868dc9 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -8,23 +8,34 @@ def shard_indices( - n: int, rank: int, world_size: int, shuffle: bool = False, drop_remainder: bool = True, seed: int = 0 + num_elements: int, + rank: int, + world_size: int, + shuffle: bool = False, + even_shards: bool = True, + seed: int = 0, ) -> list[int]: - indices = np.arange(n) + """ + even_shards: If True, every worker receives the same number of shards, and the last shards are dropped. + """ + indices = np.arange(num_elements) if shuffle: np.random.Generator(np.random.MT19937(seed)).shuffle(indices) - if drop_remainder: - indices = indices[: n - n % world_size] + if even_shards: + indices = indices[: num_elements - num_elements % world_size] return indices[rank::world_size].tolist() # this also converts np.int64 to python's int def sharded_xr_dataset( ds: xr.Dataset | xr.DataArray, - chunk_size: int, dim: str, + chunk_size: int, + chunk_overlap: int = 0, + even_shards: bool = True, + equal_chunks: bool = True, shuffle: bool = False, seed: int = 0, rank: int | None = None, @@ -34,18 +45,22 @@ def sharded_xr_dataset( load_kwargs: dict | None = None, ) -> Iterable[xr.Dataset | xr.DataArray]: num_total_elements = len(ds[dim]) - num_chunks = num_total_elements // chunk_size + + if equal_chunks: + num_chunks = num_total_elements // chunk_size + else: + num_chunks = (num_total_elements + chunk_size - 1) // chunk_size if rank is None: rank = dist.get_rank(process_group) if world_size is None: world_size = dist.get_world_size(process_group) - chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=True, seed=seed) + chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed) for chunk_idx in chunk_indices: start = chunk_idx * chunk_size - end = start + chunk_size + end = start + chunk_size + chunk_overlap chunk = ds.isel({dim: slice(start, end)}) if load: @@ -59,8 +74,11 @@ class ShardedXrDataset(IterableDataset): def __init__( self, ds: xr.Dataset | xr.DataArray, - chunk_size: int, dim: str, + chunk_size: int, + chunk_overlap: int = 0, + even_shards: bool = True, + equal_chunks: bool = True, shuffle: bool = False, seed: int = 0, rank: int | None = None, @@ -70,27 +88,18 @@ def __init__( load_kwargs: dict | None = None, ): self.ds = ds - self.chunk_size = chunk_size self.dim = dim + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.even_shards = even_shards + self.equal_chunks = equal_chunks self.shuffle = shuffle self.seed = seed self.load = load self.load_kwargs = load_kwargs - if rank is None: - self.rank = dist.get_rank(process_group) - else: - self.rank = rank - - if world_size is None: - self.world_size = dist.get_world_size(process_group) - else: - self.world_size = world_size - - def __len__(self): - num_total_elements = len(self.ds[self.dim]) - num_chunks = num_total_elements // self.chunk_size - return num_chunks // self.world_size + self.rank = rank if rank is not None else dist.get_rank(process_group) + self.world_size = world_size if world_size is not None else dist.get_world_size(process_group) def __iter__(self): worker_info = get_worker_info() @@ -103,8 +112,11 @@ def __iter__(self): return sharded_xr_dataset( self.ds, - self.chunk_size, self.dim, + self.chunk_size, + self.chunk_overlap, + self.even_shards, + self.equal_chunks, self.shuffle, self.seed, rank, diff --git a/test/test_data.py b/test/test_data.py index a5c7853..8a5d778 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,4 +1,5 @@ import sys +from functools import partial import numpy as np import pytest @@ -21,32 +22,32 @@ def __iter__(self): class TestSharding: def test_types(self): - indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) + indices = shard_indices(10, 0, 2, shuffle=False, even_shards=False) assert isinstance(indices, list) assert all(isinstance(i, int) for i in indices) def test_even(self): - assert shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8] - assert shard_indices(10, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9] + assert shard_indices(10, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8] + assert shard_indices(10, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9] def test_uneven(self): - assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=False) == [0, 3, 6, 9] - assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=False) == [1, 4, 7] - assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=False) == [2, 5, 8] + assert shard_indices(10, 0, 3, shuffle=False, even_shards=False) == [0, 3, 6, 9] + assert shard_indices(10, 1, 3, shuffle=False, even_shards=False) == [1, 4, 7] + assert shard_indices(10, 2, 3, shuffle=False, even_shards=False) == [2, 5, 8] - assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8, 10] - assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9] + assert shard_indices(11, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8, 10] + assert shard_indices(11, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9] def test_dropping(self): - assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=True) == [0, 3, 6] - assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=True) == [1, 4, 7] - assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=True) == [2, 5, 8] + assert shard_indices(10, 0, 3, shuffle=False, even_shards=True) == [0, 3, 6] + assert shard_indices(10, 1, 3, shuffle=False, even_shards=True) == [1, 4, 7] + assert shard_indices(10, 2, 3, shuffle=False, even_shards=True) == [2, 5, 8] - assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=True) == [0, 2, 4, 6, 8] - assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=True) == [1, 3, 5, 7, 9] + assert shard_indices(11, 0, 2, shuffle=False, even_shards=True) == [0, 2, 4, 6, 8] + assert shard_indices(11, 1, 2, shuffle=False, even_shards=True) == [1, 3, 5, 7, 9] def test_shuffling(self): - indices = shard_indices(10, 0, 2, shuffle=True, drop_remainder=False, seed=0) + indices = shard_indices(10, 0, 2, shuffle=True, even_shards=False, seed=0) assert len(indices) == 5 assert len(np.unique(indices)) == 5 assert indices != list(sorted(indices)) @@ -59,9 +60,10 @@ def test_basic(self): world_size = 3 chunk_size = 15 - chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False)) - chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False)) - chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False)) + shard = partial(sharded_xr_dataset, ds, 'x', chunk_size, world_size=world_size, shuffle=False) + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) assert len(chunks_1) == 2 assert len(chunks_2) == 2 @@ -83,6 +85,68 @@ def test_basic(self): assert_array_equal(chunks_2[1]['var'], np.arange(60, 75)) assert_array_equal(chunks_3[1]['var'], np.arange(75, 90)) + def test_uneven(self): + ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + world_size = 3 + chunk_size = 20 + + shard = partial( + sharded_xr_dataset, ds, 'x', chunk_size, even_shards=False, world_size=world_size, shuffle=False + ) + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) + + assert len(chunks_1) == 2 + assert len(chunks_2) == 2 + assert len(chunks_3) == 1 + + assert isinstance(chunks_1[0], xr.Dataset) + + assert chunks_1[0].x.size == 20 + assert chunks_1[1].x.size == 20 + assert chunks_2[0].x.size == 20 + assert chunks_2[1].x.size == 20 + assert chunks_3[0].x.size == 20 + + assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) + assert_array_equal(chunks_2[0]['var'], np.arange(20, 40)) + assert_array_equal(chunks_3[0]['var'], np.arange(40, 60)) + assert_array_equal(chunks_1[1]['var'], np.arange(60, 80)) + assert_array_equal(chunks_2[1]['var'], np.arange(80, 100)) + + def test_unequal(self): + ds = xr.DataArray(np.arange(110), dims=['x'], name='var').to_dataset() + world_size = 3 + chunk_size = 20 + + shard = partial( + sharded_xr_dataset, ds, 'x', chunk_size, equal_chunks=False, world_size=world_size, shuffle=False + ) + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) + + assert len(chunks_1) == 2 + assert len(chunks_2) == 2 + assert len(chunks_3) == 2 + + assert isinstance(chunks_1[0], xr.Dataset) + + assert chunks_1[0].x.size == 20 + assert chunks_1[1].x.size == 20 + assert chunks_2[0].x.size == 20 + assert chunks_2[1].x.size == 20 + assert chunks_3[0].x.size == 20 + assert chunks_3[1].x.size == 10 + + assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) + assert_array_equal(chunks_2[0]['var'], np.arange(20, 40)) + assert_array_equal(chunks_3[0]['var'], np.arange(40, 60)) + assert_array_equal(chunks_1[1]['var'], np.arange(60, 80)) + assert_array_equal(chunks_2[1]['var'], np.arange(80, 100)) + assert_array_equal(chunks_3[1]['var'], np.arange(100, 110)) + def test_shuffled(self): ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() world_size = 3 @@ -297,33 +361,83 @@ def test_XrShardedDataset_multiprocessing(self): 59, ] - def test_XrShardedDataset_length(self): + def test_overlap(self): ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + world_size = 3 chunk_size = 15 + overlap = 5 + + shard = partial( + sharded_xr_dataset, ds, 'x', chunk_size, chunk_overlap=overlap, world_size=world_size, shuffle=False + ) + + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) + + assert len(chunks_1) == 2 + assert len(chunks_2) == 2 + assert len(chunks_3) == 2 + + assert isinstance(chunks_1[0], xr.Dataset) + + assert chunks_1[0].x.size == 20 + assert chunks_1[1].x.size == 20 + assert chunks_2[0].x.size == 20 + assert chunks_2[1].x.size == 20 + assert chunks_3[0].x.size == 20 + assert chunks_3[1].x.size == 20 + + assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) + assert_array_equal(chunks_2[0]['var'], np.arange(15, 35)) + assert_array_equal(chunks_3[0]['var'], np.arange(30, 50)) + assert_array_equal(chunks_1[1]['var'], np.arange(45, 65)) + assert_array_equal(chunks_2[1]['var'], np.arange(60, 80)) + assert_array_equal(chunks_3[1]['var'], np.arange(75, 95)) + + def test_overlap_unequal_uneven(self): + ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() + world_size = 3 + chunk_size = 15 + overlap = 5 + + shard = partial( + sharded_xr_dataset, + ds, + 'x', + chunk_size, + chunk_overlap=overlap, + even_shards=False, + equal_chunks=False, + world_size=world_size, + shuffle=False, + ) + + chunks_1 = list(shard(rank=0)) + chunks_2 = list(shard(rank=1)) + chunks_3 = list(shard(rank=2)) + + assert len(chunks_1) == 3 + assert len(chunks_2) == 2 + assert len(chunks_3) == 2 + + assert isinstance(chunks_1[0], xr.Dataset) - torch_ds = ShardedXrDataset(ds, chunk_size, 'x', world_size=1, rank=0, shuffle=False) - assert len(torch_ds) == 6 - - torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=0, shuffle=False) - torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=1, shuffle=False) - assert len(torch_ds1) == 3 - assert len(torch_ds2) == 3 - - torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=0, shuffle=False) - torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=1, shuffle=False) - torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=2, shuffle=False) - assert len(torch_ds1) == 2 - assert len(torch_ds2) == 2 - assert len(torch_ds3) == 2 - - torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=0, shuffle=False) - torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=1, shuffle=False) - torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=2, shuffle=False) - torch_ds4 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=3, shuffle=False) - assert len(torch_ds1) == 1 - assert len(torch_ds2) == 1 - assert len(torch_ds3) == 1 - assert len(torch_ds4) == 1 + assert chunks_1[0].x.size == 20 + assert chunks_1[1].x.size == 20 + assert chunks_2[0].x.size == 20 + assert chunks_2[1].x.size == 20 + assert chunks_3[0].x.size == 20 + assert chunks_3[1].x.size == 20 + assert chunks_1[2].x.size == 10 + + assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) + assert_array_equal(chunks_2[0]['var'], np.arange(15, 35)) + assert_array_equal(chunks_3[0]['var'], np.arange(30, 50)) + assert_array_equal(chunks_1[1]['var'], np.arange(45, 65)) + assert_array_equal(chunks_2[1]['var'], np.arange(60, 80)) + assert_array_equal(chunks_3[1]['var'], np.arange(75, 95)) + assert_array_equal(chunks_1[2]['var'], np.arange(90, 100)) class TestInterleaveBatches: