Skip to content

Commit

Permalink
feat: overlapping chunks, drop_remainder -> even_shards, introduced e…
Browse files Browse the repository at this point in the history
…qual_chunks
  • Loading branch information
sehoffmann committed Mar 29, 2024
1 parent a7ee2af commit a25dfb4
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 66 deletions.
62 changes: 37 additions & 25 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,34 @@


def shard_indices(
n: int, rank: int, world_size: int, shuffle: bool = False, drop_remainder: bool = True, seed: int = 0
num_elements: int,
rank: int,
world_size: int,
shuffle: bool = False,
even_shards: bool = True,
seed: int = 0,
) -> list[int]:
indices = np.arange(n)
"""
even_shards: If True, every worker receives the same number of shards, and the last shards are dropped.
"""
indices = np.arange(num_elements)

if shuffle:
np.random.Generator(np.random.MT19937(seed)).shuffle(indices)

if drop_remainder:
indices = indices[: n - n % world_size]
if even_shards:
indices = indices[: num_elements - num_elements % world_size]

return indices[rank::world_size].tolist() # this also converts np.int64 to python's int


def sharded_xr_dataset(
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
chunk_size: int,
chunk_overlap: int = 0,
even_shards: bool = True,
equal_chunks: bool = True,
shuffle: bool = False,
seed: int = 0,
rank: int | None = None,
Expand All @@ -34,18 +45,22 @@ def sharded_xr_dataset(
load_kwargs: dict | None = None,
) -> Iterable[xr.Dataset | xr.DataArray]:
num_total_elements = len(ds[dim])
num_chunks = num_total_elements // chunk_size

if equal_chunks:
num_chunks = num_total_elements // chunk_size
else:
num_chunks = (num_total_elements + chunk_size - 1) // chunk_size

if rank is None:
rank = dist.get_rank(process_group)
if world_size is None:
world_size = dist.get_world_size(process_group)

chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, drop_remainder=True, seed=seed)
chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed)

for chunk_idx in chunk_indices:
start = chunk_idx * chunk_size
end = start + chunk_size
end = start + chunk_size + chunk_overlap
chunk = ds.isel({dim: slice(start, end)})

if load:
Expand All @@ -59,8 +74,11 @@ class ShardedXrDataset(IterableDataset):
def __init__(
self,
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
chunk_size: int,
chunk_overlap: int = 0,
even_shards: bool = True,
equal_chunks: bool = True,
shuffle: bool = False,
seed: int = 0,
rank: int | None = None,
Expand All @@ -70,27 +88,18 @@ def __init__(
load_kwargs: dict | None = None,
):
self.ds = ds
self.chunk_size = chunk_size
self.dim = dim
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.even_shards = even_shards
self.equal_chunks = equal_chunks
self.shuffle = shuffle
self.seed = seed
self.load = load
self.load_kwargs = load_kwargs

if rank is None:
self.rank = dist.get_rank(process_group)
else:
self.rank = rank

if world_size is None:
self.world_size = dist.get_world_size(process_group)
else:
self.world_size = world_size

def __len__(self):
num_total_elements = len(self.ds[self.dim])
num_chunks = num_total_elements // self.chunk_size
return num_chunks // self.world_size
self.rank = rank if rank is not None else dist.get_rank(process_group)
self.world_size = world_size if world_size is not None else dist.get_world_size(process_group)

def __iter__(self):
worker_info = get_worker_info()
Expand All @@ -103,8 +112,11 @@ def __iter__(self):

return sharded_xr_dataset(
self.ds,
self.chunk_size,
self.dim,
self.chunk_size,
self.chunk_overlap,
self.even_shards,
self.equal_chunks,
self.shuffle,
self.seed,
rank,
Expand Down
196 changes: 155 additions & 41 deletions test/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from functools import partial

import numpy as np
import pytest
Expand All @@ -21,32 +22,32 @@ def __iter__(self):

class TestSharding:
def test_types(self):
indices = shard_indices(10, 0, 2, shuffle=False, drop_remainder=False)
indices = shard_indices(10, 0, 2, shuffle=False, even_shards=False)
assert isinstance(indices, list)
assert all(isinstance(i, int) for i in indices)

def test_even(self):
assert shard_indices(10, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8]
assert shard_indices(10, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9]
assert shard_indices(10, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8]
assert shard_indices(10, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9]

def test_uneven(self):
assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=False) == [0, 3, 6, 9]
assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=False) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=False) == [2, 5, 8]
assert shard_indices(10, 0, 3, shuffle=False, even_shards=False) == [0, 3, 6, 9]
assert shard_indices(10, 1, 3, shuffle=False, even_shards=False) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, even_shards=False) == [2, 5, 8]

assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=False) == [0, 2, 4, 6, 8, 10]
assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=False) == [1, 3, 5, 7, 9]
assert shard_indices(11, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8, 10]
assert shard_indices(11, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9]

def test_dropping(self):
assert shard_indices(10, 0, 3, shuffle=False, drop_remainder=True) == [0, 3, 6]
assert shard_indices(10, 1, 3, shuffle=False, drop_remainder=True) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, drop_remainder=True) == [2, 5, 8]
assert shard_indices(10, 0, 3, shuffle=False, even_shards=True) == [0, 3, 6]
assert shard_indices(10, 1, 3, shuffle=False, even_shards=True) == [1, 4, 7]
assert shard_indices(10, 2, 3, shuffle=False, even_shards=True) == [2, 5, 8]

assert shard_indices(11, 0, 2, shuffle=False, drop_remainder=True) == [0, 2, 4, 6, 8]
assert shard_indices(11, 1, 2, shuffle=False, drop_remainder=True) == [1, 3, 5, 7, 9]
assert shard_indices(11, 0, 2, shuffle=False, even_shards=True) == [0, 2, 4, 6, 8]
assert shard_indices(11, 1, 2, shuffle=False, even_shards=True) == [1, 3, 5, 7, 9]

def test_shuffling(self):
indices = shard_indices(10, 0, 2, shuffle=True, drop_remainder=False, seed=0)
indices = shard_indices(10, 0, 2, shuffle=True, even_shards=False, seed=0)
assert len(indices) == 5
assert len(np.unique(indices)) == 5
assert indices != list(sorted(indices))
Expand All @@ -59,9 +60,10 @@ def test_basic(self):
world_size = 3
chunk_size = 15

chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False))
chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False))
chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False))
shard = partial(sharded_xr_dataset, ds, 'x', chunk_size, world_size=world_size, shuffle=False)
chunks_1 = list(shard(rank=0))
chunks_2 = list(shard(rank=1))
chunks_3 = list(shard(rank=2))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
Expand All @@ -83,6 +85,68 @@ def test_basic(self):
assert_array_equal(chunks_2[1]['var'], np.arange(60, 75))
assert_array_equal(chunks_3[1]['var'], np.arange(75, 90))

