Skip to content

Commit

Permalink
feat: ShardedXrDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 18, 2024
1 parent 3c8cf34 commit 98b133a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
49 changes: 48 additions & 1 deletion dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch.distributed as dist
import xarray as xr
from torch.utils.data import IterableDataset


def shard_indices(
Expand All @@ -19,7 +20,7 @@ def shard_indices(
return indices[rank::world_size].tolist() # this also converts np.int64 to python's int


def chunked_xr_dataset(
def sharded_xr_dataset(
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
Expand Down Expand Up @@ -50,3 +51,49 @@ def chunked_xr_dataset(
if load:
chunk.load()
yield chunk


class ShardedXrDataset(IterableDataset):
def __init__(
self,
ds: xr.Dataset | xr.DataArray,
chunk_size: int,
dim: str,
shuffle: bool = False,
drop_remainder: bool = True,
seed: int = 0,
rank: int | None = None,
world_size: int | None = None,
process_group: dist.ProcessGroup | None = None,
load: bool = True,
):
self.ds = ds
self.chunk_size = chunk_size
self.dim = dim
self.shuffle = shuffle
self.drop_remainder = drop_remainder
self.seed = seed
self.load = load

if rank is None:
self.rank = dist.get_rank(process_group)
else:
self.rank = rank

if world_size is None:
self.world_size = dist.get_world_size(process_group)
else:
self.world_size = world_size

def __iter__(self):
return sharded_xr_dataset(
self.ds,
self.chunk_size,
self.dim,
self.shuffle,
self.drop_remainder,
self.seed,
self.rank,
self.world_size,
self.load,
)
16 changes: 8 additions & 8 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest
import xarray as xr
from dmlcloud.util.data import chunked_xr_dataset, shard_indices
from dmlcloud.util.data import shard_indices, sharded_xr_dataset
from numpy.testing import assert_array_equal


Expand Down Expand Up @@ -41,15 +41,15 @@ def test_shuffling(self):
assert (np.array(indices) >= 0).all() and (np.array(indices) <= 9).all()


class TestChunking:
class TestShardedXr:
def test_basic(self):
ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset()
world_size = 3
chunk_size = 15

chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False))
chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False))
chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False))
chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False))
chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False))
chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
Expand All @@ -76,9 +76,9 @@ def test_shuffled(self):
world_size = 3
chunk_size = 15

chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0))
chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0))
chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0))
chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0))
chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0))
chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0))

assert len(chunks_1) == 2
assert len(chunks_2) == 2
Expand Down

0 comments on commit 98b133a

Please sign in to comment.