Skip to content

Commit

Permalink
Proper treatment of intersection of empty sequence.
Browse files Browse the repository at this point in the history
  • Loading branch information
ymatzkevich committed Nov 12, 2024
1 parent b6f6812 commit 0ae1729
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 108 deletions.
132 changes: 24 additions & 108 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from darts import TimeSeries, concatenate
from darts.timeseries import intersect
from darts.utils.timeseries_generation import constant_timeseries, linear_timeseries
from darts.utils.utils import expand_arr, freqs, generate_index
from darts.utils.utils import freqs, generate_index


class TestTimeSeries:
Expand Down Expand Up @@ -791,9 +791,6 @@ def helper_test_prepend_values(test_case, test_series: TimeSeries):
assert test_series.time_index.equals(prepended_sq.time_index)
assert prepended_sq.components.equals(test_series.components)

# component and sample dimension should match
assert prepended._xa.shape[1:] == test_series._xa.shape[1:]

def test_slice(self):
TestTimeSeries.helper_test_slice(self, self.series1)

Expand Down Expand Up @@ -829,112 +826,18 @@ def test_append(self):
assert appended.time_index.equals(expected_idx)
assert appended.components.equals(series_1.components)

@pytest.mark.parametrize(
"config",
itertools.product(
[
( # univariate array
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1]).reshape((2, 1, 1)),
),
( # multivariate array
np.array([0, 1, 2, 3, 4, 5]).reshape((3, 2, 1)),
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
),
( # empty array
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([]).reshape((0, 1, 1)),
),
(
# wrong number of components
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
),
(
# wrong number of samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
np.array([0, 1, 2, 3]).reshape((2, 1, 2)),
),
( # univariate list with times
np.array([0, 1, 2]).reshape((3, 1, 1)),
[0, 1],
),
( # univariate list with times and components
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[0], [1]],
),
( # univariate list with times, components and samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[[0]], [[1]]],
),
( # multivar with list has wrong shape
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
[[1, 2], [3, 4]],
),
( # list with wrong number of components
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[1, 2], [3, 4]],
),
( # list with wrong number of samples
np.array([0, 1, 2]).reshape((3, 1, 1)),
[[[0, 1]], [[1, 2]]],
),
( # multivar input but list has wrong shape
np.array([0, 1, 2, 3]).reshape((2, 2, 1)),
[1, 2],
),
],
[True, False],
["append_values", "prepend_values"],
),
)
def test_append_and_prepend_values(self, config):
(series_vals, vals), is_datetime, method = config
start = "20240101" if is_datetime else 1
series_idx = generate_index(
start=start, length=len(series_vals), name="some_name"
)
series = TimeSeries.from_times_and_values(
times=series_idx,
values=series_vals,
def test_append_values(self):
TestTimeSeries.helper_test_append_values(self, self.series1)
# Check `append_values` deals with `RangeIndex` series correctly:
series = linear_timeseries(start=1, length=5, freq=2)
appended = series.append_values(np.ones((2, 1, 1)))
expected_vals = np.concatenate(
[series.all_values(), np.ones((2, 1, 1))], axis=0
)

# expand if it's a list
vals_arr = np.array(vals) if isinstance(vals, list) else vals
vals_arr = expand_arr(vals_arr, ndim=3)

ts_method = getattr(TimeSeries, method)

if vals_arr.shape[1:] != series_vals.shape[1:]:
with pytest.raises(ValueError) as exc:
_ = ts_method(series, vals)
assert str(exc.value).startswith(
"The (expanded) values must have the same number of components and samples"
)
return

appended = ts_method(series, vals)

if method == "append_values":
expected_vals = np.concatenate([series_vals, vals_arr], axis=0)
expected_idx = generate_index(
start=series.start_time(),
length=len(series_vals) + len(vals),
freq=series.freq,
)
else:
expected_vals = np.concatenate([vals_arr, series_vals], axis=0)
expected_idx = generate_index(
end=series.end_time(),
length=len(series_vals) + len(vals),
freq=series.freq,
)

expected_idx = pd.RangeIndex(start=1, stop=15, step=2)
assert np.allclose(appended.all_values(), expected_vals)
assert appended.time_index.equals(expected_idx)
assert appended.components.equals(series.components)
assert appended._xa.shape[1:] == series._xa.shape[1:]
assert appended.time_index.name == series.time_index.name

def test_prepend(self):
TestTimeSeries.helper_test_prepend(self, self.series1)
Expand All @@ -950,6 +853,19 @@ def test_prepend(self):
assert prepended.time_index.equals(expected_idx)
assert prepended.components.equals(series_1.components)

def test_prepend_values(self):
TestTimeSeries.helper_test_prepend_values(self, self.series1)
# Check `prepend_values` deals with `RangeIndex` series correctly:
series = linear_timeseries(start=1, length=5, freq=2)
prepended = series.prepend_values(np.ones((2, 1, 1)))
expected_vals = np.concatenate(
[np.ones((2, 1, 1)), series.all_values()], axis=0
)
expected_idx = pd.RangeIndex(start=-3, stop=11, step=2)
assert np.allclose(prepended.all_values(), expected_vals)
assert prepended.time_index.equals(expected_idx)
assert prepended.components.equals(series.components)

@pytest.mark.parametrize(
"config",
[
Expand Down Expand Up @@ -2461,8 +2377,8 @@ def test_time_col_with_tz(self):
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
assert ts.time_index.tz is None

series = pd.Series(data=values, index=time_range_H)
ts = TimeSeries.from_series(pd_series=series)
serie = pd.Series(data=values, index=time_range_H)
ts = TimeSeries.from_series(pd_series=serie)
assert list(ts.time_index) == list(time_range_H.tz_localize(None))
assert list(ts.time_index.tz_localize("CET")) == list(time_range_H)
assert ts.time_index.tz is None
Expand Down
2 changes: 2 additions & 0 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5672,6 +5672,8 @@ def intersect(series: Sequence[TimeSeries]):
Sequence[TimeSeries]
Intersected series
"""
if not series:
return []

data_arrays = []
has_datetime_index = series[0].has_datetime_index
Expand Down

0 comments on commit 0ae1729

Please sign in to comment.