Skip to content

Commit

Permalink
.nest.get_flat_index to work with multiindex
Browse files Browse the repository at this point in the history
  • Loading branch information
hombit committed May 10, 2024
1 parent 725da0c commit 7b87cbd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/nested_pandas/series/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ def query_flat(self, query: str) -> pd.Series:

def get_flat_index(self) -> pd.Index:
"""Index of the flat arrays"""
return pd.Index(
np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)),
name=self._series.index.name,
)
flat_index = np.repeat(self._series.index, np.diff(self._series.array.list_offsets))
# pd.Index supports np.repeat, so flat_index is the same type as self._series.index
flat_index = cast(pd.Index, flat_index)
return flat_index

def get_flat_series(self, field: str) -> pd.Series:
"""Get the flat-array field as a Series
Expand Down
22 changes: 21 additions & 1 deletion tests/nested_pandas/series/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nested_pandas.series.ext_array import NestedExtensionArray
from nested_pandas.series.packer import pack_flat
from numpy.testing import assert_array_equal
from pandas.testing import assert_frame_equal, assert_series_equal
from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal


def test_registered():
Expand Down Expand Up @@ -345,6 +345,26 @@ def test_query_flat_empty_rows():
assert_series_equal(filtered, desired)


@pytest.mark.parametrize(
"df",
[
pd.DataFrame({"a": [1] * 10}, index=[1, 2, 2, 3, 3, 3, 4, 4, 4, 4]),
pd.DataFrame(
{"a": [1] * 10},
index=pd.MultiIndex.from_arrays(([1, 1, 1, 1, 1, 1, 2, 2, 2, 2], [0, 1, 1, 2, 2, 2, 1, 1, 0, 0])),
),
pd.DataFrame({"a": [1] * 10}, index=[1, 0, 0, 3, 3, 3, 0, 0, 0, 0]),
pd.DataFrame(
{"a": [1] * 6}, index=pd.MultiIndex.from_arrays(([0, 1, 0, 1, 0, 1], [1, 0, 0, 1, 0, 2]))
),
],
)
def test_get_flat_index(df):
"""Test .nest.get_flat_index() returns the index of the original flat df"""
series = pack_flat(df)
assert_index_equal(series.nest.get_flat_index(), df.index.sort_values())


def test_get_list_series():
"""Test that the .nest.get_list_series() method works."""
struct_array = pa.StructArray.from_arrays(
Expand Down

0 comments on commit 7b87cbd

Please sign in to comment.