From 0ae172936a47a7994428bb61d6800d4914dba32e Mon Sep 17 00:00:00 2001 From: ymatzkevich Date: Tue, 12 Nov 2024 14:12:18 +0100 Subject: [PATCH] Proper treatment of intersection of empty sequence. --- darts/tests/test_timeseries.py | 132 ++++++--------------------------- darts/timeseries.py | 2 + 2 files changed, 26 insertions(+), 108 deletions(-) diff --git a/darts/tests/test_timeseries.py b/darts/tests/test_timeseries.py index daefbfd317..7a5258a685 100644 --- a/darts/tests/test_timeseries.py +++ b/darts/tests/test_timeseries.py @@ -12,7 +12,7 @@ from darts import TimeSeries, concatenate from darts.timeseries import intersect from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries -from darts.utils.utils import expand_arr, freqs, generate_index +from darts.utils.utils import freqs, generate_index class TestTimeSeries: @@ -791,9 +791,6 @@ def helper_test_prepend_values(test_case, test_series: TimeSeries): assert test_series.time_index.equals(prepended_sq.time_index) assert prepended_sq.components.equals(test_series.components) - # component and sample dimension should match - assert prepended._xa.shape[1:] == test_series._xa.shape[1:] - def test_slice(self): TestTimeSeries.helper_test_slice(self, self.series1) @@ -829,112 +826,18 @@ def test_append(self): assert appended.time_index.equals(expected_idx) assert appended.components.equals(series_1.components) - @pytest.mark.parametrize( - "config", - itertools.product( - [ - ( # univariate array - np.array([0, 1, 2]).reshape((3, 1, 1)), - np.array([0, 1]).reshape((2, 1, 1)), - ), - ( # multivariate array - np.array([0, 1, 2, 3, 4, 5]).reshape((3, 2, 1)), - np.array([0, 1, 2, 3]).reshape((2, 2, 1)), - ), - ( # empty array - np.array([0, 1, 2]).reshape((3, 1, 1)), - np.array([]).reshape((0, 1, 1)), - ), - ( - # wrong number of components - np.array([0, 1, 2]).reshape((3, 1, 1)), - np.array([0, 1, 2, 3]).reshape((2, 2, 1)), - ), - ( - # wrong number of samples - np.array([0, 1, 2]).reshape((3, 1, 1)), - np.array([0, 1, 2, 3]).reshape((2, 1, 2)), - ), - ( # univariate list with times - np.array([0, 1, 2]).reshape((3, 1, 1)), - [0, 1], - ), - ( # univariate list with times and components - np.array([0, 1, 2]).reshape((3, 1, 1)), - [[0], [1]], - ), - ( # univariate list with times, components and samples - np.array([0, 1, 2]).reshape((3, 1, 1)), - [[[0]], [[1]]], - ), - ( # multivar with list has wrong shape - np.array([0, 1, 2, 3]).reshape((2, 2, 1)), - [[1, 2], [3, 4]], - ), - ( # list with wrong number of components - np.array([0, 1, 2]).reshape((3, 1, 1)), - [[1, 2], [3, 4]], - ), - ( # list with wrong number of samples - np.array([0, 1, 2]).reshape((3, 1, 1)), - [[[0, 1]], [[1, 2]]], - ), - ( # multivar input but list has wrong shape - np.array([0, 1, 2, 3]).reshape((2, 2, 1)), - [1, 2], - ), - ], - [True, False], - ["append_values", "prepend_values"], - ), - ) - def test_append_and_prepend_values(self, config): - (series_vals, vals), is_datetime, method = config - start = "20240101" if is_datetime else 1 - series_idx = generate_index( - start=start, length=len(series_vals), name="some_name" - ) - series = TimeSeries.from_times_and_values( - times=series_idx, - values=series_vals, + def test_append_values(self): + TestTimeSeries.helper_test_append_values(self, self.series1) + # Check `append_values` deals with `RangeIndex` series correctly: + series = linear_timeseries(start=1, length=5, freq=2) + appended = series.append_values(np.ones((2, 1, 1))) + expected_vals = np.concatenate( + [series.all_values(), np.ones((2, 1, 1))], axis=0 ) - - # expand if it's a list - vals_arr = np.array(vals) if isinstance(vals, list) else vals - vals_arr = expand_arr(vals_arr, ndim=3) - - ts_method = getattr(TimeSeries, method) - - if vals_arr.shape[1:] != series_vals.shape[1:]: - with pytest.raises(ValueError) as exc: - _ = ts_method(series, vals) - assert str(exc.value).startswith( - "The (expanded) values must have the same number of components and samples" - ) - return - - appended = ts_method(series, vals) - - if method == "append_values": - expected_vals = np.concatenate([series_vals, vals_arr], axis=0) - expected_idx = generate_index( - start=series.start_time(), - length=len(series_vals) + len(vals), - freq=series.freq, - ) - else: - expected_vals = np.concatenate([vals_arr, series_vals], axis=0) - expected_idx = generate_index( - end=series.end_time(), - length=len(series_vals) + len(vals), - freq=series.freq, - ) - + expected_idx = pd.RangeIndex(start=1, stop=15, step=2) assert np.allclose(appended.all_values(), expected_vals) assert appended.time_index.equals(expected_idx) assert appended.components.equals(series.components) - assert appended._xa.shape[1:] == series._xa.shape[1:] - assert appended.time_index.name == series.time_index.name def test_prepend(self): TestTimeSeries.helper_test_prepend(self, self.series1) @@ -950,6 +853,19 @@ def test_prepend(self): assert prepended.time_index.equals(expected_idx) assert prepended.components.equals(series_1.components) + def test_prepend_values(self): + TestTimeSeries.helper_test_prepend_values(self, self.series1) + # Check `prepend_values` deals with `RangeIndex` series correctly: + series = linear_timeseries(start=1, length=5, freq=2) + prepended = series.prepend_values(np.ones((2, 1, 1))) + expected_vals = np.concatenate( + [np.ones((2, 1, 1)), series.all_values()], axis=0 + ) + expected_idx = pd.RangeIndex(start=-3, stop=11, step=2) + assert np.allclose(prepended.all_values(), expected_vals) + assert prepended.time_index.equals(expected_idx) + assert prepended.components.equals(series.components) + @pytest.mark.parametrize( "config", [ @@ -2461,8 +2377,8 @@ def test_time_col_with_tz(self): assert list(ts.time_index.tz_localize("CET")) == list(time_range_H) assert ts.time_index.tz is None - series = pd.Series(data=values, index=time_range_H) - ts = TimeSeries.from_series(pd_series=series) + serie = pd.Series(data=values, index=time_range_H) + ts = TimeSeries.from_series(pd_series=serie) assert list(ts.time_index) == list(time_range_H.tz_localize(None)) assert list(ts.time_index.tz_localize("CET")) == list(time_range_H) assert ts.time_index.tz is None diff --git a/darts/timeseries.py b/darts/timeseries.py index d9b670e011..6d84433401 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -5672,6 +5672,8 @@ def intersect(series: Sequence[TimeSeries]): Sequence[TimeSeries] Intersected series """ + if not series: + return [] data_arrays = [] has_datetime_index = series[0].has_datetime_index