Skip to content

Commit

Permalink
set_columns
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Dec 22, 2024
1 parent dfd940c commit 5cc5f46
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
19 changes: 13 additions & 6 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import set_columns
from narwhals.utils import Implementation
from narwhals.utils import find_stacklevel
from narwhals.utils import remove_prefix
Expand Down Expand Up @@ -271,18 +272,24 @@ def agg_pandas( # noqa: PLR0915
if std_aggs:
result_aggs.extend(
[
grouped[std_root_names]
.std(ddof=ddof)
.set_axis(std_output_names, axis="columns", copy=False)
set_columns(
grouped[std_root_names].std(ddof=ddof),
columns=std_output_names,
implementation=implementation,
backend_version=backend_version,
)
for ddof, (std_root_names, std_output_names) in std_aggs.items()
]
)
if var_aggs:
result_aggs.extend(
[
grouped[var_root_names]
.var(ddof=ddof)
.set_axis(var_output_names, axis="columns", copy=False)
set_columns(
grouped[var_root_names].var(ddof=ddof),
columns=var_output_names,
implementation=implementation,
backend_version=backend_version,
)
for ddof, (var_root_names, var_output_names) in var_aggs.items()
]
)
Expand Down
6 changes: 3 additions & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import set_axis
from narwhals._pandas_like.utils import set_index
from narwhals._pandas_like.utils import to_datetime
from narwhals.dependencies import is_numpy_scalar
from narwhals.typing import CompliantSeries
Expand Down Expand Up @@ -211,7 +211,7 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
# .copy() is necessary in some pre-2.2 versions of pandas to avoid
# `values` also getting modified (!)
_, values = broadcast_align_and_extract_native(self, values)
values = set_axis(
values = set_index(
values.copy(),
self._native_series.index[indices],
implementation=self._implementation,
Expand Down Expand Up @@ -1423,7 +1423,7 @@ def len(self: Self) -> PandasLikeSeries:
self._compliant_series._implementation is Implementation.PANDAS
and self._compliant_series._backend_version < (3, 0)
): # pragma: no cover
native_result = set_axis(
native_result = set_index(
rename(
native_result,
native_series.name,
Expand Down
43 changes: 38 additions & 5 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def broadcast_align_and_extract_native(
if rhs._native_series.index is not lhs_index:
return (
lhs._native_series,
set_axis(
set_index(
rhs._native_series,
lhs_index,
implementation=rhs._implementation,
Expand Down Expand Up @@ -168,7 +168,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
s = other._native_series
return s.__class__(s.iloc[0], index=index, dtype=s.dtype, name=s.name)
if other._native_series.index is not index:
return set_axis(
return set_index(
other._native_series,
index,
implementation=other._implementation,
Expand Down Expand Up @@ -302,14 +302,17 @@ def native_series_from_iterable(
raise TypeError(msg)


def set_axis(
def set_index(
obj: T,
index: Any,
*,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> T:
"""Wrapper around pandas' set_axis so that we can set `copy` / `inplace` based on implementation/version."""
"""Wrapper around pandas' set_axis to set object index.
We can set `copy` / `inplace` based on implementation/version.
"""
if implementation is Implementation.CUDF: # pragma: no cover
obj = obj.copy(deep=False) # type: ignore[attr-defined]
obj.index = index # type: ignore[attr-defined]
Expand All @@ -329,6 +332,36 @@ def set_axis(
return obj.set_axis(index, axis=0, **kwargs) # type: ignore[attr-defined, no-any-return]


def set_columns(
obj: T,
columns: list[str],
*,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> T:
"""Wrapper around pandas' set_axis to set object columns.
We can set `copy` / `inplace` based on implementation/version.
"""
if implementation is Implementation.CUDF: # pragma: no cover
obj = obj.copy(deep=False) # type: ignore[attr-defined]
obj.columns = columns # type: ignore[attr-defined]
return obj
if implementation is Implementation.PANDAS and (
backend_version < (1,)
): # pragma: no cover
kwargs = {"inplace": False}
else:
kwargs = {}
if implementation is Implementation.PANDAS and (
(1, 5) <= backend_version < (3,)
): # pragma: no cover
kwargs["copy"] = False
else: # pragma: no cover
pass
return obj.set_axis(columns, axis=1, **kwargs) # type: ignore[attr-defined, no-any-return]


def rename(
obj: T,
*args: Any,
Expand Down Expand Up @@ -654,7 +687,7 @@ def broadcast_series(series: Sequence[PandasLikeSeries]) -> list[Any]:

elif s_native.index is not idx:
reindexed.append(
set_axis(
set_index(
s_native,
idx,
implementation=s._implementation,
Expand Down
4 changes: 2 additions & 2 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,13 @@ def maybe_set_index(
df_any._compliant_frame._from_native_frame(native_obj.set_index(keys))
)
elif is_pandas_like_series(native_obj):
from narwhals._pandas_like.utils import set_axis
from narwhals._pandas_like.utils import set_index

if column_names:
msg = "Cannot set index using column names on a Series"
raise ValueError(msg)

native_obj = set_axis(
native_obj = set_index(
native_obj,
keys,
implementation=obj._compliant_series._implementation, # type: ignore[union-attr]
Expand Down

0 comments on commit 5cc5f46

Please sign in to comment.