def test_uneven(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 20

shard = partial(
sharded_xr_dataset, ds, 'x', chunk_size, even_shards=False, world_size=world_size, shuffle=False
)
chunks_1 = list(shard(rank=0))
chunks_2 = list(shard(rank=1))
chunks_3 = list(shard(rank=2))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
assert len(chunks_3) == 1

assert isinstance(chunks_1[0], xr.Dataset)

assert chunks_1[0].x.size == 20
assert chunks_1[1].x.size == 20
assert chunks_2[0].x.size == 20
assert chunks_2[1].x.size == 20
assert chunks_3[0].x.size == 20

assert_array_equal(chunks_1[0]['var'], np.arange(0, 20))
assert_array_equal(chunks_2[0]['var'], np.arange(20, 40))
assert_array_equal(chunks_3[0]['var'], np.arange(40, 60))
assert_array_equal(chunks_1[1]['var'], np.arange(60, 80))
assert_array_equal(chunks_2[1]['var'], np.arange(80, 100))

def test_unequal(self):
ds = xr.DataArray(np.arange(110), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 20

shard = partial(
sharded_xr_dataset, ds, 'x', chunk_size, equal_chunks=False, world_size=world_size, shuffle=False
)
chunks_1 = list(shard(rank=0))
chunks_2 = list(shard(rank=1))
chunks_3 = list(shard(rank=2))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
assert len(chunks_3) == 2

assert isinstance(chunks_1[0], xr.Dataset)

assert chunks_1[0].x.size == 20
assert chunks_1[1].x.size == 20
assert chunks_2[0].x.size == 20
assert chunks_2[1].x.size == 20
assert chunks_3[0].x.size == 20
assert chunks_3[1].x.size == 10

assert_array_equal(chunks_1[0]['var'], np.arange(0, 20))
assert_array_equal(chunks_2[0]['var'], np.arange(20, 40))
assert_array_equal(chunks_3[0]['var'], np.arange(40, 60))
assert_array_equal(chunks_1[1]['var'], np.arange(60, 80))
assert_array_equal(chunks_2[1]['var'], np.arange(80, 100))
assert_array_equal(chunks_3[1]['var'], np.arange(100, 110))

def test_shuffled(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
Expand Down Expand Up @@ -297,33 +361,83 @@ def test_XrShardedDataset_multiprocessing(self):
59,
]

def test_XrShardedDataset_length(self):
def test_overlap(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 15
overlap = 5

shard = partial(
sharded_xr_dataset, ds, 'x', chunk_size, chunk_overlap=overlap, world_size=world_size, shuffle=False
)

chunks_1 = list(shard(rank=0))
chunks_2 = list(shard(rank=1))
chunks_3 = list(shard(rank=2))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
assert len(chunks_3) == 2

assert isinstance(chunks_1[0], xr.Dataset)

assert chunks_1[0].x.size == 20
assert chunks_1[1].x.size == 20
assert chunks_2[0].x.size == 20
assert chunks_2[1].x.size == 20
assert chunks_3[0].x.size == 20
assert chunks_3[1].x.size == 20

assert_array_equal(chunks_1[0]['var'], np.arange(0, 20))
assert_array_equal(chunks_2[0]['var'], np.arange(15, 35))
assert_array_equal(chunks_3[0]['var'], np.arange(30, 50))
assert_array_equal(chunks_1[1]['var'], np.arange(45, 65))
assert_array_equal(chunks_2[1]['var'], np.arange(60, 80))
assert_array_equal(chunks_3[1]['var'], np.arange(75, 95))

def test_overlap_unequal_uneven(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 15
overlap = 5

shard = partial(
sharded_xr_dataset,
ds,
'x',
chunk_size,
chunk_overlap=overlap,
even_shards=False,
equal_chunks=False,
world_size=world_size,
shuffle=False,
)

chunks_1 = list(shard(rank=0))
chunks_2 = list(shard(rank=1))
chunks_3 = list(shard(rank=2))

assert len(chunks_1) == 3
assert len(chunks_2) == 2
assert len(chunks_3) == 2

assert isinstance(chunks_1[0], xr.Dataset)

torch_ds = ShardedXrDataset(ds, chunk_size, 'x', world_size=1, rank=0, shuffle=False)
assert len(torch_ds) == 6

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=2, rank=1, shuffle=False)
assert len(torch_ds1) == 3
assert len(torch_ds2) == 3

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=1, shuffle=False)
torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=3, rank=2, shuffle=False)
assert len(torch_ds1) == 2
assert len(torch_ds2) == 2
assert len(torch_ds3) == 2

torch_ds1 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=0, shuffle=False)
torch_ds2 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=1, shuffle=False)
torch_ds3 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=2, shuffle=False)
torch_ds4 = ShardedXrDataset(ds, chunk_size, 'x', world_size=4, rank=3, shuffle=False)
assert len(torch_ds1) == 1
assert len(torch_ds2) == 1
assert len(torch_ds3) == 1
assert len(torch_ds4) == 1
assert chunks_1[0].x.size == 20
assert chunks_1[1].x.size == 20
assert chunks_2[0].x.size == 20
assert chunks_2[1].x.size == 20
assert chunks_3[0].x.size == 20
assert chunks_3[1].x.size == 20
assert chunks_1[2].x.size == 10

assert_array_equal(chunks_1[0]['var'], np.arange(0, 20))
assert_array_equal(chunks_2[0]['var'], np.arange(15, 35))
assert_array_equal(chunks_3[0]['var'], np.arange(30, 50))
assert_array_equal(chunks_1[1]['var'], np.arange(45, 65))
assert_array_equal(chunks_2[1]['var'], np.arange(60, 80))
assert_array_equal(chunks_3[1]['var'], np.arange(75, 95))
assert_array_equal(chunks_1[2]['var'], np.arange(90, 100))


class TestInterleaveBatches:
Expand Down

0 comments on commit a25dfb4

Please sign in to comment.