diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index ac8b2c2..0dee40a 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -30,7 +30,8 @@ def sharded_xr_dataset( rank: int | None = None, world_size: int | None = None, process_group: dist.ProcessGroup | None = None, - load: bool = True, + load: bool = False, + load_kwargs: dict | None = None, ) -> Iterable[xr.Dataset | xr.DataArray]: num_total_elements = len(ds[dim]) num_chunks = num_total_elements // chunk_size @@ -46,8 +47,11 @@ def sharded_xr_dataset( start = chunk_idx * chunk_size end = start + chunk_size chunk = ds.isel({dim: slice(start, end)}) + if load: - chunk.load() + kwargs = load_kwargs or {} + chunk.load(**kwargs) + yield chunk @@ -62,7 +66,8 @@ def __init__( rank: int | None = None, world_size: int | None = None, process_group: dist.ProcessGroup | None = None, - load: bool = True, + load: bool = False, + load_kwargs: dict | None = None, ): self.ds = ds self.chunk_size = chunk_size @@ -70,6 +75,7 @@ def __init__( self.shuffle = shuffle self.seed = seed self.load = load + self.load_kwargs = load_kwargs if rank is None: self.rank = dist.get_rank(process_group) @@ -104,6 +110,7 @@ def __iter__(self): rank, world_size, self.load, + self.load_kwargs, ) @@ -114,6 +121,11 @@ def interleave_batches( Interleaves batches from an iterable of batches. Important: Returned batches must be used immediately or copied to avoid overwriting. """ + if num_batches < 1: + raise ValueError('num_batches must be greater than 0') + + if num_batches == 1: + yield from iterable batches = [] memory = None