Skip to content

Commit

Permalink
feat: PrefetchDataset, BatchDataset, DownstreamDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed May 10, 2024
1 parent 78755d7 commit d7c11f1
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions dmlcloud/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed as dist
import xarray as xr
from torch.utils.data import get_worker_info, IterableDataset
from concurrent.futures import ThreadPoolExecutor


def shard_indices(
Expand Down Expand Up @@ -199,6 +200,63 @@ def __iter__(self):
load_kwargs=self.load_kwargs,
)

class DownstreamDataset(IterableDataset):

def __init__(self, source_ds: Iterable[xr.Dataset]):
self.source_ds = source_ds

def set_epoch(self, epoch: int):
if hasattr(self.source_ds, 'set_epoch'):
self.source_ds.set_epoch(epoch)

def __len__(self):
return len(self.source_ds)


class PrefetchDataset(DownstreamDataset):
def __init__(self, source_ds: Iterable, num_elements: int):
super().__init__(source_ds)
self.num_elements = num_elements

def __iter__(self):
pool = ThreadPoolExecutor(max_workers=1)
iter_ = iter(self.source_ds)

with pool:
futures = [pool.submit(next, iter_) for _ in range(self.num_elements)]
while True:
future = futures.pop(0)
try:
element = future.result()
except StopIteration:
return
futures += [pool.submit(next, iter_)]
yield element


class BatchDataset(DownstreamDataset):

def __init__(self, source_ds: Iterable, batch_size: int, drop_remainder: bool = False):
super().__init__(source_ds)
self.batch_size = batch_size
self.drop_remainder = drop_remainder

def __len__(self):
if self.drop_remainder:
return len(self.source_ds) // self.batch_size
else:
return (len(self.source_ds) + self.batch_size - 1) // self.batch_size

def __iter__(self):
batch = []
for element in self.source_ds:
batch.append(element)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch and not self.drop_remainder:
yield batch


def interleave_batches(
iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False
Expand Down

0 comments on commit d7c11f1

Please sign in to comment.