From 08640f22d0ae176459c242fff82ab6604c2cf901 Mon Sep 17 00:00:00 2001 From: AlessiopSymplectic <160468679+AlessiopSymplectic@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:23:27 +0200 Subject: [PATCH] Fix/bug 2491 revisited (#2520) * probably found where the bug #2491 stems from * the DatetimeIndex casting before the index sort was unnecessary and removing it solved the issue * addes a unit test to assert the expected behaviour of the time_index type * Update darts/tests/test_timeseries.py typo in function name Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com> * * Used set_index for better readability * Included a warning about monotonically increasing index also in the case where time_col is set (analogous to the warning in the case 'time_cole is None and df.index.is_monotonic_increasing) * * changed the testing time index to: i) an unsorted integer list and ii) an unsorted datetimeindex * * only sort if unsorted * potential speed improvement by case distinction within the time column (need to check this again) * * code in comment removed: there is no performance difference between set_index and casting DatetimeIndex when the time_col is of type datetime * * moved the test "test_from_group_dataframe" to where the other test concerning "from_group_dataframe" are allocated * * included a values check * Update darts/timeseries.py Co-authored-by: Dennis Bader * * included the change in the changelog file * * Changed from_group_datagrame so that it handles different dtypes of time_col values properly (in the prior fix, string dates of the form "2024-01-01" were not coverted to datetimes) * Included a unit test to cover this case * * renamed test to be more descriptive * Update darts/timeseries.py Use else statement for better readability Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com> * Update darts/timeseries.py cleaner case distinction for the time_index type recasting Co-authored-by: Dennis Bader * * parametrized the tests with pytest as suggested by Dennis (thanks!) --------- Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com> Co-authored-by: Dennis Bader --- .../test_timeseries_static_covariates.py | 47 ++++++------------- darts/timeseries.py | 10 +++- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/darts/tests/test_timeseries_static_covariates.py b/darts/tests/test_timeseries_static_covariates.py index 895ef5b686..07590eb82f 100644 --- a/darts/tests/test_timeseries_static_covariates.py +++ b/darts/tests/test_timeseries_static_covariates.py @@ -104,33 +104,22 @@ def test_ts_from_x(self, tmpdir_module): ts, TimeSeries.from_json(ts_json, static_covariates=ts.static_covariates) ) - def test_from_group_dataframe(self): - # checks that the time_index is of RangeIndex type when the time_col is a(n) (unsorted) list and/or a Rangeindex + @pytest.mark.parametrize("index_type", ["int", "dt", "str"]) + def test_from_group_dataframe(self, index_type): + """Tests correct extract of TimeSeries groups from a long DataFrame with unsorted (time/integer) index""" group = ["a", "a", "a", "b", "b", "b"] values = np.arange(len(group)) - # for time as a unsorted list - time = [2, 1, 0, 0, 1, 2] - df = pd.DataFrame({ - "group": group, - "time": time, - "x": values, - }) - ts = TimeSeries.from_group_dataframe(df, group_cols="group", time_col="time") - - # check the time index - assert ts[0].time_index.equals(pd.RangeIndex(3)) - assert ts[1].time_index.equals(pd.RangeIndex(3)) + if index_type == "int": + index_expected = pd.RangeIndex(3) + time = [2, 1, 0, 0, 1, 2] + else: + index_expected = pd.date_range("2024-01-01", periods=3) + time = index_expected[::-1].append(index_expected) + if index_type == "str": + time = time.astype(str) - # check the values - assert (ts[0].values().flatten() == [values[2], values[1], values[0]]).all() - assert (ts[1].values().flatten() == [values[3], values[4], values[5]]).all() - - # for time as Timestamps - time = pd.DatetimeIndex( - [pd.Timestamp("20240103") - pd.Timedelta(i, "d") for i in range(3)] - + [pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3)] - ) + # create a df with unsorted time df = pd.DataFrame({ "group": group, "time": time, @@ -139,16 +128,8 @@ def test_from_group_dataframe(self): ts = TimeSeries.from_group_dataframe(df, group_cols="group", time_col="time") # check the time index - assert ts[0].time_index.equals( - pd.DatetimeIndex([ - pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3) - ]) - ) - assert ts[1].time_index.equals( - pd.DatetimeIndex([ - pd.Timestamp("20240101") + pd.Timedelta(i, "d") for i in range(3) - ]) - ) + assert ts[0].time_index.equals(index_expected) + assert ts[1].time_index.equals(index_expected) # check the values assert (ts[0].values().flatten() == [values[2], values[1], values[0]]).all() diff --git a/darts/timeseries.py b/darts/timeseries.py index 5f7878eb56..b14e37a4b9 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -866,7 +866,13 @@ def from_group_dataframe( df = df[static_cov_cols + extract_value_cols + extract_time_col] if time_col: - df = df.set_index(df[time_col]) + if np.issubdtype(df[time_col].dtype, object) or np.issubdtype( + df[time_col].dtype, np.datetime64 + ): + df.index = pd.DatetimeIndex(df[time_col]) + df = df.drop(columns=time_col) + else: + df = df.set_index(time_col) if df.index.is_monotonic_increasing: logger.warning( @@ -876,7 +882,7 @@ def from_group_dataframe( ) # sort on entire `df` to avoid having to sort individually later on - if not df.index.is_monotonic_increasing: + else: df = df.sort_index() groups = df.groupby(group_cols[0] if len(group_cols) == 1 else group_cols)