Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to use linear interpolation to fill gaps in input data #230

Merged
merged 9 commits into from
Oct 23, 2023
117 changes: 111 additions & 6 deletions ocf_datapipes/select/select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,47 @@
logger = logging.getLogger(__name__)


def fill_1d_bool_gaps(x, max_gap, fill_ends=False):
"""In a boolean array, fill consecutive False elements their number is less than the gap_size.
dfulu marked this conversation as resolved.
Show resolved Hide resolved

Args:
x: A 1-dimensional boolean array
max_gap: integer of the maximum gap size which will be filled with True
fill_ends: Whether to fill the ends as if there are True values on either side

Returns:
A 1-dimensional boolean array

Examples:
>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=2, fill_ends=False).astype(int)
array([0, 1, 1, 1, 1, 1, 1, 0])

>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=1, fill_ends=True).astype(int)
array([1, 1, 0, 0, 1, 1, 1, 1])
"""
if fill_ends:
x_extended = np.concatenate([[True], x, [True]])
return fill_1d_bool_gaps(x_extended, max_gap, fill_ends=False)[1:-1]

should_fill = np.zeros(len(x), dtype=bool)

i_start = None

last_b = False
for i, b in enumerate(x):
if last_b and not b:
i_start = i
elif b and not last_b and i_start is not None:
if i - i_start <= max_gap:
should_fill[i_start:i] = True
i_start = None
last_b = b

return np.logical_or(should_fill, x)


@functional_datapipe("select_time_slice")
class SelectTimeSliceIterDataPipe(IterDataPipe):
"""Selects time slice"""
Expand All @@ -26,6 +67,7 @@ def __init__(
interval_start: Optional[timedelta] = None,
interval_end: Optional[timedelta] = None,
fill_selection: Optional[bool] = False,
max_steps_gap: Optional[int] = 0,
):
"""
Selects time slice.
Expand All @@ -43,16 +85,23 @@ def __init__(
interval_end (optional): timedelta with respect to t0 where the open interval ends
fill_selection (optional): If True, and if the data yielded from `source_datapipe` does
not extend over the entire requested time period. The missing timestamps are filled
with NaN values in the returned xarray object. Else the default xarray slicing
behaviour is used.
in the returned xarray object tyo give the expected shape. Else the default xarray
slicing behaviour is used and timestamps may be missing. When filled, the values are
linearly interpolated up to a gap size of `max_steps_gap` steps. If outside this
range, the values are set to NaN.
max_steps_gap (optional): The number of consecutive missing time steps which will be
filled via linear interpolation. If set to zero, no interpolation is used and all
missing timesteps will be NaN.
"""
self.source_datapipe = source_datapipe
self.t0_datapipe = t0_datapipe
self.fill_selection = fill_selection
self.max_steps_gap = max_steps_gap

used_duration = history_duration is not None and forecast_duration is not None
used_intervals = interval_start is not None and interval_end is not None
assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied"
assert max_steps_gap >= 0, "max_steps_gap must be >= 0 "

