diff --git a/ocf_datapipes/config/model.py b/ocf_datapipes/config/model.py index e446860c2..bb152fb21 100644 --- a/ocf_datapipes/config/model.py +++ b/ocf_datapipes/config/model.py @@ -272,6 +272,11 @@ class PV(DataSourceMixin, StartEndDatetimeMixin, TimeResolutionMixin, XYDimensio description="Tthe CSV files describing each PV system.", ) + pv_ml_ids: List[int] = Field( + None, + description="List of the ML IDs of the PV systems you'd like to filter to.", + ) + is_live: bool = Field( False, description="Option if to use live data from the nowcasting pv database" ) diff --git a/ocf_datapipes/convert/numpy/batch/pv.py b/ocf_datapipes/convert/numpy/batch/pv.py index 1043af18d..8a39d9b17 100644 --- a/ocf_datapipes/convert/numpy/batch/pv.py +++ b/ocf_datapipes/convert/numpy/batch/pv.py @@ -35,7 +35,8 @@ def __iter__(self) -> NumpyBatch: BatchKey.pv_t0_idx: xr_data.attrs["t0_idx"], BatchKey.pv_ml_id: xr_data["ml_id"].values, BatchKey.pv_id: xr_data["pv_system_id"].values.astype(np.float32), - BatchKey.pv_capacity_watt_power: xr_data["capacity_watt_power"].values, + BatchKey.pv_observed_capacity_wp: (xr_data["observed_capacity_wp"].values), + BatchKey.pv_nominal_capacity_wp: (xr_data["nominal_capacity_wp"].values), BatchKey.pv_time_utc: datetime64_to_float(xr_data["time_utc"].values), BatchKey.pv_latitude: xr_data["latitude"].values, BatchKey.pv_longitude: xr_data["longitude"].values, diff --git a/ocf_datapipes/load/__init__.py b/ocf_datapipes/load/__init__.py index 50f55ef15..23f021222 100644 --- a/ocf_datapipes/load/__init__.py +++ b/ocf_datapipes/load/__init__.py @@ -4,6 +4,7 @@ from ocf_datapipes.load.gsp.gsp_national import OpenGSPNationalIterDataPipe as OpenGSPNational from ocf_datapipes.load.nwp.providers.gfs import OpenGFSForecastIterDataPipe as OpenGFSForecast from ocf_datapipes.load.pv.database import OpenPVFromDBIterDataPipe as OpenPVFromDB +from ocf_datapipes.load.pv.database import OpenPVFromPVSitesDBIterDataPipe as OpenPVFromPVSitesDB from ocf_datapipes.load.pv.pv import OpenPVFromNetCDFIterDataPipe as OpenPVFromNetCDF from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration diff --git a/ocf_datapipes/load/pv/database.py b/ocf_datapipes/load/pv/database.py index 9b23c0162..307b17fed 100644 --- a/ocf_datapipes/load/pv/database.py +++ b/ocf_datapipes/load/pv/database.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd from nowcasting_datamodel.connection import DatabaseConnection -from nowcasting_datamodel.models.base import Base_PV +from nowcasting_datamodel.models.base import Base_Forecast, Base_PV from nowcasting_datamodel.models.pv import ( PVSystem, PVSystemSQL, @@ -17,6 +17,7 @@ solar_sheffield_passiv, ) from nowcasting_datamodel.read.read_pv import get_pv_systems, get_pv_yield +from sqlalchemy import text from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -75,6 +76,9 @@ def __iter__(self): load_extra_minutes=self.load_extra_minutes, ) + # Database record is very short. Set observed max to NaN + pv_metadata["observed_capacity_watt_power"] = np.nan + # select metadata that is in pv_power logger.debug( f"There are currently {len(pv_metadata.index)} pv system in the metadata, " @@ -91,7 +95,8 @@ def __iter__(self): # Compile data into an xarray DataArray data_xr = put_pv_data_into_an_xr_dataarray( df_gen=pv_power, - system_capacities=pv_metadata.capacity_watt_power, + observed_system_capacities=pv_metadata.observed_capacity_watt_power, + nominal_system_capacities=pv_metadata.capacity_watt_power, ml_id=pv_metadata.ml_id, latitude=pv_metadata.latitude, longitude=pv_metadata.longitude, @@ -339,3 +344,99 @@ def create_empty_pv_data( data.iloc[mask, i] = 0.0 logger.debug(f"Finished adding zeros to pv data for elevation below {sun_elevation_limit}") return data + + +@functional_datapipe("open_pv_from_pvsites_db") +class OpenPVFromPVSitesDBIterDataPipe(IterDataPipe): + """Data pipes and utils for getting PV data from pvsites database""" + + def __init__( + self, + history_minutes: int = 30, + ): + """ + Datapipe to get PV from pvsites database + + Args: + history_minutes: How many history minutes to use + """ + + super().__init__() + + self.history_minutes = history_minutes + self.history_duration = pd.Timedelta(self.history_minutes, unit="minutes") + + def __iter__(self): + df_metadata = get_metadata_from_pvsites_database() + df_gen = get_pv_power_from_pvsites_database(history_duration=self.history_duration) + + # Database record is very short. Set observed max to NaN + df_metadata["observed_capacity_wp"] = np.nan + + # Ensure systems are consistant between generation data, and metadata + common_systems = list(np.intersect1d(df_metadata.index, df_gen.columns)) + df_gen = df_gen[common_systems] + df_metadata = df_metadata.loc[common_systems] + + # Compile data into an xarray DataArray + xr_array = put_pv_data_into_an_xr_dataarray( + df_gen=df_gen, + observed_system_capacities=df_metadata.observed_capacity_wp, + nominal_system_capacities=df_metadata.nominal_capacity_wp, + ml_id=df_metadata.ml_id, + latitude=df_metadata.latitude, + longitude=df_metadata.longitude, + tilt=df_metadata.get("tilt"), + orientation=df_metadata.get("orientation"), + ) + + logger.info(f"Found {len(xr_array.ml_id)} PV systems") + + while True: + yield xr_array + + +def get_metadata_from_pvsites_database() -> pd.DataFrame: + """Load metadata from the pvsites database""" + # make database connection + url = os.getenv("DB_URL_PV") + db_connection = DatabaseConnection(url=url, base=Base_Forecast) + + with db_connection.engine.connect() as conn: + df_sites_metadata = pd.DataFrame(conn.execute(text("SELECT * FROM sites")).fetchall()) + + df_sites_metadata["nominal_capacity_wp"] = df_sites_metadata["capacity_kw"] * 1000 + + df_sites_metadata = df_sites_metadata.set_index("site_uuid") + + return df_sites_metadata + + +def get_pv_power_from_pvsites_database(history_duration: timedelta): + """Load recent generation data from the pvsites database""" + + # make database connection + url = os.getenv("DB_URL_PV") + db_connection = DatabaseConnection(url=url, base=Base_Forecast) + + columns = "site_uuid, generation_power_kw, start_utc, end_utc" + + start_time = f"{datetime.now() - history_duration}" + + with db_connection.engine.connect() as conn: + df_db_raw = pd.DataFrame( + conn.execute( + text(f"SELECT {columns} FROM generation where end_utc >= '{start_time}'") + ).fetchall() + ) + + # Reshape + df_gen = df_db_raw.pivot(index="end_utc", columns="site_uuid", values="generation_power_kw") + + # Rescale from kW to W + df_gen = df_gen * 1000 + + # Fix data types + df_gen = df_gen.astype(np.float32) + + return df_gen diff --git a/ocf_datapipes/load/pv/pv.py b/ocf_datapipes/load/pv/pv.py index 592586e24..bee2538f9 100644 --- a/ocf_datapipes/load/pv/pv.py +++ b/ocf_datapipes/load/pv/pv.py @@ -3,7 +3,7 @@ import logging from datetime import datetime from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import fsspec import numpy as np @@ -48,39 +48,21 @@ def __init__( self.end_datetime = pv.end_datetime def __iter__(self): - pv_datas_xr = [] + pv_array_list = [] for i in range(len(self.pv_power_filenames)): - one_data: xr.DataArray = load_everything_into_ram( + pv_array: xr.DataArray = load_everything_into_ram( self.pv_power_filenames[i], self.pv_metadata_filenames[i], start_datetime=self.start_datetime, end_datetime=self.end_datetime, inferred_metadata_filename=self.inferred_metadata_filenames[i], ) - pv_datas_xr.append(one_data) + pv_array_list.append(pv_array) - data = join_pv(pv_datas_xr) + pv_array = xr.concat(pv_array_list, dim="pv_system_id") while True: - yield data - - -def join_pv(data_arrays: List[xr.DataArray]) -> xr.DataArray: - """Join PV data arrays together. - - Args: - data_arrays: List of PV data arrays - - Returns: one data array containing all pv systems - """ - - if len(data_arrays) == 1: - return data_arrays[0] - - # expand each dataset to full time_utc - joined_data_array = xr.concat(data_arrays, dim="pv_system_id") - - return joined_data_array + yield pv_array def load_everything_into_ram( @@ -126,9 +108,10 @@ def load_everything_into_ram( estimated_capacities = estimated_capacities.loc[common_systems] # Compile data into an xarray DataArray - data_in_ram = put_pv_data_into_an_xr_dataarray( + xr_array = put_pv_data_into_an_xr_dataarray( df_gen=df_gen, - system_capacities=estimated_capacities, + observed_system_capacities=estimated_capacities, + nominal_system_capacities=df_metadata.capacity_watts, ml_id=df_metadata.ml_id, latitude=df_metadata.latitude, longitude=df_metadata.longitude, @@ -137,11 +120,11 @@ def load_everything_into_ram( ) # Sanity checks - time_utc = pd.DatetimeIndex(data_in_ram.time_utc) + time_utc = pd.DatetimeIndex(xr_array.time_utc) assert time_utc.is_monotonic_increasing assert time_utc.is_unique - return data_in_ram + return xr_array def _load_pv_generation_and_capacity( @@ -241,48 +224,51 @@ def _load_pv_metadata(filename: str, inferred_filename: Optional[str] = None) -> Shape of the returned pd.DataFrame for Passiv PV data: Index: ss_id (Sheffield Solar ID) Columns: llsoacd, orientation, tilt, kwp, operational_at, latitude, longitude, system_id, - ml_id + ml_id, capacity_watts """ _log.info(f"Loading PV metadata from {filename}") - if "passiv" in str(filename): - index_col = "ss_id" - else: - index_col = "system_id" - + index_col = "ss_id" if "passiv" in str(filename) else "system_id" df_metadata = pd.read_csv(filename, index_col=index_col) - # Maybe load inferred metadata if passiv - if inferred_filename is not None: - df_metadata = _load_inferred_metadata(filename, df_metadata) - - if "Unnamed: 0" in df_metadata.columns: - df_metadata.drop(columns="Unnamed: 0", inplace=True) + # Drop if exists + df_metadata.drop(columns="Unnamed: 0", inplace=True, errors="ignore") - # Add ml_id column if not in metadata + # Add ml_id column if not in metadata already if "ml_id" not in df_metadata.columns: df_metadata["ml_id"] = np.nan - _log.info(f"Found {len(df_metadata)} PV systems in {filename}") + if "passiv" in str(filename): + # Add capacity in watts + df_metadata["capacity_watts"] = df_metadata.kwp * 1000 + # Maybe load inferred metadata if passiv + if inferred_filename is not None: + df_metadata = _load_inferred_metadata(filename, df_metadata) + else: + # For PVOutput.org data + df_metadata["capacity_watts"] = df_metadata.system_size_watts + # Rename PVOutput.org tilt name to be simpler + # There is a second degree tilt, but this should be fine for now + if "array_tilt_degrees" in df_metadata.columns: + df_metadata["tilt"] = df_metadata["array_tilt_degrees"] + + # Need to change orientation to a number if a string (i.e. SE) that PVOutput.org uses by + # default + mapping = { + "N": 0.0, + "NE": 45.0, + "E": 90.0, + "SE": 135.0, + "S": 180.0, + "SW": 225.0, + "W": 270.0, + "NW": 315.0, + } + + # Any other keys other than those in the dict above are mapped to NaN + df_metadata["orientation"] = df_metadata.orientation.map(mapping) - # Rename PVOutput.org tilt name to be simpler - # There is a second degree tilt, but this should be fine for now - if "array_tilt_degrees" in df_metadata.columns: - df_metadata["tilt"] = df_metadata["array_tilt_degrees"] - - # Need to change orientation to a number if a string (i.e. SE) that PVOutput.org uses by default - mapping = { - "S": 180.0, - "SE": 135.0, - "SW": 225.0, - "E": 90.0, - "W": 270.0, - "N": 0.0, - "NE": 45.0, - "NW": 315.0, - "EW": np.nan, - } - df_metadata = df_metadata.replace({"orientation": mapping}) + _log.info(f"Found {len(df_metadata)} PV systems in {filename}") return df_metadata diff --git a/ocf_datapipes/load/pv/utils.py b/ocf_datapipes/load/pv/utils.py index 4e601a4cb..78369c252 100644 --- a/ocf_datapipes/load/pv/utils.py +++ b/ocf_datapipes/load/pv/utils.py @@ -12,7 +12,8 @@ def put_pv_data_into_an_xr_dataarray( df_gen: pd.DataFrame, - system_capacities: pd.Series, + observed_system_capacities: pd.Series, + nominal_system_capacities: pd.Series, ml_id: pd.Series, longitude: pd.Series, latitude: pd.Series, @@ -24,7 +25,9 @@ def put_pv_data_into_an_xr_dataarray( Args: df_gen: pd.DataFrame where the columns are PV systems (and the column names are ints), and the index is UTC datetime - system_capacities: The max power output of each PV system in Watts. Index is PV system IDs. + observed_system_capacities: The max power output observed in the time series for PV system + in watts. Index is PV system IDs + nominal_system_capacities: The metadata value for each PV system capacities in watts ml_id: The `ml_id` used to identify each PV system longitude: longitude of the locations latitude: latitude of the locations @@ -34,14 +37,18 @@ def put_pv_data_into_an_xr_dataarray( # Sanity check! system_ids = df_gen.columns for name, series in ( + ("observed_system_capacities", observed_system_capacities), + ("nominal_system_capacities", nominal_system_capacities), + ("ml_id", ml_id), ("longitude", longitude), ("latitude", latitude), - ("system_capacities", system_capacities), + ("tilt", tilt), + ("orientation", orientation), ): - logger.debug(f"Checking {name}") - if not np.array_equal(series.index, system_ids, equal_nan=True): - logger.debug(f"Index of {name} does not equal {system_ids}. Index is {series.index}") - assert np.array_equal(series.index, system_ids, equal_nan=True) + if (series is not None) and (not np.array_equal(series.index, system_ids)): + raise ValueError( + f"Index of {name} does not equal {system_ids}. Index is {series.index}" + ) data_array = xr.DataArray( data=df_gen.values, @@ -53,10 +60,11 @@ def put_pv_data_into_an_xr_dataarray( ).astype(np.float32) data_array = data_array.assign_coords( + observed_capacity_wp=("pv_system_id", observed_system_capacities), + nominal_capacity_wp=("pv_system_id", nominal_system_capacities), + ml_id=("pv_system_id", ml_id), longitude=("pv_system_id", longitude), latitude=("pv_system_id", latitude), - capacity_watt_power=("pv_system_id", system_capacities), - ml_id=("pv_system_id", ml_id), ) if tilt is not None: diff --git a/ocf_datapipes/select/drop_pv_sys_generating_overnight.py b/ocf_datapipes/select/drop_pv_sys_generating_overnight.py index 10022e7bf..cfab6aadb 100644 --- a/ocf_datapipes/select/drop_pv_sys_generating_overnight.py +++ b/ocf_datapipes/select/drop_pv_sys_generating_overnight.py @@ -44,7 +44,7 @@ def __iter__(self) -> xr.DataArray(): ds_night = ds.where(ds.status_daynight == "night", drop=True) # Find relative maximum night-time generation for each system - night_time_max_gen = (ds_night / ds_night.capacity_watt_power).max(dim="time_utc") + night_time_max_gen = (ds_night / ds_night.observed_capacity_wp).max(dim="time_utc") # Find systems above threshold mask = night_time_max_gen > self.threshold diff --git a/ocf_datapipes/select/select_pv_systems_on_capacity.py b/ocf_datapipes/select/select_pv_systems_on_capacity.py index cb9d76875..ebf71fb45 100644 --- a/ocf_datapipes/select/select_pv_systems_on_capacity.py +++ b/ocf_datapipes/select/select_pv_systems_on_capacity.py @@ -33,7 +33,7 @@ def __init__( def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: for ds in self.source_datapipe: - too_low = ds.capacity_watt_power < self.min_capacity_watts - too_high = ds.capacity_watt_power > self.max_capacity_watts + too_low = ds.observed_capacity_wp < self.min_capacity_watts + too_high = ds.observed_capacity_wp > self.max_capacity_watts mask = np.logical_or(too_low, too_high) yield ds.where(~mask, drop=True) diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 606a71aff..2b0e16262 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -228,7 +228,7 @@ def get_and_return_overlapping_time_periods_and_t0(used_datapipes: dict, key_for for i, key in enumerate(list(datapipes_to_return.keys())): datapipes_to_return[key + "_t0"] = t0_datapipes[i] - # Readd config for later + # Re-add config for later datapipes_to_return["config"] = configuration if "topo" in used_datapipes.keys(): datapipes_to_return["topo"] = used_datapipes["topo"] diff --git a/ocf_datapipes/training/example/nwp_pv.py b/ocf_datapipes/training/example/nwp_pv.py index 905c0d5e1..63db7630a 100644 --- a/ocf_datapipes/training/example/nwp_pv.py +++ b/ocf_datapipes/training/example/nwp_pv.py @@ -65,7 +65,7 @@ def nwp_pv_datapipe( minutes=configuration.input_data.pv.time_resolution_minutes ), history_duration=timedelta(minutes=configuration.input_data.pv.history_minutes), - ).normalize(normalize_fn=lambda x: x / x.capacity_watt_power) + ).normalize(normalize_fn=lambda x: x / x.observed_capacity_wp) nwp_datapipe = nwp_datapipe.add_t0_idx_and_sample_period_duration( sample_period_duration=timedelta( minutes=configuration.input_data.nwp.time_resolution_minutes diff --git a/ocf_datapipes/training/example/simple_pv.py b/ocf_datapipes/training/example/simple_pv.py index 652d07b95..f5b80b106 100644 --- a/ocf_datapipes/training/example/simple_pv.py +++ b/ocf_datapipes/training/example/simple_pv.py @@ -63,7 +63,7 @@ def simple_pv_datapipe( logger.debug("Making PV space slice") pv_datapipe, pv_t0_datapipe, pv_time_periods_datapipe = ( - pv_datapipe.normalize(normalize_fn=lambda x: x / x.capacity_watt_power) + pv_datapipe.normalize(normalize_fn=lambda x: x / x.observed_capacity_wp) .add_t0_idx_and_sample_period_duration( sample_period_duration=timedelta( minutes=configuration.input_data.pv.time_resolution_minutes diff --git a/ocf_datapipes/training/metnet_gsp_national.py b/ocf_datapipes/training/metnet_gsp_national.py index e40a9df14..6d96ee8d8 100644 --- a/ocf_datapipes/training/metnet_gsp_national.py +++ b/ocf_datapipes/training/metnet_gsp_national.py @@ -45,7 +45,7 @@ def normalize_pv(x): # So it can be pickled Returns: Normalized DataArray """ - return x / x.capacity_watt_power + return x / x.observed_capacity_wp def _remove_nans(x): diff --git a/ocf_datapipes/training/metnet_pv_national.py b/ocf_datapipes/training/metnet_pv_national.py index 54f11f8b6..47607c8d1 100644 --- a/ocf_datapipes/training/metnet_pv_national.py +++ b/ocf_datapipes/training/metnet_pv_national.py @@ -50,7 +50,7 @@ def normalize_pv(x): # So it can be pickled Returns: Normalized DataArray """ - return x / x.capacity_watt_power + return x / x.observed_capacity_wp def _remove_nans(x): diff --git a/ocf_datapipes/training/metnet_pv_site.py b/ocf_datapipes/training/metnet_pv_site.py index 99c464012..05b10af18 100644 --- a/ocf_datapipes/training/metnet_pv_site.py +++ b/ocf_datapipes/training/metnet_pv_site.py @@ -33,7 +33,7 @@ def normalize_pv(x): # So it can be pickled Returns: Normalized DataArray """ - return x / x.capacity_watt_power + return x / x.observed_capacity_wp def _remove_nans(x): diff --git a/ocf_datapipes/training/pseudo_irradience.py b/ocf_datapipes/training/pseudo_irradience.py index 6c435dfaa..00f7c0d1b 100644 --- a/ocf_datapipes/training/pseudo_irradience.py +++ b/ocf_datapipes/training/pseudo_irradience.py @@ -36,7 +36,7 @@ def normalize_pv(x): # So it can be pickled Returns: Normalized DataArray """ - return x / x.capacity_watt_power + return x / x.observed_capacity_wp def _remove_nans(x): @@ -111,7 +111,7 @@ def _normalize_by_pvlib(pv_system): clear_sky["dni"] + clear_sky["dhi"] + clear_sky["ghi"] ) print(fraction_clear_sky) - pv_system /= pv_system.capacity_watt_power + pv_system /= pv_system.observed_capacity_wp print(pv_system) pv_system *= fraction_clear_sky print(pv_system) @@ -468,8 +468,6 @@ def pseudo_irradiance_datapipe( pv_loc_datapipe, pv_sav_loc = LocationPicker( pv_loc_datapipe, return_all_locations=True if is_test else False, - x_dim_name="latitude", - y_dim_name="longitude", ).fork(2, buffer_size=-1) pv_sav_loc = pv_sav_loc.map(_get_id_from_location) pv_meta_save = pv_meta_save.map(_extract_test_info) diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 9ffed2627..27015e74e 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -5,11 +5,12 @@ import numpy as np import xarray as xr +from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.batch import MergeNumpyModalities from ocf_datapipes.config.model import Configuration -from ocf_datapipes.load import OpenGSPFromDatabase +from ocf_datapipes.load import OpenGSPFromDatabase, OpenPVFromPVSitesDB from ocf_datapipes.training.common import ( create_t0_and_loc_datapipes, open_and_return_datapipes, @@ -39,6 +40,18 @@ def normalize_gsp(x): return x / x.effective_capacity_mwp +def normalize_pv(x): + """Normalize the PV data + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return (x / x.nominal_capacity_wp).clip(None, 5) + + def production_sat_scale(x): """Scale the production satellite data @@ -51,18 +64,18 @@ def production_sat_scale(x): return x / 1024 -def pvnet_concat_gsp(gsp_dataarrays: List[xr.DataArray]): - """This function is used to combine the split history and future gsp dataarrays. +def concat_xr_time_utc(gsp_dataarrays: List[xr.DataArray]): + """This function is used to combine the split history and future gsp/pv dataarrays. These are split inside the `slice_datapipes_by_time()` function below. Splitting them inside that function allows us to apply dropout to the - history GSP whilst leaving the future GSP without NaNs. + history GSP/PV whilst leaving the future GSP/PV without NaNs. We recombine the history and future with this function to allow us to use the `MergeNumpyModalities()` datapipe without redefining the BatchKeys. - The `pvnet` model was also written to use a GSP array which has historical and future + The `pvnet` model was also written to use a GSP/PV array which has historical and future and to split it out. These maintains that assumption. """ return xr.concat(gsp_dataarrays, dim="time_utc") @@ -80,6 +93,59 @@ def gsp_drop_national(x: Union[xr.DataArray, xr.Dataset]): return x.where(x.gsp_id != 0, drop=True) +@functional_datapipe("pvnet_select_pv_by_ml_id") +class PVNetSelectPVbyMLIDIterDataPipe(IterDataPipe): + """Select specific set of PV systems by ML ID.""" + + def __init__(self, source_datapipe: IterDataPipe, ml_ids: np.array): + """Select specific set of PV systems by ML ID. + + Args: + source_datapipe: Datapipe emitting PV xarray data + ml_ids: List-like of ML IDs to select + + Returns: + Filtered data source + """ + self.source_datapipe = source_datapipe + self.ml_ids = ml_ids + + def __iter__(self): + for x in self.source_datapipe: + # Check for missing IDs + ml_ids_not_in_data = ~np.isin(self.ml_ids, x.ml_id) + if ml_ids_not_in_data.any(): + missing_ml_ids = np.array(self.ml_ids)[ml_ids_not_in_data] + logger.warning( + f"The following ML IDs were mising in the PV site-level input data: " + f"{missing_ml_ids}. The values for these IDs will be set to NaN." + ) + + x_filtered = ( + # Many ML-IDs are null, so filter first + x.where(~x.ml_id.isnull(), drop=True) + # Swap dimensions so we can select by ml_id coordinate + .swap_dims({"pv_system_id": "ml_id"}) + # Select IDs - missing IDs are given NaN values + .reindex(ml_id=self.ml_ids) + # Swap back dimensions + .swap_dims({"ml_id": "pv_system_id"}) + ) + yield x_filtered + + +def fill_nans_in_pv(x: Union[xr.DataArray, xr.Dataset]): + """Fill NaNs in PV data with the value -1 + + Args: + x: Input DataArray + + Returns: + Normalized DataArray + """ + return x.fillna(-1) + + def fill_nans_in_arrays(batch: NumpyBatch) -> NumpyBatch: """Fills all NaN values in each np.ndarray in the batch dictionary with zeros. @@ -223,21 +289,30 @@ def _get_datapipes_dict( datapipes_dict = open_and_return_datapipes( configuration_filename=config_filename, use_gsp=(not production), - use_pv=False, + use_pv=(not production), use_sat=not block_sat, # Only loaded if we aren't replacing them with zeros use_hrv=False, use_nwp=not block_nwp, # Only loaded if we aren't replacing them with zeros use_topo=False, production=production, ) - if production: - configuration: Configuration = datapipes_dict["config"] + config: Configuration = datapipes_dict["config"] + + if production: datapipes_dict["gsp"] = OpenGSPFromDatabase().add_t0_idx_and_sample_period_duration( sample_period_duration=timedelta(minutes=30), - history_duration=timedelta(minutes=configuration.input_data.gsp.history_minutes), + history_duration=timedelta(minutes=config.input_data.gsp.history_minutes), + ) + if "sat" in datapipes_dict: + datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) + if "pv" in datapipes_dict: + datapipes_dict["pv"] = OpenPVFromPVSitesDB(config.input_data.pv.history_minutes) + + if "pv" in datapipes_dict and config.input_data.pv.pv_ml_ids != []: + datapipes_dict["pv"] = datapipes_dict["pv"].pvnet_select_pv_by_ml_id( + config.input_data.pv.pv_ml_ids ) - datapipes_dict["sat"] = datapipes_dict["sat"].map(production_sat_scale) return datapipes_dict @@ -420,13 +495,33 @@ def slice_datapipes_by_time( ) datapipes_dict["pv"] = datapipes_dict["pv"].select_time_slice( - t0_datapipe=get_t0_datapipe("pv"), + t0_datapipe=get_t0_datapipe(None), sample_period_duration=minutes(5), interval_start=minutes(-conf_in.pv.history_minutes), interval_end=minutes(0), fill_selection=production, ) + # Dropout on the PV, but not the future PV + pv_dropout_time_datapipe = get_t0_datapipe("pv").select_dropout_time( + # All PV data could be delayed by up to 30 minutes + # (this does not stem from production - just setting for now) + dropout_timedeltas=[minutes(m) for m in range(-30, 0, 5)], + dropout_frac=0.1 if production else 1, + ) + + datapipes_dict["pv"] = datapipes_dict["pv"].apply_dropout_time( + dropout_time_datapipe=pv_dropout_time_datapipe, + ) + + # Apply extra PV dropout using different delays per system and droping out entire PV systems + # independently + if not production: + datapipes_dict["pv"].apply_pv_dropout( + system_dropout_fractions=np.linspace(0, 0.2, 100), + system_dropout_timedeltas=[minutes(m) for m in [-15, -10, -5, 0]], + ) + if "gsp" in datapipes_dict: datapipes_dict["gsp"], dp = datapipes_dict["gsp"].fork(2, buffer_size=5) @@ -530,6 +625,16 @@ def construct_sliced_data_pipeline( sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD) numpy_modalities.append(sat_datapipe.convert_satellite_to_numpy_batch()) + if "pv" in datapipes_dict: + # Recombine PV arrays - see function doc for further explanation + pv_datapipe = ( + datapipes_dict["pv"].zip_ocf(datapipes_dict["pv_future"]).map(concat_xr_time_utc) + ) + pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv) + pv_datapipe = pv_datapipe.map(fill_nans_in_pv) + + numpy_modalities.append(pv_datapipe.convert_pv_to_numpy_batch()) + # GSP always assumed to be in data location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5) gsp_future_datapipe = datapipes_dict["gsp_future"] @@ -549,7 +654,7 @@ def construct_sliced_data_pipeline( ) # Recombine GSP arrays - see function doc for further explanation - gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(pvnet_concat_gsp) + gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(concat_xr_time_utc) gsp_datapipe = gsp_datapipe.normalize(normalize_fn=normalize_gsp) numpy_modalities.append(gsp_datapipe.convert_gsp_to_numpy_batch()) diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index 785563e8a..1b3da6222 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -54,6 +54,7 @@ from .pv.create_pv_meta_image import ( CreatePVMetadataImageIterDataPipe as CreatePVMetadataImage, ) +from .pv_dropout import ApplyPVDropoutIterDataPipe as ApplyPVDropout from .remove_nans import RemoveNansIterDataPipe as RemoveNans from .reproject_topographic_data import ( ReprojectTopographyIterDataPipe as ReprojectTopography, diff --git a/ocf_datapipes/transform/xarray/pv/create_pv_image.py b/ocf_datapipes/transform/xarray/pv/create_pv_image.py index 9ccbbcdf5..9d976c433 100644 --- a/ocf_datapipes/transform/xarray/pv/create_pv_image.py +++ b/ocf_datapipes/transform/xarray/pv/create_pv_image.py @@ -209,6 +209,6 @@ def _normalize_by_pvlib(pv_system): fraction_clear_sky = total_irradiance["poa_global"] / ( clear_sky["dni"] + clear_sky["dhi"] + clear_sky["ghi"] ) - pv_system /= pv_system.capacity_watt_power + pv_system /= pv_system.observed_capacity_wp pv_system *= fraction_clear_sky return pv_system diff --git a/ocf_datapipes/transform/xarray/pv_dropout.py b/ocf_datapipes/transform/xarray/pv_dropout.py new file mode 100644 index 000000000..cb8a9aa52 --- /dev/null +++ b/ocf_datapipes/transform/xarray/pv_dropout.py @@ -0,0 +1,89 @@ +"""Convert NWP data to the target time with dropout""" +import logging +from datetime import timedelta +from typing import List, Union + +import numpy as np +import pandas as pd +import xarray as xr +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +logger = logging.getLogger(__name__) + + +@functional_datapipe("apply_pv_dropout") +class ApplyPVDropoutIterDataPipe(IterDataPipe): + """Apply PV system dropout to mimic production + + Systems have independent delay times. Systems may also completely drop out. + + """ + + def __init__( + self, + source_datapipe: IterDataPipe, + system_dropout_fractions: List[float], + system_dropout_timedeltas: List[timedelta], + ): + """Apply PV system dropout to mimic production + + Systems have independent delay times. Systems may also completely drop out. + + Args: + source_datapipe: Datapipe emitting an Xarray Dataset with time_utc indexer. + system_dropout_fractions: List of possible system dropout fractions to apply to each + sample. For each yielded sample, one of these fractions will be chosen and used to + dropout each PV system. Using a list instead of a single value allows us to avoid + overfitting to the fraction of dropped out systems. + system_dropout_timedeltas: List of timedeltas. We randomly select the delay for each PV + system from this list. These should be negative timedeltas w.r.t the last time_utc + value of the xarray data. + """ + self.source_datapipe = source_datapipe + self.system_dropout_fractions = system_dropout_fractions + self.system_dropout_timedeltas = system_dropout_timedeltas + + assert ( + len(system_dropout_timedeltas) >= 1 + ), "Must include list of relative dropout timedeltas" + + assert all( + [t <= timedelta(minutes=0) for t in system_dropout_timedeltas] + ), f"dropout timedeltas must be negative: {system_dropout_timedeltas}" + + assert all( + [0 <= f <= 1 for f in system_dropout_fractions] + ), "dropout fractions must be in open range (0, 1)" + + def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: + """Iterate through Xarray dataset using dropout""" + + for xr_data in self.source_datapipe: + # Assign these values for convenience + t0 = pd.Timestamp(xr_data.time_utc.values[-1]) + n_systems = len(xr_data.pv_system_id) + + # Apply PV system dropout - individual systems are dropped out + + # Don't want fraction of dropped out system to be the same in each sample + # This might lead to overfitting. Instead vary it + dropout_p = np.random.choice(self.system_dropout_fractions) + + system_mask = xr.zeros_like(xr_data.pv_system_id, dtype=bool) + system_mask.values[:] = np.random.uniform(size=n_systems) >= dropout_p + + # Apply independent delay to each PV system + delay_mask = xr.zeros_like(xr_data, dtype=bool) + + last_available_times = xr.zeros_like(xr_data.pv_system_id, dtype=xr_data.time_utc.dtype) + last_available_times.values[:] = t0 + np.random.choice( + self.system_dropout_timedeltas, size=n_systems + ) + + delay_mask = xr_data.time_utc <= last_available_times + + # Apply masking + xr_data = xr_data.where(system_mask).where(delay_mask) + + yield xr_data diff --git a/ocf_datapipes/utils/consts.py b/ocf_datapipes/utils/consts.py index 2430547f3..c3f743425 100644 --- a/ocf_datapipes/utils/consts.py +++ b/ocf_datapipes/utils/consts.py @@ -7,56 +7,14 @@ import xarray as xr from pydantic import BaseModel, validator -PV_TIME_AXIS = 1 -PV_SYSTEM_AXIS = 2 - Y_OSGB_MEAN = 357021.38 Y_OSGB_STD = 612920.2 X_OSGB_MEAN = 187459.94 X_OSGB_STD = 622805.44 -SATELLITE_SPACER_LEN = 17 # Patch of 4x4 + 1 for surface height. -PV_SPACER_LEN = 18 # 16 for embedding dim + 1 for marker + 1 for history - -PV_SYSTEM_ID: str = "pv_system_id" -PV_ML_ID = "pv_ml_id" -PV_SYSTEM_X_COORDS = "pv_system_x_coords" -PV_SYSTEM_Y_COORDS = "pv_system_y_coords" - -SUN_AZIMUTH_ANGLE = "sun_azimuth_angle" -SUN_ELEVATION_ANGLE = "sun_elevation_angle" -PV_YIELD = "pv_yield" -PV_DATETIME_INDEX = "pv_datetime_index" DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE = 2048 -GSP_ID: str = "gsp_id" -GSP_YIELD = "gsp_yield" -GSP_X_COORDS = "gsp_x_coords" -GSP_Y_COORDS = "gsp_y_coords" -GSP_DATETIME_INDEX = "gsp_datetime_index" -N_GSPS = 317 - DEFAULT_N_GSP_PER_EXAMPLE = 32 -OBJECT_AT_CENTER = "object_at_center" -DATETIME_FEATURE_NAMES = ( - "hour_of_day_sin", - "hour_of_day_cos", - "day_of_year_sin", - "day_of_year_cos", -) -SATELLITE_DATA = "sat_data" -SATELLITE_Y_COORDS = "sat_y_coords" -SATELLITE_X_COORDS = "sat_x_coords" -SATELLITE_DATETIME_INDEX = "sat_datetime_index" -NWP_TARGET_TIME = "nwp_target_time" -NWP_DATA = "nwp" -NWP_X_COORDS = "nwp_x_coords" -NWP_Y_COORDS = "nwp_y_coords" -X_CENTERS_OSGB = "x_centers_osgb" -Y_CENTERS_OSGB = "y_centers_osgb" -TOPOGRAPHIC_DATA = "topo_data" -TOPOGRAPHIC_X_COORDS = "topo_x_coords" -TOPOGRAPHIC_Y_COORDS = "topo_y_coords" # "Safe" default NWP variable names: NWP_VARIABLE_NAMES = ( @@ -72,27 +30,6 @@ "hcc", ) -# A complete set of NWP variable names. Not all are currently used. -FULL_NWP_VARIABLE_NAMES = ( - "cdcb", - "lcc", - "mcc", - "hcc", - "sde", - "hcct", - "dswrf", - "dlwrf", - "h", - "t", - "r", - "dpt", - "vis", - "si10", - "wdir10", - "prmsl", - "prate", -) - SAT_VARIABLE_NAMES = ( "HRV", "IR_016", @@ -108,38 +45,6 @@ "WV_073", ) -DEFAULT_REQUIRED_KEYS = [ - NWP_DATA, - NWP_X_COORDS, - NWP_Y_COORDS, - SATELLITE_DATA, - SATELLITE_X_COORDS, - SATELLITE_Y_COORDS, - PV_YIELD, - PV_SYSTEM_ID, - PV_ML_ID, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - X_CENTERS_OSGB, - Y_CENTERS_OSGB, - GSP_ID, - GSP_YIELD, - GSP_X_COORDS, - GSP_Y_COORDS, - GSP_DATETIME_INDEX, - TOPOGRAPHIC_DATA, - TOPOGRAPHIC_Y_COORDS, - TOPOGRAPHIC_X_COORDS, -] + list(DATETIME_FEATURE_NAMES) -T0_DT = "t0_dt" - - -SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME = ( - "spatial_and_temporal_locations_of_each_example.csv" -) - -LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR") - class Location(BaseModel): """Represent a spatial location.""" @@ -257,10 +162,8 @@ class BatchKey(Enum): pv_t0_idx = auto() # shape: scalar pv_ml_id = auto() # shape: (batch_size, n_pv_systems) pv_id = auto() # shape: (batch_size, n_pv_systems) - # PV AC system capacity in watts peak. - # Warning: In v15, pv_capacity_watt_power is sometimes 0. This will be fixed in - # https://github.com/openclimatefix/nowcasting_dataset/issues/622 - pv_capacity_watt_power = auto() # shape: (batch_size, n_pv_systems) + pv_observed_capacity_wp = auto() # shape: (batch_size, n_pv_systems) + pv_nominal_capacity_wp = auto() # shape: (batch_size, n_pv_systems) #: pv_mask is True for good PV systems in each example. # The RawPVDataSource doesn't use pv_mask. Instead is sets missing PV systems to NaN # across all PV batch keys. @@ -281,8 +184,7 @@ class BatchKey(Enum): pv_time_utc_fourier_t0 = auto() # Added by SaveT0Time. Shape: (batch_size, n_fourier_features) # -------------- GSP -------------------------------------------- - gsp = auto() # shape: (batch_size, time, 1) (the RawGSPDataSource include a '1', - # not sure if the prepared dataset does!) + gsp = auto() # shape: (batch_size, time, 1) gsp_t0_idx = auto() # shape: scalar gsp_id = auto() # shape: (batch_size) diff --git a/tests/config/test.yaml b/tests/config/test.yaml index 5589ff402..b0ad59069 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -30,6 +30,7 @@ input_data: n_pv_systems_per_example: 32 start_datetime: "2010-01-01 00:00:00" end_datetime: "2030-01-01 00:00:00" + pv_ml_ids: [] satellite: satellite_channels: - IR_016 diff --git a/tests/conftest.py b/tests/conftest.py index 8e2006992..d093a90a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,8 @@ import tempfile from datetime import datetime, timedelta, timezone from pathlib import Path +import uuid + import numpy as np import pandas as pd @@ -148,7 +150,57 @@ def gsp_datapipe(): @pytest.fixture -def db_connection(): +def pv_system_db_data(): + # Create generation data + n_systems = 10 + + t0 = pd.Timestamp.now().floor("5T") + datetimes = pd.date_range(t0 - timedelta(minutes=120), t0, freq="5T") + site_uuids = [str(uuid.uuid4()) for _ in range(n_systems)] + + data = np.zeros((len(datetimes), n_systems)) + + # Make data a nice sinusoidal curve + data[:] = ( + 0.5 + * (1 - np.cos((datetimes.hour + datetimes.minute / 60) / 24 * 2 * np.pi).values)[:, None] + ) + + # Chuck in some nan values + data[:, 1] = np.nan + data[-5:, 2] = np.nan + data[::3, 3] = np.nan + + da = xr.DataArray( + data, + coords=( + ("end_utc", datetimes), + ("site_uuid", site_uuids), + ), + ) + + # Reshape for tabular database + df_gen = da.to_dataframe("generation_power_kw").reset_index() + df_gen["start_utc"] = df_gen["end_utc"] - timedelta(minutes=5) + + # Create metadata + df_meta = pd.DataFrame( + dict( + site_uuid=site_uuids, + orientation=np.random.uniform(0, 360, n_systems), + tilt=np.random.uniform(0, 90, n_systems), + longitude=np.random.uniform(-3.07, 0.59, n_systems), + latitude=np.random.uniform(51.56, 52.89, n_systems), + capacity_kw=np.random.uniform(1, 5, n_systems), + ml_id=np.arange(n_systems), + ) + ) + + return df_gen, df_meta + + +@pytest.fixture +def db_connection(pv_system_db_data): """Create data connection""" with tempfile.NamedTemporaryFile(suffix=".db") as temp: @@ -167,10 +219,14 @@ def db_connection(): for table in [PVYieldSQL, PVSystemSQL, GSPYieldSQL, LocationSQL]: table.__table__.create(connection.engine) - yield connection + # Create and populate pvsites tables + df_gen, df_meta = pv_system_db_data + with connection.engine.connect() as conn: + df_gen.to_sql(name="generation", con=conn, index=False) + df_meta.to_sql(name="sites", con=conn, index=False) + conn.commit() - for table in [PVYieldSQL, PVSystemSQL, GSPYieldSQL, LocationSQL]: - table.__table__.drop(connection.engine) + yield connection @pytest.fixture(scope="function", autouse=True) diff --git a/tests/load/pv/test_pv_database.py b/tests/load/pv/test_pv_database.py index 4fe33bd8d..b2d204a0e 100644 --- a/tests/load/pv/test_pv_database.py +++ b/tests/load/pv/test_pv_database.py @@ -10,13 +10,15 @@ OpenPVFromDBIterDataPipe, get_metadata_from_database, get_pv_power_from_database, + OpenPVFromPVSitesDBIterDataPipe, + get_metadata_from_pvsites_database, + get_pv_power_from_pvsites_database, ) def test_get_metadata_from_database(pv_yields_and_systems): """Test get meteadata from database""" meteadata = get_metadata_from_database() - assert len(meteadata) == 4 @@ -120,3 +122,42 @@ def test_open_pv_datasource_from_database_config(pv_yields_and_systems): pv_datapipe = OpenPVFromDBIterDataPipe(pv_config=pv_config) data = next(iter(pv_datapipe)) assert data is not None + + +def test_get_pv_power_from_pvsites_database(): + df_gen = get_pv_power_from_pvsites_database(timedelta(minutes=30)) + # 30 minutes so 5 five-minutely timestamps, 10 PV systems + assert df_gen.shape == (6, 10) + + +def test_get_metadata_from_pvsites_database(): + df_meta = get_metadata_from_pvsites_database() + assert len(df_meta) == 10 + for column in [ + "orientation", + "tilt", + "longitude", + "latitude", + "capacity_kw", + "ml_id", + ]: + assert column in df_meta.columns + + +def test_open_pv_from_pvsites_db(): + dp = OpenPVFromPVSitesDBIterDataPipe(history_minutes=30) + da = next(iter(dp)) + # 30 minutes so 5 five-minutely timestamps, 10 PV systems + assert da.shape == (6, 10) + for variable in [ + "time_utc", + "pv_system_id", + "observed_capacity_wp", + "nominal_capacity_wp", + "orientation", + "tilt", + "longitude", + "latitude", + "ml_id", + ]: + assert variable in da.coords diff --git a/tests/select/test_drop_pv_generating_overnight.py b/tests/select/test_drop_pv_generating_overnight.py index d69ea265c..3bba1489b 100644 --- a/tests/select/test_drop_pv_generating_overnight.py +++ b/tests/select/test_drop_pv_generating_overnight.py @@ -70,7 +70,7 @@ def test_drop_with_constructed_dataarray(): coords=ALL_COORDS, ) data_array = data_array.assign_coords( - capacity_watt_power=("pv_system_id", np.ones(len(pv_system_id))), + observed_capacity_wp=("pv_system_id", np.ones(len(pv_system_id))), ) # run the function diff --git a/tests/transform/xarray/test_normalize.py b/tests/transform/xarray/test_normalize.py index 89e3d1459..66c188088 100644 --- a/tests/transform/xarray/test_normalize.py +++ b/tests/transform/xarray/test_normalize.py @@ -37,7 +37,7 @@ def test_normalize_gsp(gsp_datapipe): def test_normalize_passiv(passiv_datapipe): - passiv_datapipe = passiv_datapipe.normalize(normalize_fn=lambda x: x / x.capacity_watt_power) + passiv_datapipe = passiv_datapipe.normalize(normalize_fn=lambda x: x / x.observed_capacity_wp) data = next(iter(passiv_datapipe)) assert np.min(data) >= 0.0 assert np.max(data) <= 1.0 @@ -45,7 +45,7 @@ def test_normalize_passiv(passiv_datapipe): def test_normalize_pvoutput(pvoutput_datapipe): pvoutput_datapipe = pvoutput_datapipe.normalize( - normalize_fn=lambda x: x / x.capacity_watt_power + normalize_fn=lambda x: x / x.observed_capacity_wp ) data = next(iter(pvoutput_datapipe)) assert np.min(data) >= 0.0 diff --git a/tests/transform/xarray/test_pv_dropout.py b/tests/transform/xarray/test_pv_dropout.py new file mode 100644 index 000000000..280f27b73 --- /dev/null +++ b/tests/transform/xarray/test_pv_dropout.py @@ -0,0 +1,70 @@ +from datetime import timedelta +from torchdata.datapipes.iter import IterableWrapper +import numpy as np + +from ocf_datapipes.transform.xarray import ApplyPVDropout + + +def test_apply_pv_dropout(passiv_datapipe): + data = ( + next(iter(passiv_datapipe)) + .isel(pv_system_id=slice(0, 50)) + .isel(time_utc=slice(-10, None)) + .compute() + ) + + data = data.fillna(0) + + pv_datapipe = IterableWrapper([data for _ in range(3)]) + + # ---------------- + # Apply no dropout + pv_dropout_datapipe = ApplyPVDropout( + pv_datapipe, + system_dropout_fractions=[0], + system_dropout_timedeltas=[timedelta(minutes=0)], + ) + + # No dropout should have been applied + for pv_data in pv_dropout_datapipe: + assert not np.isnan(pv_data.values).any() + + # -------------------------- + # Apply only system dropout + pv_dropout_datapipe = ApplyPVDropout( + pv_datapipe, + system_dropout_fractions=[0.5], + system_dropout_timedeltas=[timedelta(minutes=0)], + ) + + # Each system should have either all NaNs or no NaNs + for pv_data in pv_dropout_datapipe: + all_system_nan = pv_data.isnull().all(dim="time_utc") + any_system_nan = pv_data.isnull().any(dim="time_utc") + assert np.logical_or(all_system_nan.values, ~any_system_nan.values).all() + + # -------------------------- + # Apply only delay dropout + pv_dropout_datapipe = ApplyPVDropout( + pv_datapipe, + system_dropout_fractions=[0.0], + system_dropout_timedeltas=[timedelta(minutes=-5)], + ) + + # Each system should have 1 NaN + for pv_data in pv_dropout_datapipe: + assert (pv_data.isnull().sum(dim="time_utc") == 1).all() + + # -------------------------- + # Apply combo dropout + pv_dropout_datapipe = ApplyPVDropout( + pv_datapipe, + system_dropout_fractions=[0.5], + system_dropout_timedeltas=[timedelta(minutes=-5)], + ) + + # Each system should have either all NaNs or one NaNs + for pv_data in pv_dropout_datapipe: + all_system_nan = pv_data.isnull().all(dim="time_utc") + one_system_nan = pv_data.isnull().sum(dim="time_utc") == 1 + assert np.logical_or(all_system_nan.values, one_system_nan.values).all()