diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index 10cd710ba9..5626c2d159 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -10,6 +10,7 @@ from scipy.stats import kurtosis, skew from darts import TimeSeries, concatenate +from darts.timeseries import slice_intersect from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries from darts.utils.utils import expand_arr, freqs, generate_index @@ -603,6 +604,11 @@ def check_intersect(other, start_, end_, freq_): s_int_idx = series.slice_intersect_times(other, copy=False) assert s_int.time_index.equals(s_int_idx) + assert slice_intersect([series, other]) == [ + series.slice_intersect(other), + other.slice_intersect(series), + ] + # slice with exact range startA = start endA = end @@ -611,11 +617,11 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesA, startA, endA, freq_expected) # entire slice within the range - startA = start + freq - endA = startA + 6 * freq_other - idxA = generate_index(startA, endA, freq=freq_other) - seriesA = TimeSeries.from_series(pd.Series(range(len(idxA)), index=idxA)) - check_intersect(seriesA, startA, endA, freq_expected) + startB = start + freq + endB = startB + 6 * freq_other + idxB = generate_index(startB, endB, freq=freq_other) + seriesB = TimeSeries.from_series(pd.Series(range(len(idxB)), index=idxB)) + check_intersect(seriesB, startB, endB, freq_expected) # start outside of range startC = start - 4 * freq @@ -625,11 +631,11 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesC, start, endC, freq_expected) # end outside of range - startC = start + 4 * freq - endC = end + 4 * freq_other - idxC = generate_index(startC, endC, freq=freq_other) - seriesC = TimeSeries.from_series(pd.Series(range(len(idxC)), index=idxC)) - check_intersect(seriesC, startC, end, freq_expected) + startD = start + 4 * freq + endD = end + 4 * freq_other + idxD = generate_index(startD, endD, freq=freq_other) + seriesD = TimeSeries.from_series(pd.Series(range(len(idxD)), index=idxD)) + check_intersect(seriesD, startD, end, freq_expected) # small intersect startE = start + (n_steps - 1) * freq @@ -639,12 +645,40 @@ def check_intersect(other, start_, end_, freq_): check_intersect(seriesE, startE, end, freq_expected) # No intersect - startG = end + 3 * freq - endG = startG + 6 * freq_other - idxG = generate_index(startG, endG, freq=freq_other) - seriesG = TimeSeries.from_series(pd.Series(range(len(idxG)), index=idxG)) + startF = end + 3 * freq + endF = startF + 6 * freq_other + idxF = generate_index(startF, endF, freq=freq_other) + seriesF = TimeSeries.from_series(pd.Series(range(len(idxF)), index=idxF)) # for empty slices, we expect the original freq - check_intersect(seriesG, None, None, freq) + check_intersect(seriesF, None, None, freq) + + # sequence with zero or one element + assert slice_intersect([]) == [] + assert slice_intersect([series]) == [series] + + # sequence with more than 2 elements + intersected_series = slice_intersect([series, seriesA, seriesE]) + s1_int = intersected_series[0] + s2_int = intersected_series[1] + s3_int = intersected_series[2] + + assert ( + s1_int.start_time() == s2_int.start_time() == s3_int.start_time() == startE + ) + assert s1_int.end_time() == s2_int.end_time() == s3_int.end_time() == endA + + # check treatment different time index types + if series.has_datetime_index: + seriesF = TimeSeries.from_series( + pd.Series(range(len(idxF)), index=pd.to_numeric(idxF)) + ) + else: + seriesF = TimeSeries.from_series( + pd.Series(range(len(idxF)), index=pd.to_datetime(idxF)) + ) + + with pytest.raises(IndexError): + slice_intersect([series, seriesF]) @staticmethod def helper_test_shift(test_case, test_series: TimeSeries): diff --git a/darts/timeseries.py b/darts/timeseries.py index 6f2597dd15..8b45c77e9f 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -2494,14 +2494,7 @@ def slice_intersect(self, other: Self) -> Self: TimeSeries a new series, containing the values of this series, over the time-span common to both time series. """ - if other.has_same_time_as(self): - return self.__class__(self._xa) - if other.freq == self.freq: - start, end = self._slice_intersect_bounds(other) - return self[start:end] - else: - time_index = self.time_index.intersection(other.time_index) - return self[time_index] + return slice_intersect([self, other])[0] def slice_intersect_values(self, other: Self, copy: bool = False) -> np.ndarray: """ @@ -5659,6 +5652,36 @@ def concatenate( return TimeSeries.from_xarray(da_concat, fill_missing_dates=False) +def slice_intersect(series: Sequence[TimeSeries]) -> Sequence[TimeSeries]: + """Returns a list of ``TimeSeries``, where all `series` have been intersected along the time index. + + Parameters + ---------- + series : Sequence[TimeSeries] + sequence of ``TimeSeries`` to intersect + + Returns + ------- + Sequence[TimeSeries] + Intersected series. + """ + if not series: + return [] + + int_time_index = series[0].time_index + for ts in series[1:]: + int_time_index = int_time_index.intersection(ts.time_index) + + ts_other = series[0] + for ts in series[1:]: + int_time_index = int_time_index.intersection( + ts.time_index.intersection(ts_other.time_index) + ) + ts_other = ts + + return [ts[int_time_index] for ts in series] + + def _finite_rows_boundaries( values: np.ndarray, how: str = "all" ) -> tuple[Optional[int], Optional[int]]: