diff --git a/pvnet/data/utils.py b/pvnet/data/utils.py index e216a22d..0005a8ba 100644 --- a/pvnet/data/utils.py +++ b/pvnet/data/utils.py @@ -1,27 +1,33 @@ """Utils common between Wind and PV datamodules""" import numpy as np import torch -from ocf_datapipes.batch import unstack_np_batch_into_examples -from ocf_datapipes.utils.consts import BatchKey +from ocf_datapipes.batch import BatchKey, unstack_np_batch_into_examples from torch.utils.data import IterDataPipe, functional_datapipe def copy_batch_to_device(batch, device): """Moves a dict-batch of tensors to new device.""" batch_copy = {} - for k in list(batch.keys()): - if isinstance(batch[k], torch.Tensor): - batch_copy[k] = batch[k].to(device) + + for k, v in batch.items(): + if isinstance(v, dict): + # Recursion to reach the nested NWP + batch_copy[k] = copy_batch_to_device(v) + elif isinstance(v, torch.Tensor): + batch_copy[k] = v.to(device) else: - batch_copy[k] = batch[k] + batch_copy[k] = v return batch_copy def batch_to_tensor(batch): """Moves numpy batch to a tensor""" - for k in list(batch.keys()): - if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number): - batch[k] = torch.as_tensor(batch[k]) + for k, v in batch.items(): + if isinstance(v, dict): + # Recursion to reach the nested NWP + batch[k] = batch_to_tensor(v) + elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number): + batch[k] = torch.as_tensor(v) return batch