Skip to content

Commit

Permalink
add NWP contiguous function
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Nov 15, 2023
1 parent dd64df6 commit 9c484cf
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 50 deletions.
86 changes: 49 additions & 37 deletions ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def create_t0_and_loc_datapipes(
key_for_t0: str = "gsp",
shuffle: bool = True,
nwp_max_dropout_minutes: int = 0,
max_staleness_minutes: int = 180,
):
"""
Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times.
Expand All @@ -349,6 +350,7 @@ def create_t0_and_loc_datapipes(
"""
assert key_for_t0 in datapipes_dict
assert key_for_t0 in ["gsp", "pv"]
assert max_staleness_minutes >= nwp_max_dropout_minutes

contiguous_time_datapipes = [] # Used to store contiguous time periods from each data source

Expand All @@ -359,53 +361,63 @@ def create_t0_and_loc_datapipes(
continue

elif key == "nwp":
sample_frequency = 180 # Init times are 3 hours apart

# If using NWP dropout we need to make sure the previous forecast is available
# Setting the history to larger here will do the required filtering
history_duration = max(
configuration.input_data.nwp.history_minutes, nwp_max_dropout_minutes
)
forecast_duration = configuration.input_data.nwp.forecast_minutes
time_dim = "init_time_utc"

elif key == "sat":
sample_frequency = 5
history_duration = configuration.input_data.satellite.history_minutes
forecast_duration = 0
time_dim = "time_utc"

elif key == "hrv":
sample_frequency = 5
history_duration = configuration.input_data.hrvsatellite.history_minutes
forecast_duration = 0
time_dim = "time_utc"

elif key == "pv":
sample_frequency = 5
history_duration = configuration.input_data.pv.history_minutes
forecast_duration = configuration.input_data.pv.forecast_minutes
time_dim = "time_utc"

elif key == "gsp":
sample_frequency = 30
history_duration = configuration.input_data.gsp.history_minutes
forecast_duration = configuration.input_data.gsp.forecast_minutes
time_dim = "time_utc"

datapipes_dict["nwp"], datapipe_copy = datapipes_dict["nwp"].fork(2, buffer_size=5)

# NWP is a forecast product so gets its own contiguous function
time_periods = datapipe_copy.get_contiguous_time_periods_nwp(
history_duration=timedelta(minutes=history_duration),
max_staleness=timedelta(minutes=max_staleness_minutes),
time_dim="init_time_utc",
)

contiguous_time_datapipes.append(time_periods)

else:
raise ValueError(f"Unexpected key: {key}")

datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5)

time_periods = datapipe_copy.get_contiguous_time_periods(
sample_period_duration=timedelta(minutes=sample_frequency),
history_duration=timedelta(minutes=history_duration),
forecast_duration=timedelta(minutes=forecast_duration),
time_dim=time_dim,
)
if key == "sat":
sample_frequency = 5
history_duration = configuration.input_data.satellite.history_minutes
forecast_duration = 0
time_dim = "time_utc"

elif key == "hrv":
sample_frequency = 5
history_duration = configuration.input_data.hrvsatellite.history_minutes
forecast_duration = 0
time_dim = "time_utc"

elif key == "pv":
sample_frequency = 5
history_duration = configuration.input_data.pv.history_minutes
forecast_duration = configuration.input_data.pv.forecast_minutes
time_dim = "time_utc"

elif key == "gsp":
sample_frequency = 30
history_duration = configuration.input_data.gsp.history_minutes
forecast_duration = configuration.input_data.gsp.forecast_minutes
time_dim = "time_utc"

else:
raise ValueError(f"Unexpected key: {key}")

datapipes_dict[key], datapipe_copy = datapipes_dict[key].fork(2, buffer_size=5)

time_periods = datapipe_copy.get_contiguous_time_periods(
sample_period_duration=timedelta(minutes=sample_frequency),
history_duration=timedelta(minutes=history_duration),
forecast_duration=timedelta(minutes=forecast_duration),
time_dim=time_dim,
)

contiguous_time_datapipes.append(time_periods)
contiguous_time_datapipes.append(time_periods)

# Find joint overlapping contiguous time periods
if len(contiguous_time_datapipes) > 1:
Expand Down
2 changes: 2 additions & 0 deletions ocf_datapipes/training/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ def construct_loctime_pipelines(
key_for_t0="gsp",
shuffle=True,
nwp_max_dropout_minutes=180,
# Sometimes the forecast is only 4/day so 6 hour intervals - then we add 3-hour dropout
max_staleness_minutes=60*9,
)

