Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/slice intersect multi series #2592

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
64 changes: 49 additions & 15 deletions darts/tests/test_timeseries.py
ymatzkevich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import kurtosis, skew

from darts import TimeSeries, concatenate
from darts.timeseries import slice_intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index

Expand Down Expand Up @@ -603,6 +604,11 @@ def check_intersect(other, start_, end_, freq_):
s_int_idx = series.slice_intersect_times(other, copy=False)
assert s_int.time_index.equals(s_int_idx)

assert slice_intersect([series, other]) == [
series.slice_intersect(other),
other.slice_intersect(series),
]

# slice with exact range
startA = start
endA = end
Expand All @@ -611,11 +617,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesA, startA, endA, freq_expected)

# entire slice within the range
startA = start + freq
endA = startA + 6 * freq_other
idxA = generate_index(startA, endA, freq=freq_other)
seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA))
check_intersect(seriesA, startA, endA, freq_expected)
startB = start + freq
endB = startB + 6 * freq_other
idxB = generate_index(startB, endB, freq=freq_other)
seriesB = TimeSeries.from_series(pd.Series(range(len(idxB)), index=idxB))
check_intersect(seriesB, startB, endB, freq_expected)

# start outside of range
startC = start - 4 * freq
Expand All @@ -625,11 +631,11 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesC, start, endC, freq_expected)

# end outside of range
startC = start + 4 * freq
endC = end + 4 * freq_other
idxC = generate_index(startC, endC, freq=freq_other)
seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC))
check_intersect(seriesC, startC, end, freq_expected)
startD = start + 4 * freq
endD = end + 4 * freq_other
idxD = generate_index(startD, endD, freq=freq_other)
seriesD = TimeSeries.from_series(pd.Series(range(len(idxD)), index=idxD))
check_intersect(seriesD, startD, end, freq_expected)

# small intersect
startE = start + (n_steps - 1) * freq
Expand All @@ -639,12 +645,40 @@ def check_intersect(other, start_, end_, freq_):
check_intersect(seriesE, startE, end, freq_expected)

# No intersect
startG = end + 3 * freq
endG = startG + 6 * freq_other
idxG = generate_index(startG, endG, freq=freq_other)
seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG))
startF = end + 3 * freq
endF = startF + 6 * freq_other
idxF = generate_index(startF, endF, freq=freq_other)
seriesF = TimeSeries.from_series(pd.Series(range(len(idxF)), index=idxF))
# for empty slices, we expect the original freq
check_intersect(seriesG, None, None, freq)
check_intersect(seriesF, None, None, freq)

# sequence with zero or one element
assert slice_intersect([]) == []
assert slice_intersect([series]) == [series]

# sequence with more than 2 elements
intersected_series = slice_intersect([series, seriesA, seriesE])
s1_int = intersected_series[0]
s2_int = intersected_series[1]
s3_int = intersected_series[2]

assert (
s1_int.start_time() == s2_int.start_time() == s3_int.start_time() == startE
)
assert s1_int.end_time() == s2_int.end_time() == s3_int.end_time() == endA

# check treatment different time index types
if series.has_datetime_index:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_numeric(idxF))
)
else:
seriesF = TimeSeries.from_series(
pd.Series(range(len(idxF)), index=pd.to_datetime(idxF))
)

with pytest.raises(IndexError):
slice_intersect([series, seriesF])

@staticmethod
def helper_test_shift(test_case, test_series: TimeSeries):
Expand Down
39 changes: 31 additions & 8 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2494,14 +2494,7 @@ def slice_intersect(self, other: Self) -> Self:
TimeSeries
a new series, containing the values of this series, over the time-span common to both time series.
"""
if other.has_same_time_as(self):
return self.__class__(self._xa)
if other.freq == self.freq:
start, end = self._slice_intersect_bounds(other)
return self[start:end]
else:
time_index = self.time_index.intersection(other.time_index)
return self[time_index]
return slice_intersect([self, other])[0]

def slice_intersect_values(self, other: Self, copy: bool = False) -> np.ndarray:
"""
Expand Down Expand Up @@ -5659,6 +5652,36 @@ def concatenate(
return TimeSeries.from_xarray(da_concat, fill_missing_dates=False)


def slice_intersect(series: Sequence[TimeSeries]) -> Sequence[TimeSeries]:
"""Returns a list of ``TimeSeries``, where all `series` have been intersected along the time index.

Parameters
----------
series : Sequence[TimeSeries]
sequence of ``TimeSeries`` to intersect

Returns
-------
Sequence[TimeSeries]
Intersected series.
"""
if not series:
return []

int_time_index = series[0].time_index
for ts in series[1:]:
int_time_index = int_time_index.intersection(ts.time_index)

ts_other = series[0]
for ts in series[1:]:
int_time_index = int_time_index.intersection(
ts.time_index.intersection(ts_other.time_index)
)
ts_other = ts

return [ts[int_time_index] for ts in series]


def _finite_rows_boundaries(
values: np.ndarray, how: str = "all"
) -> tuple[Optional[int], Optional[int]]:
Expand Down
Loading