From d7c11f19468a435d6212fb815d6d826113736c44 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Fri, 10 May 2024 20:11:57 +0200 Subject: [PATCH] feat: PrefetchDataset, BatchDataset, DownstreamDataset --- dmlcloud/util/data.py | 58 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/dmlcloud/util/data.py b/dmlcloud/util/data.py index 01188da..804e5bb 100644 --- a/dmlcloud/util/data.py +++ b/dmlcloud/util/data.py @@ -5,6 +5,7 @@ import torch.distributed as dist import xarray as xr from torch.utils.data import get_worker_info, IterableDataset +from concurrent.futures import ThreadPoolExecutor def shard_indices( @@ -199,6 +200,63 @@ def __iter__(self): load_kwargs=self.load_kwargs, ) +class DownstreamDataset(IterableDataset): + + def __init__(self, source_ds: Iterable[xr.Dataset]): + self.source_ds = source_ds + + def set_epoch(self, epoch: int): + if hasattr(self.source_ds, 'set_epoch'): + self.source_ds.set_epoch(epoch) + + def __len__(self): + return len(self.source_ds) + + +class PrefetchDataset(DownstreamDataset): + def __init__(self, source_ds: Iterable, num_elements: int): + super().__init__(source_ds) + self.num_elements = num_elements + + def __iter__(self): + pool = ThreadPoolExecutor(max_workers=1) + iter_ = iter(self.source_ds) + + with pool: + futures = [pool.submit(next, iter_) for _ in range(self.num_elements)] + while True: + future = futures.pop(0) + try: + element = future.result() + except StopIteration: + return + futures += [pool.submit(next, iter_)] + yield element + + +class BatchDataset(DownstreamDataset): + + def __init__(self, source_ds: Iterable, batch_size: int, drop_remainder: bool = False): + super().__init__(source_ds) + self.batch_size = batch_size + self.drop_remainder = drop_remainder + + def __len__(self): + if self.drop_remainder: + return len(self.source_ds) // self.batch_size + else: + return (len(self.source_ds) + self.batch_size - 1) // self.batch_size + + def __iter__(self): + batch = [] + for element in self.source_ds: + batch.append(element) + if len(batch) == self.batch_size: + yield batch + batch = [] + if batch and not self.drop_remainder: + yield batch + def interleave_batches( iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False