Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 23, 2023
1 parent be796d2 commit f864553
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 48 deletions.
59 changes: 29 additions & 30 deletions ocf_datapipes/select/select_time_slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Selects time slice"""
import logging
from datetime import timedelta
from typing import Optional, Union, Literal
from typing import Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -12,45 +12,44 @@
logger = logging.getLogger(__name__)



def fill_1d_bool_gaps(x, max_gap, fill_ends=False):
"""In a boolean array, fill consecutive False elements their number is less than the gap_size.
Args:
x: A 1-dimensional boolean array
max_gap: integer of the maximum gap size which will be filled with True
fill_ends: Whether to fill the ends as if there are True values on either side
Returns:
A 1-dimensional boolean array
Examples:
>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=2, fill_ends=False).astype(int)
array([0, 1, 1, 1, 1, 1, 1, 0])
>>> x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
>>> fill_1d_bool_gaps(x, max_gap=1, fill_ends=True).astype(int)
array([1, 1, 0, 0, 1, 1, 1, 1])
"""
if fill_ends:
x_extended = np.concatenate([[True], x, [True]])
return fill_1d_bool_gaps(x_extended, max_gap, fill_ends=False)[1:-1]

should_fill = np.zeros(len(x), dtype=bool)

i_start = None

last_b = False
for i, b in enumerate(x):
if last_b and not b:
i_start = i
elif b and not last_b and i_start is not None:
if i-i_start<=max_gap:
if i - i_start <= max_gap:
should_fill[i_start:i] = True
i_start = None
last_b = b

return np.logical_or(should_fill, x)


Expand Down Expand Up @@ -86,12 +85,12 @@ def __init__(
interval_end (optional): timedelta with respect to t0 where the open interval ends
fill_selection (optional): If True, and if the data yielded from `source_datapipe` does
not extend over the entire requested time period. The missing timestamps are filled
in the returned xarray object tyo give the expected shape. Else the default xarray
in the returned xarray object tyo give the expected shape. Else the default xarray
slicing behaviour is used and timestamps may be missing. When filled, the values are
linearly interpolated up to a gap size of `max_steps_gap` steps. If outside this
linearly interpolated up to a gap size of `max_steps_gap` steps. If outside this
range, the values are set to NaN.
max_steps_gap (optional): The number of consecutive missing time steps which will be
filled via linear interpolation. If set to zero, no interpolation is used and all
max_steps_gap (optional): The number of consecutive missing time steps which will be
filled via linear interpolation. If set to zero, no interpolation is used and all
missing timesteps will be NaN.
"""
self.source_datapipe = source_datapipe
Expand All @@ -112,10 +111,10 @@ def __init__(
self.interval_end = np.timedelta64(interval_end)

self.sample_period_duration = sample_period_duration
if self.fill_selection and max_steps_gap==0:

if self.fill_selection and max_steps_gap == 0:
self._sel = self._sel_fillnan
elif self.fill_selection and max_steps_gap>0:
elif self.fill_selection and max_steps_gap > 0:
self._sel = self._sel_fillinterp
else:
self._sel = self._sel_default
Expand All @@ -128,30 +127,30 @@ def _sel_fillnan(self, xr_data, start_dt, end_dt):
)
# Missing time indexes are returned with all NaN values
return xr_data.reindex(time_utc=requested_times)

def _sel_fillinterp(self, xr_data, start_dt, end_dt):
dt_buffer = self.sample_period_duration*self.max_steps_gap
dt_buffer = self.sample_period_duration * self.max_steps_gap

# Initially select larger period so we can use it to interpolate to requested period
# This slice also avoids us interpolating the whole dataset to get the requested times
ds = xr_data.sel(time_utc=slice(start_dt-dt_buffer, end_dt+dt_buffer))
ds = xr_data.sel(time_utc=slice(start_dt - dt_buffer, end_dt + dt_buffer))

# These are the times we will ultimately return
requested_times = pd.date_range(
start_dt,
end_dt,
freq=self.sample_period_duration,
)

