Skip to content

Commit

Permalink
feat: load_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 22, 2024
1 parent 8784045 commit 02e96b9
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def sharded_xr_dataset(
rank: int | None = None,
world_size: int | None = None,
process_group: dist.ProcessGroup | None = None,
load: bool = True,
load: bool = False,
load_kwargs: dict | None = None,
) -> Iterable[xr.Dataset | xr.DataArray]:
num_total_elements = len(ds[dim])
num_chunks = num_total_elements // chunk_size
Expand All @@ -46,8 +47,11 @@ def sharded_xr_dataset(
start = chunk_idx * chunk_size
end = start + chunk_size
chunk = ds.isel({dim: slice(start, end)})

if load:
chunk.load()
kwargs = load_kwargs or {}
chunk.load(**kwargs)

yield chunk


Expand All @@ -62,14 +66,16 @@ def __init__(
rank: int | None = None,
world_size: int | None = None,
process_group: dist.ProcessGroup | None = None,
load: bool = True,
load: bool = False,
load_kwargs: dict | None = None,
):
self.ds = ds
self.chunk_size = chunk_size
self.dim = dim
self.shuffle = shuffle
self.seed = seed
self.load = load
self.load_kwargs = load_kwargs

if rank is None:
self.rank = dist.get_rank(process_group)
Expand Down Expand Up @@ -104,6 +110,7 @@ def __iter__(self):
rank,
world_size,
self.load,
self.load_kwargs,
)


Expand All @@ -114,6 +121,11 @@ def interleave_batches(
Interleaves batches from an iterable of batches.
Important: Returned batches must be used immediately or copied to avoid overwriting.
"""
if num_batches < 1:
raise ValueError('num_batches must be greater than 0')

if num_batches == 1:
yield from iterable

batches = []
memory = None
Expand Down

0 comments on commit 02e96b9

Please sign in to comment.