if used_duration:
self.interval_start = -np.timedelta64(history_duration)
Expand All @@ -63,6 +112,13 @@ def __init__(

self.sample_period_duration = sample_period_duration

if self.fill_selection and max_steps_gap == 0:
self._sel = self._sel_fillnan
elif self.fill_selection and max_steps_gap > 0:
self._sel = self._sel_fillinterp
else:
self._sel = self._sel_default

def _sel_fillnan(self, xr_data, start_dt, end_dt):
requested_times = pd.date_range(
start_dt,
Expand All @@ -72,6 +128,58 @@ def _sel_fillnan(self, xr_data, start_dt, end_dt):
# Missing time indexes are returned with all NaN values
return xr_data.reindex(time_utc=requested_times)

def _sel_fillinterp(self, xr_data, start_dt, end_dt):
dt_buffer = self.sample_period_duration * self.max_steps_gap

# Initially select larger period so we can use it to interpolate to requested period
# This slice also avoids us interpolating the whole dataset to get the requested times
ds = xr_data.sel(time_utc=slice(start_dt - dt_buffer, end_dt + dt_buffer))

# These are the times we will ultimately return
requested_times = pd.date_range(
start_dt,
end_dt,
freq=self.sample_period_duration,
)

# If all the requested times are present we avoid running interpolation
if np.isin(requested_times, ds.time_utc).all():
return ds.sel(time_utc=slice(start_dt, end_dt))

logger.info("Some requested times are missing - running interpolation")
# These are the times we use for interpolation to the requested_times
buffer_requested_times = pd.date_range(
start_dt - dt_buffer,
end_dt + dt_buffer,
freq=self.sample_period_duration,
)

# Find the timestamps which are within max gap size
mask = np.isin(buffer_requested_times, ds.time_utc)
valid_fill_times = fill_1d_bool_gaps(mask, self.max_steps_gap, fill_ends=False)

# Run the interpolation and filter to requested times
ds_interp = ds.interp(time_utc=buffer_requested_times, method="linear", assume_sorted=True)

# Mask the timestamps outside the max gap size
valid_fill_times_xr = xr.zeros_like(ds_interp.time_utc, dtype=bool)
valid_fill_times_xr.values[:] = valid_fill_times

valid_requested_times = valid_fill_times_xr.sel(time_utc=slice(start_dt, end_dt))
if not valid_requested_times.all():
not_infilled_times = valid_requested_times.where(~valid_requested_times, drop=True)
logger.warning(
"After interpolation the following requested times are still missing:"
f"{not_infilled_times.time_utc.values}"
)

ds_out = ds_interp.where(valid_fill_times_xr)

# Filter to selected times
ds_out = ds_out.sel(time_utc=slice(start_dt, end_dt))

return ds_out

def _sel_default(self, xr_data, start_dt, end_dt):
return xr_data.sel(time_utc=slice(start_dt, end_dt))

Expand All @@ -86,7 +194,4 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
start_dt = start_dt.ceil(self.sample_period_duration)
end_dt = end_dt.ceil(self.sample_period_duration)

if self.fill_selection:
yield self._sel_fillnan(xr_data, start_dt, end_dt)
else:
yield self._sel_default(xr_data, start_dt, end_dt)
yield self._sel(xr_data, start_dt, end_dt)
2 changes: 2 additions & 0 deletions ocf_datapipes/training/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def slice_datapipes_by_time(
interval_start=minutes(-conf_in.satellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
max_steps_gap=2,
)

# Generate randomly sampled dropout times
Expand Down Expand Up @@ -399,6 +400,7 @@ def slice_datapipes_by_time(
interval_start=minutes(-conf_in.hrvsatellite.history_minutes),
interval_end=sat_delay,
fill_selection=production,
max_steps_gap=2,
)

# Apply the dropout
Expand Down
58 changes: 55 additions & 3 deletions tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,26 @@
import numpy as np
from torchdata.datapipes.iter import IterableWrapper
from ocf_datapipes.select import SelectTimeSlice
from ocf_datapipes.select.select_time_slice import fill_1d_bool_gaps


def test_fill_1d_bool_gaps():
x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
y = fill_1d_bool_gaps(x, max_gap=2, fill_ends=False).astype(int)
assert (np.array([0, 1, 1, 1, 1, 1, 1, 0]) == y).all()

x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
y = fill_1d_bool_gaps(x, max_gap=1, fill_ends=True).astype(int)
assert (np.array([1, 1, 0, 0, 1, 1, 1, 1]) == y).all()


def test_select_time_slice_sat(sat_datapipe):
data = next(iter(sat_datapipe))

t0_datapipe = IterableWrapper(pd.to_datetime(data.time_utc.values)[3:6])

# Check with history and forecast durations
# ----------- Check with history and forecast durations -----------

dp = SelectTimeSlice(
sat_datapipe,
t0_datapipe,
Expand All @@ -26,7 +38,8 @@ def test_select_time_slice_sat(sat_datapipe):
assert len(sat_sample.time_utc) == 3
assert sat_sample.time_utc[1] == t0

# Check again with intervals
# ------------------ Check again with intervals -------------------

dp = SelectTimeSlice(
sat_datapipe,
t0_datapipe,
Expand All @@ -41,7 +54,8 @@ def test_select_time_slice_sat(sat_datapipe):
assert len(sat_sample.time_utc) == 3
assert sat_sample.time_utc[1] == t0

# Check with out of bounds selection
# -------------- Check with out of bounds selection ---------------

t_last = pd.to_datetime(data.time_utc.values[-1])
t0_values = [
t_last - timedelta(minutes=5),
Expand All @@ -68,3 +82,41 @@ def test_select_time_slice_sat(sat_datapipe):
# Correct number of time steps are all NaN
sat_sel = sat_sample.isel(x_geostationary=0, y_geostationary=0, channel=0)
assert np.isnan(sat_sel.values).sum() == i

# ------------------- Check with interpolation --------------------

data_times = pd.to_datetime(data.time_utc.values)
t0_datapipe = IterableWrapper(data_times[[0, 1, 4, 5]])

missing_sat_data = data.sel(time_utc=data_times[[0, 2, 3, 6]])
missing_sat_datapipe = IterableWrapper([missing_sat_data]).repeat(len(t0_datapipe))

# For each sample the timestamps should be missing in this order
expected_missing_steps = np.array(
[
[True, False, False],
[False, False, False],
[False, True, True],
[True, True, False],
]
)

dp = SelectTimeSlice(
missing_sat_datapipe,
t0_datapipe,
sample_period_duration=timedelta(minutes=5),
interval_start=timedelta(minutes=-5),
interval_end=timedelta(minutes=5),
fill_selection=True,
max_steps_gap=1,
)

sat_samples = list(dp)
t0_values = list(t0_datapipe)

for i in range(len(sat_samples)):
assert len(sat_samples[i].time_utc) == 3
assert sat_samples[i].time_utc[1] == t0_values[i]
# Correct number of time steps are all NaN
sat_sel = sat_samples[i].isel(x_geostationary=0, y_geostationary=0, channel=0)
assert (np.isnan(sat_sel.values) == expected_missing_steps[i]).all()
Loading