diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 23f227ae0..2a1e1b815 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -5,8 +5,8 @@ import numpy as np import xarray as xr -from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.config.model import Configuration @@ -92,6 +92,7 @@ def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): """ return x.where(x.gsp_id != 0, drop=True) + @functional_datapipe("pvnet_select_pv_by_ml_id") class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): """Select specific set of PV systems by ML ID.""" @@ -111,17 +112,15 @@ def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): def __iter__(self): for x in self.source_datapipe: - - # Check for missing IDs + # Check for missing IDs ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) if ml_ids_not_in_data.any(): missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] logger.warning( f"The following ML IDs were mising in the PV site-level input data: " f"{missing_ml_ids}. The values for these IDs will be set to NaN." - ) - + x_filtered = ( # Many ML-IDs are null, so filter first x.where(~x.ml_id.isnull(), drop=True)