return location_pipe, t0_datapipe
Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/transform/xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .downsample import DownsampleIterDataPipe as Downsample
from .get_contiguous_time_periods import (
GetContiguousT0TimePeriodsIterDataPipe as GetContiguousT0TimePeriods,
GetContiguousT0TimePeriodsNWPIterDataPipe as GetContiguousT0TimePeriodsNWP,
)
from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage
from .gsp.ensure_n_gsp_per_example import (
Expand Down
117 changes: 104 additions & 13 deletions ocf_datapipes/transform/xarray/get_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,43 @@ def __iter__(self) -> pd.DataFrame:
logger.debug("Get contiguous time periods:done")
yield contiguous_time_periods


@functional_datapipe("get_contiguous_time_periods_nwp")
class GetContiguousT0TimePeriodsNWPIterDataPipe(IterDataPipe):
"""Get contiguous NWP time periods for training"""

def get_contiguous_t0_time_periods(
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
) -> pd.DataFrame:
"""Get all time periods which contain valid t0 datetimes.
def __init__(
self,
source_datapipe: IterDataPipe,
history_duration: timedelta,
max_staleness: timedelta = timedelta(minutes=0),
time_dim: str = "init_time_utc",
):
"""
Get contiguous time periods for use in determing t0 times for training
`t0` is the datetime of the most recent observation.
Args:
source_datapipe: Datapipe emitting a Xarray dataset
history_duration: Length of the historical slice used for a sample
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
forecast.
time_dim: time dimensions for which to find the contiguous time periods
"""
self.source_datapipe = source_datapipe
self.history_duration = history_duration
self.max_staleness = max_staleness
self.time_dim = time_dim

Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()
return contiguous_time_periods
def __iter__(self) -> pd.DataFrame:
"""Calculate contiguous time periods and return a dataframe containing them"""
for xr_data in self.source_datapipe:
logger.debug("Getting contiguous NWP t0 time periods")
contiguous_time_periods = get_contiguous_t0_periods_nwp(
datetimes=pd.DatetimeIndex(xr_data[self.time_dim]),
history_duration=self.history_duration,
max_staleness=self.max_staleness,
)
yield contiguous_time_periods


def get_contiguous_time_periods(
Expand Down Expand Up @@ -132,3 +153,73 @@ def get_contiguous_time_periods(
)

return pd.DataFrame(periods)


def get_contiguous_t0_time_periods(
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
) -> pd.DataFrame:
"""Get all time periods which contain valid t0 datetimes.
`t0` is the datetime of the most recent observation.
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()
return contiguous_time_periods


def get_contiguous_t0_periods_nwp(
datetimes: pd.DatetimeIndex,
history_duration: timedelta,
max_staleness: timedelta,
) -> pd.DataFrame:
"""Get all time periods from the NWP init times which are valid as t0 datetimes.
Args:
datetimes: Sorted pd.DatetimeIndex
history_duration: Length of the historical slice used for a sample
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
forecast. This must be >= forecast_duration.
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
# Sanity checks.
assert len(datetimes) > 0
assert datetimes.is_monotonic_increasing
assert datetimes.is_unique
assert history_duration >= timedelta(0)
assert max_staleness >= timedelta(0)

# Each forecast init time cover up to this time before we consider it too stale
stale_datetimes = datetimes + max_staleness

# Store contiguous periods
contiguous_periods = []

# dt_stale_prev: the timestamp after which the previous init time becomes "stale"
dt_stale_prev = stale_datetimes[0]

# Start first period allowing for history slice
start_this_period = datetimes[0] + history_duration

for dt_init, dt_stale in zip(datetimes[1:], stale_datetimes[1:]):
# If the previous init time becomes stale before the next init time
if dt_stale_prev < dt_init:
# Store a contiguous t0 period - allowing for forecast slice
if start_this_period <= dt_stale_prev:
contiguous_periods += [[start_this_period, dt_stale_prev]]

# And start a new period
start_this_period = dt_init + history_duration
dt_stale_prev = dt_stale

if start_this_period <= dt_stale_prev:
contiguous_periods += [[start_this_period, dt_stale_prev]]

return pd.DataFrame(contiguous_periods, columns=["start_dt", "end_dt"])

0 comments on commit 9c484cf

Please sign in to comment.