Skip to content

Commit

Permalink
fix: .list namespace should preserve pandas index (#1538)
Browse files Browse the repository at this point in the history
* fix: .list preserve index

* use set_axis helper function and check pandas versions

* no cover?
  • Loading branch information
FBruzzesi authored Dec 8, 2024
1 parent fe15dd7 commit adadecc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
13 changes: 12 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,18 @@ def len(self: Self) -> PandasLikeSeries:
from narwhals.utils import import_dtypes_module

native_series = self._compliant_series._native_series
native_result = native_series.list.len().rename(native_series.name, copy=False)
native_result = native_series.list.len()

if (
self._compliant_series._implementation is Implementation.PANDAS
and self._compliant_series._backend_version < (3, 0)
): # pragma: no cover
native_result = set_axis(
native_result.rename(native_series.name, copy=False),
index=native_series.index,
implementation=self._compliant_series._implementation,
backend_version=self._compliant_series._backend_version,
)
dtype = narwhals_to_native_dtype(
dtype=import_dtypes_module(self._compliant_series._version).UInt32(),
starting_dtype=native_result.dtype,
Expand Down
13 changes: 13 additions & 0 deletions tests/expr_and_series/list/len_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pandas as pd
import pytest

import narwhals.stable.v1 as nw
Expand Down Expand Up @@ -43,3 +44,15 @@ def test_len_series(

result = df["a"].cast(nw.List(nw.Int32())).list.len()
assert_equal_data({"a": result}, expected)


def test_pandas_preserve_index(request: pytest.FixtureRequest) -> None:
if PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

index = pd.Index(["a", "b", "c", "d", "e"])
df = nw.from_native(pd.DataFrame(data, index=index), eager_only=True)

result = df["a"].cast(nw.List(nw.Int32())).list.len()
assert_equal_data({"a": result}, expected)
assert (result.to_native().index == index).all()

0 comments on commit adadecc

Please sign in to comment.