From 7b87cbdc5253590f03706ea9a65cefc26f4225aa Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Fri, 10 May 2024 14:21:18 -0400 Subject: [PATCH] .nest.get_flat_index to work with multiindex --- src/nested_pandas/series/accessor.py | 8 ++++---- tests/nested_pandas/series/test_accessor.py | 22 ++++++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/nested_pandas/series/accessor.py b/src/nested_pandas/series/accessor.py index 8a075a7..3bd967b 100644 --- a/src/nested_pandas/series/accessor.py +++ b/src/nested_pandas/series/accessor.py @@ -182,10 +182,10 @@ def query_flat(self, query: str) -> pd.Series: def get_flat_index(self) -> pd.Index: """Index of the flat arrays""" - return pd.Index( - np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)), - name=self._series.index.name, - ) + flat_index = np.repeat(self._series.index, np.diff(self._series.array.list_offsets)) + # pd.Index supports np.repeat, so flat_index is the same type as self._series.index + flat_index = cast(pd.Index, flat_index) + return flat_index def get_flat_series(self, field: str) -> pd.Series: """Get the flat-array field as a Series diff --git a/tests/nested_pandas/series/test_accessor.py b/tests/nested_pandas/series/test_accessor.py index e461ac6..5ee617d 100644 --- a/tests/nested_pandas/series/test_accessor.py +++ b/tests/nested_pandas/series/test_accessor.py @@ -6,7 +6,7 @@ from nested_pandas.series.ext_array import NestedExtensionArray from nested_pandas.series.packer import pack_flat from numpy.testing import assert_array_equal -from pandas.testing import assert_frame_equal, assert_series_equal +from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal def test_registered(): @@ -345,6 +345,26 @@ def test_query_flat_empty_rows(): assert_series_equal(filtered, desired) +@pytest.mark.parametrize( + "df", + [ + pd.DataFrame({"a": [1] * 10}, index=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]), + pd.DataFrame( + {"a": [1] * 10}, + index=pd.MultiIndex.from_arrays(([1, 1, 1, 1, 1, 1, 2, 2, 2, 2], [0, 1, 1, 2, 2, 2, 1, 1, 0, 0])), + ), + pd.DataFrame({"a": [1] * 10}, index=[1, 0, 0, 3, 3, 3, 0, 0, 0, 0]), + pd.DataFrame( + {"a": [1] * 6}, index=pd.MultiIndex.from_arrays(([0, 1, 0, 1, 0, 1], [1, 0, 0, 1, 0, 2])) + ), + ], +) +def test_get_flat_index(df): + """Test .nest.get_flat_index() returns the index of the original flat df""" + series = pack_flat(df) + assert_index_equal(series.nest.get_flat_index(), df.index.sort_values()) + + def test_get_list_series(): """Test that the .nest.get_list_series() method works.""" struct_array = pa.StructArray.from_arrays(