Skip to content

Commit

Permalink
Merge branch 'main' into pv_inputs_from_database
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Oct 24, 2023
2 parents d539a77 + e383474 commit 77ecaa8
Show file tree
Hide file tree
Showing 39 changed files with 605 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[bumpversion]
commit = True
tag = True
current_version = 2.0.1
current_version = 2.0.3
message = Bump version: {current_version} → {new_version} [skip ci]

[bumpversion:file:setup.py]
Expand Down
3 changes: 3 additions & 0 deletions ocf_datapipes/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

from ocf_datapipes.load.nwp.providers.ecmwf import open_ifs
from ocf_datapipes.load.nwp.providers.icon import open_icon_eu, open_icon_global
from ocf_datapipes.load.nwp.providers.ukv import open_ukv

Expand Down Expand Up @@ -39,6 +40,8 @@ def __init__(
self.open_nwp = open_icon_eu
elif provider == "icon-global":
self.open_nwp = open_icon_global
elif provider == "ecmwf":
self.open_nwp = open_ifs
else:
raise ValueError(f"Unknown provider: {provider}")

Expand Down
40 changes: 40 additions & 0 deletions ocf_datapipes/load/nwp/providers/ecmwf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""ECMWF provider loaders"""
import pandas as pd
import xarray as xr

from ocf_datapipes.load.nwp.providers.utils import open_zarr_paths


def open_ifs(zarr_path) -> xr.DataArray:
"""
Opens the ECMWF IFS NWP data
Args:
zarr_path: Path to the zarr to open
Returns:
Xarray DataArray of the NWP data
"""
# Open the data
nwp = open_zarr_paths(zarr_path)
dataVars = list(nwp.data_vars.keys())
if len(dataVars) > 1:
raise Exception("Too many TLDVs")
else:
dataVar = dataVars[0]
ifs: xr.Dataset = nwp[dataVar]
del nwp
ifs = ifs.transpose("init_time", "step", "variable", "latitude", "longitude")
ifs = ifs.rename(
{
"init_time": "init_time_utc",
"variable": "channel",
"latitude": "latitude",
"longitude": "longitude",
}
)
# Sanity checks.
time = pd.DatetimeIndex(ifs.init_time_utc)
assert time.is_unique
assert time.is_monotonic_increasing
return ifs
121 changes: 115 additions & 6 deletions ocf_datapipes/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,47 @@
logger = logging.getLogger(__name__)


def fill_1d_bool_gaps(x, max_gap, fill_ends=False):
"""In a boolean array, fill consecutive False elements if their number is less than the gap_size
Args:
x: A 1-dimensional boolean array
max_gap: integer of the maximum gap size which will be filled with True
fill_ends: Whether to fill the ends as if there are True values on either side
Returns:
A 1-dimensional boolean array
Examples:
>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=2, fill_ends=False).astype(int)
array([0, 1, 1, 1, 1, 1, 1, 0])
>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=1, fill_ends=True).astype(int)
array([1, 1, 0, 0, 1, 1, 1, 1])
"""
if fill_ends:
x_extended = np.concatenate([[True], x, [True]])
return fill_1d_bool_gaps(x_extended, max_gap, fill_ends=False)[1:-1]

should_fill = np.zeros(len(x), dtype=bool)

i_start = None

last_b = False
for i, b in enumerate(x):
if last_b and not b:
i_start = i
elif b and not last_b and i_start is not None:
if i - i_start <= max_gap:
should_fill[i_start:i] = True
i_start = None
last_b = b

return np.logical_or(should_fill, x)


@functional_datapipe("select_time_slice")
class SelectTimeSliceIterDataPipe(IterDataPipe):
"""Selects time slice"""
Expand All @@ -26,6 +67,7 @@ def __init__(
interval_start: Optional[timedelta] = None,
interval_end: Optional[timedelta] = None,
fill_selection: Optional[bool] = False,
max_steps_gap: Optional[int] = 0,
):
"""
Selects time slice.
Expand All @@ -43,16 +85,23 @@ def __init__(
interval_end (optional): timedelta with respect to t0 where the open interval ends
fill_selection (optional): If True, and if the data yielded from `source_datapipe` does
not extend over the entire requested time period. The missing timestamps are filled
with NaN values in the returned xarray object. Else the default xarray slicing
behaviour is used.
in the returned xarray object tyo give the expected shape. Else the default xarray
slicing behaviour is used and timestamps may be missing. When filled, the values are
linearly interpolated up to a gap size of `max_steps_gap` steps. If outside this
range, the values are set to NaN.
max_steps_gap (optional): The number of consecutive missing time steps which will be
filled via linear interpolation. If set to zero, no interpolation is used and all
missing timesteps will be NaN.
"""
self.source_datapipe = source_datapipe
self.t0_datapipe = t0_datapipe
self.fill_selection = fill_selection
self.max_steps_gap = max_steps_gap

used_duration = history_duration is not None and forecast_duration is not None
used_intervals = interval_start is not None and interval_end is not None
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "

if used_duration:
self.interval_start = -np.timedelta64(history_duration)
Expand All @@ -63,6 +112,13 @@ def __init__(

self.sample_period_duration = sample_period_duration

if self.fill_selection and max_steps_gap == 0:
self._sel = self._sel_fillnan
elif self.fill_selection and max_steps_gap > 0:
self._sel = self._sel_fillinterp
else:
self._sel = self._sel_default

def _sel_fillnan(self, xr_data, start_dt, end_dt):
requested_times = pd.date_range(
start_dt,
Expand All @@ -72,6 +128,62 @@ def _sel_fillnan(self, xr_data, start_dt, end_dt):
# Missing time indexes are returned with all NaN values
return xr_data.reindex(time_utc=requested_times)

def _sel_fillinterp(self, xr_data, start_dt, end_dt):
dt_buffer = self.sample_period_duration * self.max_steps_gap

# Initially select larger period so we can use it to interpolate to requested period
# This slice also avoids us interpolating the whole dataset to get the requested times
ds = xr_data.sel(time_utc=slice(start_dt - dt_buffer, end_dt + dt_buffer))

# These are the times we will ultimately return
requested_times = pd.date_range(
start_dt,
end_dt,
freq=self.sample_period_duration,
)

# These are the times we use for interpolation to the requested_times
buffer_requested_times = pd.date_range(
start_dt - dt_buffer,
end_dt + dt_buffer,
freq=self.sample_period_duration,
)

# If all the requested times are present we avoid running interpolation
if np.isin(requested_times, ds.time_utc).all():
return ds.sel(time_utc=slice(start_dt, end_dt))

# If less than 2 of the buffer requested times are present we cannot infill
elif np.isin(buffer_requested_times, ds.time_utc).sum() < 2:
logger.warning("Cannot run interpolate infilling with less than 2 time steps available")
return self._sel_fillnan(xr_data, start_dt, end_dt)

logger.info("Some requested times are missing - running interpolation")
# Find the timestamps which are within max gap size
mask = np.isin(buffer_requested_times, ds.time_utc)
valid_fill_times = fill_1d_bool_gaps(mask, self.max_steps_gap, fill_ends=False)

# Run the interpolation and filter to requested times
ds_interp = ds.interp(time_utc=buffer_requested_times, method="linear", assume_sorted=True)

# Mask the timestamps outside the max gap size
valid_fill_times_xr = xr.zeros_like(ds_interp.time_utc, dtype=bool)
valid_fill_times_xr.values[:] = valid_fill_times

valid_requested_times = valid_fill_times_xr.sel(time_utc=slice(start_dt, end_dt))
if not valid_requested_times.all():
not_infilled_times = valid_requested_times.where(~valid_requested_times, drop=True)
logger.warning(
"After interpolation the following requested times are still missing:"
f"{not_infilled_times.time_utc.values}"
)

ds_out = ds_interp.where(valid_fill_times_xr)

# Filter to selected times
ds_out = ds_out.sel(time_utc=slice(start_dt, end_dt))
return ds_out

def _sel_default(self, xr_data, start_dt, end_dt):
return xr_data.sel(time_utc=slice(start_dt, end_dt))

Expand All @@ -86,7 +198,4 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
start_dt = start_dt.ceil(self.sample_period_duration)
end_dt = end_dt.ceil(self.sample_period_duration)

if self.fill_selection:
yield self._sel_fillnan(xr_data, start_dt, end_dt)
else:
yield self._sel_default(xr_data, start_dt, end_dt)
yield self._sel(xr_data, start_dt, end_dt)
13 changes: 4 additions & 9 deletions ocf_datapipes/training/pseudo_irradience.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def pseudo_irradiance_datapipe(
batch_size: Batch size for the datapipe
one_d: Whether to return a 1D array or not, i.e. a single PV site in the center as
opposed to a 2D array of PV sites
size_meters: Size, in meters, of the output image
use_meters: Whether to use meters or pixels
normalize_by_pvlib: Whether to normalize the PV generation by the PVLib generation
is_test: Whether to return the test set or not
Returns: datapipe
"""
Expand All @@ -434,20 +438,16 @@ def pseudo_irradiance_datapipe(
use_gsp=False,
use_pv=use_pv,
)
# print(used_datapipes.keys())
# Load GSP national data
used_datapipes["pv"] = used_datapipes["pv"].select_train_test_time(start_time, end_time)

# Now get overlapping time periods
used_datapipes = get_and_return_overlapping_time_periods_and_t0(
used_datapipes, key_for_t0="pv", return_all_times=True if is_test else False
)
# print(used_datapipes.keys())
# return used_datapipes["pv"].zip_ocf(used_datapipes["nwp"],used_datapipes["pv_t0"],used_datapipes["nwp_t0"])
# And now get time slices
used_datapipes = add_selected_time_slices_from_datapipes(used_datapipes)
# print(used_datapipes.keys())
# return used_datapipes["pv"].zip_ocf(used_datapipes["sat"],used_datapipes["pv_future"])

# Now do the extra processing
pv_history = used_datapipes["pv"].map(
Expand Down Expand Up @@ -490,7 +490,6 @@ def pseudo_irradiance_datapipe(
)
pv_datapipe = pv_datapipe.select_id(pv_one_d_datapipe, data_source_name="pv")
pv_history = pv_history.select_id(pv_one_d_datapipe2, data_source_name="pv")
# return pv_datapipe.zip_ocf(pv_history, pv_loc_datapipe, pv_meta_save, pv_sav_loc, used_datapipes["sat"])

if "nwp" in used_datapipes.keys():
# take nwp time slices
Expand Down Expand Up @@ -594,7 +593,6 @@ def pseudo_irradiance_datapipe(
pv_datapipe, pv_meta = pv_datapipe.fork(2)
pv_meta = pv_meta.map(_get_meta)
pv_datapipe = pv_datapipe.map(_get_values)
# return pv_datapipe.zip_ocf(pv_history, pv_loc_datapipe, pv_meta_save, pv_sav_loc, pv_meta, sat_datapipe)
else:
if "hrv" in used_datapipes.keys():
sat_hrv_datapipe, sat_gsp_datapipe = sat_hrv_datapipe.fork(2)
Expand Down Expand Up @@ -664,7 +662,6 @@ def pseudo_irradiance_datapipe(
elif "sat" in used_datapipes.keys():
sat_datapipe, sun_image_datapipe = sat_datapipe.fork(2)
sun_image_datapipe = sun_image_datapipe.create_sun_image(normalize=True)
# return pv_datapipe.zip_ocf(pv_history, pv_loc_datapipe, pv_meta_save, pv_sav_loc, pv_meta, sat_datapipe)
if "nwp" in used_datapipes.keys():
nwp_datapipe, time_image_datapipe = nwp_datapipe.fork(2, buffer_size=100)
time_image_datapipe = time_image_datapipe.create_time_image(
Expand All @@ -679,7 +676,6 @@ def pseudo_irradiance_datapipe(
time_image_datapipe = time_image_datapipe.create_time_image()
else:
time_image_datapipe = None
# return pv_datapipe.zip_ocf(pv_history, pv_loc_datapipe, pv_meta_save, pv_sav_loc, pv_meta, sat_datapipe, time_image_datapipe)

modalities = []
if not one_d:
Expand All @@ -698,7 +694,6 @@ def pseudo_irradiance_datapipe(
modalities.append(time_image_datapipe)

stacked_xarray_inputs = StackXarray(modalities)
# return pv_datapipe.zip_ocf(pv_history, pv_loc_datapipe, pv_meta_save, pv_sav_loc, pv_meta, stacked_xarray_inputs)
return stacked_xarray_inputs.batch(batch_size).zip_ocf(
pv_meta.batch(batch_size),
pv_datapipe.batch(batch_size),
Expand Down
2 changes: 2 additions & 0 deletions ocf_datapipes/training/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def slice_datapipes_by_time(
interval_start=minutes(-conf_in.satellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
max_steps_gap=2,
)

# Generate randomly sampled dropout times
Expand Down Expand Up @@ -475,6 +476,7 @@ def slice_datapipes_by_time(
interval_start=minutes(-conf_in.hrvsatellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
max_steps_gap=2,
)

# Apply the dropout
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="ocf_datapipes",
version="2.0.1",
version="2.0.3",
license="MIT",
description="Pytorch Datapipes built for use in Open Climate Fix's forecasting work",
author="Jacob Bieker, Jack Kelly, Peter Dudfield, James Fulton",
Expand Down
1 change: 1 addition & 0 deletions tests/data/ifs.zarr/.zattrs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
3 changes: 3 additions & 0 deletions tests/data/ifs.zarr/.zgroup
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"zarr_format": 2
}
Loading

0 comments on commit 77ecaa8

Please sign in to comment.