Skip to content

Commit

Permalink
feat: support passing index object directly into maybe_set_index (#1319)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: FBruzzesi <[email protected]>
  • Loading branch information
Riik and FBruzzesi authored Nov 12, 2024
1 parent c694148 commit 0929d52
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 19 deletions.
78 changes: 67 additions & 11 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,37 @@ def maybe_get_index(obj: T) -> Any | None:
return None


def maybe_set_index(df: T, column_names: str | list[str]) -> T:
def maybe_set_index(
obj: T,
column_names: str | list[str] | None = None,
*,
index: Series | list[Series] | None = None,
) -> T:
"""
Set columns `columns` to be the index of `df`, if `df` is pandas-like.
Set the index of a DataFrame or a Series, if it's pandas-like.
Arguments:
obj: object for which maybe set the index (can be either a Narwhals `DataFrame`
or `Series`).
column_names: name or list of names of the columns to set as index.
For dataframes, only one of `column_names` and `index` can be specified but
not both. If `column_names` is passed and `df` is a Series, then a
`ValueError` is raised.
index: series or list of series to set as index.
Raises:
ValueError: If one of the following condition happens:
- none of `column_names` and `index` are provided
- both `column_names` and `index` are provided
- `column_names` is provided and `df` is a Series
Notes:
This is only really intended for backwards-compatibility purposes,
for example if your library already aligns indices for users.
This is only really intended for backwards-compatibility purposes, for example if
your library already aligns indices for users.
If you're designing a new library, we highly encourage you to not
rely on the Index.
For non-pandas-like inputs, this is a no-op.
Examples:
Expand All @@ -297,15 +319,49 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:
4 1
5 2
"""
df_any = cast(Any, df)
native_frame = to_native(df_any)
if is_pandas_like_dataframe(native_frame):

df_any = cast(Any, obj)
native_obj = to_native(df_any)

if column_names is not None and index is not None:
msg = "Only one of `column_names` or `index` should be provided"
raise ValueError(msg)

if not column_names and not index:
msg = "Either `column_names` or `index` should be provided"
raise ValueError(msg)

if index is not None:
keys = (
[to_native(idx, pass_through=True) for idx in index]
if _is_iterable(index)
else to_native(index, pass_through=True)
)
else:
keys = column_names

if is_pandas_like_dataframe(native_obj):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(
native_frame.set_index(column_names)
)
df_any._compliant_frame._from_native_frame(native_obj.set_index(keys))
)
elif is_pandas_like_series(native_obj):
if column_names:
msg = "Cannot set index using column names on a Series"
raise ValueError(msg)

if (
df_any._compliant_series._implementation is Implementation.PANDAS
and df_any._compliant_series._backend_version < (1,)
): # pragma: no cover
native_obj = native_obj.set_axis(keys, inplace=False)
else:
native_obj = native_obj.set_axis(keys)

return df_any._from_compliant_series( # type: ignore[no-any-return]
df_any._compliant_series._from_native_series(native_obj)
)
return df_any # type: ignore[no-any-return]
else:
return df_any # type: ignore[no-any-return]


def maybe_reset_index(obj: T) -> T:
Expand Down
108 changes: 100 additions & 8 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import string
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import pandas as pd
Expand All @@ -15,6 +16,9 @@
from tests.utils import PANDAS_VERSION
from tests.utils import get_module_version_as_tuple

if TYPE_CHECKING:
from narwhals.series import Series


def test_maybe_align_index_pandas() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0]))
Expand Down Expand Up @@ -58,21 +62,109 @@ def test_maybe_align_index_polars() -> None:
nw.maybe_align_index(df, s[1:])


def test_maybe_set_index_pandas() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]))
result = nw.maybe_set_index(df, "b")
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]).set_index(
"b"
)
@pytest.mark.parametrize(
"column_names",
["b", ["a", "b"]],
)
def test_maybe_set_index_pandas_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(column_names)
assert_frame_equal(nw.to_native(result), expected)


def test_maybe_set_index_polars() -> None:
@pytest.mark.parametrize(
"column_names",
[
"b",
["a", "b"],
],
)
def test_maybe_set_index_polars_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
assert result is df


@pytest.mark.parametrize(
"native_df_or_series",
[pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), pd.Series([0, 1, 2])],
)
@pytest.mark.parametrize(
("narwhals_index", "pandas_index"),
[
(nw.from_native(pd.Series([1, 2, 0]), series_only=True), pd.Series([1, 2, 0])),
(
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
[
pd.Series([0, 1, 2]),
pd.Series([1, 2, 0]),
],
),
],
)
def test_maybe_set_index_pandas_direct_index(
narwhals_index: Series | list[Series] | None,
pandas_index: pd.Series | list[pd.Series] | None,
native_df_or_series: pd.DataFrame | pd.Series,
) -> None:
df = nw.from_native(native_df_or_series, allow_series=True)
result = nw.maybe_set_index(df, index=narwhals_index)
if isinstance(native_df_or_series, pd.Series):
native_df_or_series.index = pandas_index
assert_series_equal(nw.to_native(result), native_df_or_series)
else:
expected = native_df_or_series.set_index(pandas_index)
assert_frame_equal(nw.to_native(result), expected)


@pytest.mark.parametrize(
"index",
[
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
],
)
def test_maybe_set_index_polars_direct_index(
index: Series | list[Series] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, "b")
result = nw.maybe_set_index(df, index=index)
assert result is df


def test_maybe_set_index_pandas_series_column_names() -> None:
df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True)
with pytest.raises(
ValueError, match="Cannot set index using column names on a Series"
):
nw.maybe_set_index(df, column_names=["a"])


def test_maybe_set_index_pandas_either_index_or_column_names() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
column_names = ["a", "b"]
index = nw.from_native(pd.Series([0, 1, 2]), series_only=True)
with pytest.raises(
ValueError, match="Only one of `column_names` or `index` should be provided"
):
nw.maybe_set_index(df, column_names=column_names, index=index)
with pytest.raises(
ValueError, match="Either `column_names` or `index` should be provided"
):
nw.maybe_set_index(df)


def test_maybe_get_index_pandas() -> None:
pandas_df = pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0])
result = nw.maybe_get_index(nw.from_native(pandas_df))
Expand Down

0 comments on commit 0929d52

Please sign in to comment.