diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 2cbbc83a1..31c677575 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -326,6 +326,7 @@ def create_t0_and_loc_datapipes( key_for_t0: str = "gsp", shuffle: bool = True, nwp_max_dropout_minutes: int = 0, + max_staleness_minutes: int = 180, ): """ Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. @@ -349,6 +350,7 @@ def create_t0_and_loc_datapipes( """ assert key_for_t0 in datapipes_dict assert key_for_t0 in ["gsp", "pv"] + assert max_staleness_minutes >= nwp_max_dropout_minutes contiguous_time_datapipes = [] # Used to store contiguous time periods from each data source @@ -359,53 +361,63 @@ def create_t0_and_loc_datapipes( continue elif key == "nwp": - sample_frequency = 180 # Init times are 3 hours apart # If using NWP dropout we need to make sure the previous forecast is available # Setting the history to larger here will do the required filtering history_duration = max( configuration.input_data.nwp.history_minutes, nwp_max_dropout_minutes ) - forecast_duration = configuration.input_data.nwp.forecast_minutes - time_dim = "init_time_utc" - - elif key == "sat": - sample_frequency = 5 - history_duration = configuration.input_data.satellite.history_minutes - forecast_duration = 0 - time_dim = "time_utc" - - elif key == "hrv": - sample_frequency = 5 - history_duration = configuration.input_data.hrvsatellite.history_minutes - forecast_duration = 0 - time_dim = "time_utc" - - elif key == "pv": - sample_frequency = 5 - history_duration = configuration.input_data.pv.history_minutes - forecast_duration = configuration.input_data.pv.forecast_minutes - time_dim = "time_utc" - - elif key == "gsp": - sample_frequency = 30 - history_duration = configuration.input_data.gsp.history_minutes - forecast_duration = configuration.input_data.gsp.forecast_minutes - time_dim = "time_utc" + + datapipes_dict["nwp"], datapipe_copy = datapipes_dict["nwp"].fork(2, buffer_size=5) + + # NWP is a forecast product so gets its own contiguous function + time_periods = datapipe_copy.get_contiguous_time_periods_nwp( + history_duration=timedelta(minutes=history_duration), + max_staleness=timedelta(minutes=max_staleness_minutes), + time_dim="init_time_utc", + ) + contiguous_time_datapipes.append(time_periods) + else: - raise ValueError(f"Unexpected key: {key}") - - datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5) - time_periods = datapipe_copy.get_contiguous_time_periods( - sample_period_duration=timedelta(minutes=sample_frequency), - history_duration=timedelta(minutes=history_duration), - forecast_duration=timedelta(minutes=forecast_duration), - time_dim=time_dim, - ) + if key == "sat": + sample_frequency = 5 + history_duration = configuration.input_data.satellite.history_minutes + forecast_duration = 0 + time_dim = "time_utc" + + elif key == "hrv": + sample_frequency = 5 + history_duration = configuration.input_data.hrvsatellite.history_minutes + forecast_duration = 0 + time_dim = "time_utc" + + elif key == "pv": + sample_frequency = 5 + history_duration = configuration.input_data.pv.history_minutes + forecast_duration = configuration.input_data.pv.forecast_minutes + time_dim = "time_utc" + + elif key == "gsp": + sample_frequency = 30 + history_duration = configuration.input_data.gsp.history_minutes + forecast_duration = configuration.input_data.gsp.forecast_minutes + time_dim = "time_utc" + + else: + raise ValueError(f"Unexpected key: {key}") + + datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5) + + time_periods = datapipe_copy.get_contiguous_time_periods( + sample_period_duration=timedelta(minutes=sample_frequency), + history_duration=timedelta(minutes=history_duration), + forecast_duration=timedelta(minutes=forecast_duration), + time_dim=time_dim, + ) - contiguous_time_datapipes.append(time_periods) + contiguous_time_datapipes.append(time_periods) # Find joint overlapping contiguous time periods if len(contiguous_time_datapipes) > 1: diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 27015e74e..ab757749f 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -356,6 +356,8 @@ def construct_loctime_pipelines( key_for_t0="gsp", shuffle=True, nwp_max_dropout_minutes=180, + # Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout + max_staleness_minutes=60*9, ) return location_pipe, t0_datapipe diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index 1b3da6222..aca22c861 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -33,6 +33,7 @@ from .downsample import DownsampleIterDataPipe as Downsample from .get_contiguous_time_periods import ( GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods, + GetContiguousT0TimePeriodsNWPIterDataPipe as GetContiguousT0TimePeriodsNWP, ) from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage from .gsp.ensure_n_gsp_per_example import ( diff --git a/ocf_datapipes/transform/xarray/get_contiguous_time_periods.py b/ocf_datapipes/transform/xarray/get_contiguous_time_periods.py index 3b0c87d20..4ab1c7e5a 100644 --- a/ocf_datapipes/transform/xarray/get_contiguous_time_periods.py +++ b/ocf_datapipes/transform/xarray/get_contiguous_time_periods.py @@ -60,22 +60,43 @@ def __iter__(self) -> pd.DataFrame: logger.debug("Get contiguous time periods:done") yield contiguous_time_periods + +@functional_datapipe("get_contiguous_time_periods_nwp") +class GetContiguousT0TimePeriodsNWPIterDataPipe(IterDataPipe): + """Get contiguous NWP time periods for training""" -def get_contiguous_t0_time_periods( - contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta -) -> pd.DataFrame: - """Get all time periods which contain valid t0 datetimes. + def __init__( + self, + source_datapipe: IterDataPipe, + history_duration: timedelta, + max_staleness: timedelta = timedelta(minutes=0), + time_dim: str = "init_time_utc", + ): + """ + Get contiguous time periods for use in determing t0 times for training - `t0` is the datetime of the most recent observation. + Args: + source_datapipe: Datapipe emitting a Xarray dataset + history_duration: Length of the historical slice used for a sample + max_staleness: Up to how long after an NWP forecast init_time are we willing to use the + forecast. + time_dim: time dimensions for which to find the contiguous time periods + """ + self.source_datapipe = source_datapipe + self.history_duration = history_duration + self.max_staleness = max_staleness + self.time_dim = time_dim - Returns: - pd.DataFrame where each row represents a single time period. The pd.DataFrame - has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). - """ - contiguous_time_periods["start_dt"] += history_duration - contiguous_time_periods["end_dt"] -= forecast_duration - assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all() - return contiguous_time_periods + def __iter__(self) -> pd.DataFrame: + """Calculate contiguous time periods and return a dataframe containing them""" + for xr_data in self.source_datapipe: + logger.debug("Getting contiguous NWP t0 time periods") + contiguous_time_periods = get_contiguous_t0_periods_nwp( + datetimes=pd.DatetimeIndex(xr_data[self.time_dim]), + history_duration=self.history_duration, + max_staleness=self.max_staleness, + ) + yield contiguous_time_periods def get_contiguous_time_periods( @@ -132,3 +153,73 @@ def get_contiguous_time_periods( ) return pd.DataFrame(periods) + + +def get_contiguous_t0_time_periods( + contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta +) -> pd.DataFrame: + """Get all time periods which contain valid t0 datetimes. + + `t0` is the datetime of the most recent observation. + + Returns: + pd.DataFrame where each row represents a single time period. The pd.DataFrame + has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). + """ + contiguous_time_periods["start_dt"] += history_duration + contiguous_time_periods["end_dt"] -= forecast_duration + assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all() + return contiguous_time_periods + + +def get_contiguous_t0_periods_nwp( + datetimes: pd.DatetimeIndex, + history_duration: timedelta, + max_staleness: timedelta, +) -> pd.DataFrame: + """Get all time periods from the NWP init times which are valid as t0 datetimes. + + Args: + datetimes: Sorted pd.DatetimeIndex + history_duration: Length of the historical slice used for a sample + max_staleness: Up to how long after an NWP forecast init_time are we willing to use the + forecast. This must be >= forecast_duration. + + Returns: + pd.DataFrame where each row represents a single time period. The pd.DataFrame + has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). + """ + # Sanity checks. + assert len(datetimes) > 0 + assert datetimes.is_monotonic_increasing + assert datetimes.is_unique + assert history_duration >= timedelta(0) + assert max_staleness >= timedelta(0) + + # Each forecast init time cover up to this time before we consider it too stale + stale_datetimes = datetimes + max_staleness + + # Store contiguous periods + contiguous_periods = [] + + # dt_stale_prev: the timestamp after which the previous init time becomes "stale" + dt_stale_prev = stale_datetimes[0] + + # Start first period allowing for history slice + start_this_period = datetimes[0] + history_duration + + for dt_init, dt_stale in zip(datetimes[1:], stale_datetimes[1:]): + # If the previous init time becomes stale before the next init time + if dt_stale_prev < dt_init: + # Store a contiguous t0 period - allowing for forecast slice + if start_this_period <= dt_stale_prev: + contiguous_periods += [[start_this_period, dt_stale_prev]] + + # And start a new period + start_this_period = dt_init + history_duration + dt_stale_prev = dt_stale + + if start_this_period <= dt_stale_prev: + contiguous_periods += [[start_this_period, dt_stale_prev]] + + return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"]) \ No newline at end of file