# If all the requested times are present we avoid running interpolation
if np.isin(requested_times, ds.time_utc).all():
return ds.sel(time_utc=slice(start_dt, end_dt))

logger.info("Some requested times are missing - running interpolation")
# These are the times we use for interpolation to the requested_times
buffer_requested_times = pd.date_range(
start_dt-dt_buffer,
end_dt+dt_buffer,
start_dt - dt_buffer,
end_dt + dt_buffer,
freq=self.sample_period_duration,
)

Expand All @@ -165,20 +164,20 @@ def _sel_fillinterp(self, xr_data, start_dt, end_dt):
# Mask the timestamps outside the max gap size
valid_fill_times_xr = xr.zeros_like(ds_interp.time_utc, dtype=bool)
valid_fill_times_xr.values[:] = valid_fill_times

valid_requested_times = valid_fill_times_xr.sel(time_utc=slice(start_dt, end_dt))
if not valid_requested_times.all():
not_infilled_times = valid_requested_times.where(~valid_requested_times, drop=True)
logger.warning(
"After interpolation the following requested times are still missing:"
f"{not_infilled_times.time_utc.values}"
)

ds_out = ds_interp.where(valid_fill_times_xr)

# Filter to selected times
ds_out = ds_out.sel(time_utc=slice(start_dt, end_dt))

return ds_out

def _sel_default(self, xr_data, start_dt, end_dt):
Expand Down
37 changes: 19 additions & 18 deletions tests/select/test_select_time_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def test_fill_1d_bool_gaps():
x = np.array([0, 1, 0, 0, 1, 0, 1, 0])
y = fill_1d_bool_gaps(x, max_gap=1, fill_ends=True).astype(int)
assert (np.array([1, 1, 0, 0, 1, 1, 1, 1]) == y).all()


def test_select_time_slice_sat(sat_datapipe):
data = next(iter(sat_datapipe))

t0_datapipe = IterableWrapper(pd.to_datetime(data.time_utc.values)[3:6])

# ----------- Check with history and forecast durations -----------

dp = SelectTimeSlice(
sat_datapipe,
t0_datapipe,
Expand All @@ -37,9 +37,9 @@ def test_select_time_slice_sat(sat_datapipe):
for sat_sample, t0 in zip(sat_samples, t0_values):
assert len(sat_sample.time_utc) == 3
assert sat_sample.time_utc[1] == t0

# ------------------ Check again with intervals -------------------

dp = SelectTimeSlice(
sat_datapipe,
t0_datapipe,
Expand All @@ -55,7 +55,7 @@ def test_select_time_slice_sat(sat_datapipe):
assert sat_sample.time_utc[1] == t0

# -------------- Check with out of bounds selection ---------------

t_last = pd.to_datetime(data.time_utc.values[-1])
t0_values = [
t_last - timedelta(minutes=5),
Expand All @@ -82,23 +82,24 @@ def test_select_time_slice_sat(sat_datapipe):
# Correct number of time steps are all NaN
sat_sel = sat_sample.isel(x_geostationary=0, y_geostationary=0, channel=0)
assert np.isnan(sat_sel.values).sum() == i

# ------------------- Check with interpolation --------------------

data_times = pd.to_datetime(data.time_utc.values)
t0_datapipe = IterableWrapper(data_times[[0,1,4,5]])
missing_sat_data = data.sel(time_utc=data_times[[0,2,3,6]])
t0_datapipe = IterableWrapper(data_times[[0, 1, 4, 5]])

missing_sat_data = data.sel(time_utc=data_times[[0, 2, 3, 6]])
missing_sat_datapipe = IterableWrapper([missing_sat_data]).repeat(len(t0_datapipe))

# For each sample the timestamps should be missing in this order
expected_missing_steps = np.array([
[True, False, False],
[False, False, False],
[False, True, True],
[True, True, False],
])

# For each sample the timestamps should be missing in this order
expected_missing_steps = np.array(
[
[True, False, False],
[False, False, False],
[False, True, True],
[True, True, False],
]
)

dp = SelectTimeSlice(
missing_sat_datapipe,
Expand Down

0 comments on commit f864553

Please sign in to comment.