Skip to content

Commit

Permalink
Fix/bug 2491 revisited (#2520)
Browse files Browse the repository at this point in the history
* probably found where the bug #2491 stems from

* the DatetimeIndex casting before the index sort was unnecessary and removing it solved the issue

* addes a unit test to assert the expected behaviour of the time_index type

* Update darts/tests/test_timeseries.py

typo in function name

Co-authored-by: madtoinou <[email protected]>

* * Used set_index for better readability

* Included a warning about monotonically increasing index also in the case where time_col is set (analogous to the warning in the case 'time_cole is None and df.index.is_monotonic_increasing)

* * changed the testing time index to: i) an unsorted integer list and ii) an unsorted datetimeindex

* * only sort if unsorted
* potential speed improvement by case distinction within the time column (need to check this again)

* * code in comment removed: there is no performance difference between set_index and casting DatetimeIndex when the time_col is of type datetime

* * moved the test "test_from_group_dataframe" to where the other test concerning "from_group_dataframe" are allocated

* * included a values check

* Update darts/timeseries.py

Co-authored-by: Dennis Bader <[email protected]>

* * included the change in the changelog file

* * Changed from_group_datagrame so that it handles different dtypes of time_col values properly (in the prior fix, string dates of the form "2024-01-01" were not coverted to datetimes)

* Included a unit test to cover this case

* * renamed test to be more descriptive

* Update darts/timeseries.py

Use else statement for better readability

Co-authored-by: madtoinou <[email protected]>

* Update darts/timeseries.py

cleaner case distinction for the time_index type recasting

Co-authored-by: Dennis Bader <[email protected]>

* * parametrized the tests with pytest as suggested by Dennis (thanks!)

---------

Co-authored-by: madtoinou <[email protected]>
Co-authored-by: Dennis Bader <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 38c066b commit 08640f2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 35 deletions.
47 changes: 14 additions & 33 deletions darts/tests/test_timeseries_static_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,33 +104,22 @@ def test_ts_from_x(self, tmpdir_module):
ts, TimeSeries.from_json(ts_json, static_covariates=ts.static_covariates)
)

def test_from_group_dataframe(self):
# checks that the time_index is of RangeIndex type when the time_col is a(n) (unsorted) list and/or a Rangeindex
@pytest.mark.parametrize("index_type", ["int", "dt", "str"])
def test_from_group_dataframe(self, index_type):
"""Tests correct extract of TimeSeries groups from a long DataFrame with unsorted (time/integer) index"""
group = ["a", "a", "a", "b", "b", "b"]
values = np.arange(len(group))

# for time as a unsorted list
time = [2, 1, 0, 0, 1, 2]
df = pd.DataFrame({
"group": group,
"time": time,
"x": values,
})
ts = TimeSeries.from_group_dataframe(df, group_cols="group", time_col="time")

# check the time index
assert ts[0].time_index.equals(pd.RangeIndex(3))
assert ts[1].time_index.equals(pd.RangeIndex(3))
if index_type == "int":
index_expected = pd.RangeIndex(3)
time = [2, 1, 0, 0, 1, 2]
else:
index_expected = pd.date_range("2024-01-01", periods=3)
time = index_expected[::-1].append(index_expected)
if index_type == "str":
time = time.astype(str)

# check the values
assert (ts[0].values().flatten() == [values[2], values[1], values[0]]).all()
assert (ts[1].values().flatten() == [values[3], values[4], values[5]]).all()

# for time as Timestamps
time = pd.DatetimeIndex(
[pd.Timestamp("20240103") - pd.Timedelta(i, "d") for i in range(3)]
+ [pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3)]
)
# create a df with unsorted time
df = pd.DataFrame({
"group": group,
"time": time,
Expand All @@ -139,16 +128,8 @@ def test_from_group_dataframe(self):
ts = TimeSeries.from_group_dataframe(df, group_cols="group", time_col="time")

# check the time index
assert ts[0].time_index.equals(
pd.DatetimeIndex([
pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3)
])
)
assert ts[1].time_index.equals(
pd.DatetimeIndex([
pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3)
])
)
assert ts[0].time_index.equals(index_expected)
assert ts[1].time_index.equals(index_expected)

# check the values
assert (ts[0].values().flatten() == [values[2], values[1], values[0]]).all()
Expand Down
10 changes: 8 additions & 2 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,13 @@ def from_group_dataframe(
df = df[static_cov_cols + extract_value_cols + extract_time_col]

if time_col:
df = df.set_index(df[time_col])
if np.issubdtype(df[time_col].dtype, object) or np.issubdtype(
df[time_col].dtype, np.datetime64
):
df.index = pd.DatetimeIndex(df[time_col])
df = df.drop(columns=time_col)
else:
df = df.set_index(time_col)

if df.index.is_monotonic_increasing:
logger.warning(
Expand All @@ -876,7 +882,7 @@ def from_group_dataframe(
)

# sort on entire `df` to avoid having to sort individually later on
if not df.index.is_monotonic_increasing:
else:
df = df.sort_index()

groups = df.groupby(group_cols[0] if len(group_cols) == 1 else group_cols)
Expand Down

0 comments on commit 08640f2

Please sign in to comment.