diff --git a/pvnet/data/utils.py b/pvnet/data/utils.py index 00cdeb78..f3a0d64c 100644 --- a/pvnet/data/utils.py +++ b/pvnet/data/utils.py @@ -23,9 +23,10 @@ def batch_to_tensor(batch): return batch -def split_batches(batch): +def split_batches(batch, splitting_key=BatchKey.gsp): """Splits a single batch of data.""" - n_samples = batch[BatchKey.gsp].shape[0] + + n_samples = batch[splitting_key].shape[0] keys = list(batch.keys()) examples = [{} for _ in range(n_samples)] for i in range(n_samples): @@ -42,12 +43,13 @@ def split_batches(batch): class BatchSplitter(IterDataPipe): """Pipeline step to split batches of data and yield single examples""" - def __init__(self, source_datapipe: IterDataPipe): + def __init__(self, source_datapipe: IterDataPipe, splitting_key: BatchKey = BatchKey.gsp): """Pipeline step to split batches of data and yield single examples""" self.source_datapipe = source_datapipe + self.splitting_key = splitting_key def __iter__(self): """Opens the NWP data""" for batch in self.source_datapipe: - for example in split_batches(batch): + for example in split_batches(batch, splitting_key=self.splitting_key): yield example