diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 8cf07be..06f2bd3 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -3,6 +3,7 @@ import numpy as np import torch.distributed as dist import xarray as xr +from torch.utils.data import IterableDataset def shard_indices( @@ -19,7 +20,7 @@ def shard_indices( return indices[rank::world_size].tolist() # this also converts np.int64 to python's int -def chunked_xr_dataset( +def sharded_xr_dataset( ds: xr.Dataset | xr.DataArray, chunk_size: int, dim: str, @@ -50,3 +51,49 @@ def chunked_xr_dataset( if load: chunk.load() yield chunk + + +class ShardedXrDataset(IterableDataset): + def __init__( + self, + ds: xr.Dataset | xr.DataArray, + chunk_size: int, + dim: str, + shuffle: bool = False, + drop_remainder: bool = True, + seed: int = 0, + rank: int | None = None, + world_size: int | None = None, + process_group: dist.ProcessGroup | None = None, + load: bool = True, + ): + self.ds = ds + self.chunk_size = chunk_size + self.dim = dim + self.shuffle = shuffle + self.drop_remainder = drop_remainder + self.seed = seed + self.load = load + + if rank is None: + self.rank = dist.get_rank(process_group) + else: + self.rank = rank + + if world_size is None: + self.world_size = dist.get_world_size(process_group) + else: + self.world_size = world_size + + def __iter__(self): + return sharded_xr_dataset( + self.ds, + self.chunk_size, + self.dim, + self.shuffle, + self.drop_remainder, + self.seed, + self.rank, + self.world_size, + self.load, + ) diff --git a/test/test_data.py b/test/test_data.py index 92aa1a7..a804a96 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -3,7 +3,7 @@ import numpy as np import pytest import xarray as xr -from dmlcloud.util.data import chunked_xr_dataset, shard_indices +from dmlcloud.util.data import shard_indices, sharded_xr_dataset from numpy.testing import assert_array_equal @@ -41,15 +41,15 @@ def test_shuffling(self): assert (np.array(indices) >= 0).all() and (np.array(indices) <= 9).all() -class TestChunking: +class TestShardedXr: def test_basic(self): ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() world_size = 3 chunk_size = 15 - chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False)) - chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False)) - chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False)) + chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=False)) + chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=False)) + chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=False)) assert len(chunks_1) == 2 assert len(chunks_2) == 2 @@ -76,9 +76,9 @@ def test_shuffled(self): world_size = 3 chunk_size = 15 - chunks_1 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0)) - chunks_2 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0)) - chunks_3 = list(chunked_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0)) + chunks_1 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=0, shuffle=True, seed=0)) + chunks_2 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=1, shuffle=True, seed=0)) + chunks_3 = list(sharded_xr_dataset(ds, chunk_size, 'x', world_size=world_size, rank=2, shuffle=True, seed=0)) assert len(chunks_1) == 2 assert len(chunks_2) == 2