Skip to content

Commit

Permalink
fix: Make compatible with latest cuDF release (#1640)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: raisadz <[email protected]>
  • Loading branch information
MarcoGorelli and raisadz authored Dec 22, 2024
1 parent 599d93a commit e376dfe
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 48 deletions.
48 changes: 29 additions & 19 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals._pandas_like.utils import create_compliant_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import pivot_table
from narwhals._pandas_like.utils import rename
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import validate_dataframe_comparand
Expand Down Expand Up @@ -853,58 +854,67 @@ def pivot(

if isinstance(on, str):
on = [on]

if isinstance(values, str):
values = [values]
if isinstance(index, str):
index = [index]

if index is None:
index = [c for c in self.columns if c not in {*on, *values}] # type: ignore[misc]

if values is None:
values_ = [c for c in self.columns if c not in {*on, *index}] # type: ignore[misc]
elif isinstance(values, str): # pragma: no cover
values_ = [values]
else:
values_ = values
values = [c for c in self.columns if c not in {*on, *index}]

if aggregate_function is None:
result = frame.pivot(columns=on, index=index, values=values_)

result = frame.pivot(columns=on, index=index, values=values)
elif aggregate_function == "len":
result = (
frame.groupby([*on, *index]) # type: ignore[misc]
.agg({v: "size" for v in values_})
frame.groupby([*on, *index])
.agg({v: "size" for v in values})
.reset_index()
.pivot(columns=on, index=index, values=values_)
.pivot(columns=on, index=index, values=values)
)
else:
result = frame.pivot_table(
values=values_,
result = pivot_table(
df=self,
values=values,
index=index,
columns=on,
aggfunc=aggregate_function,
margins=False,
observed=True,
aggregate_function=aggregate_function,
)

# Put columns in the right order
if sort_columns:
if sort_columns and self._implementation is Implementation.CUDF:
uniques = {
col: sorted(self._native_frame[col].unique().to_arrow().to_pylist())
for col in on
}
elif sort_columns:
uniques = {
col: sorted(self._native_frame[col].unique().tolist()) for col in on
}
elif self._implementation is Implementation.CUDF:
uniques = {
col: self._native_frame[col].unique().to_arrow().to_pylist() for col in on
}
else:
uniques = {col: self._native_frame[col].unique().tolist() for col in on}
all_lists = [values_, *list(uniques.values())]
all_lists = [values, *list(uniques.values())]
ordered_cols = list(product(*all_lists))
result = result.loc[:, ordered_cols]
columns = result.columns.tolist()

n_on = len(on)
if n_on == 1:
new_columns = [
separator.join(col).strip() if len(values_) > 1 else col[-1]
separator.join(col).strip() if len(values) > 1 else col[-1]
for col in columns
]
else:
new_columns = [
separator.join([col[0], '{"' + '","'.join(col[-n_on:]) + '"}'])
if len(values_) > 1
if len(values) > 1
else '{"' + '","'.join(col[-n_on:]) + '"}'
for col in columns
]
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
if (
self._df._implementation is Implementation.PANDAS
and self._df._backend_version < (2, 2)
) or (self._df._implementation is Implementation.CUDF): # pragma: no cover
): # pragma: no cover
for key in indices:
yield (key, self._from_native_frame(self._grouped.get_group(key)))
else:
Expand Down
38 changes: 37 additions & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
T = TypeVar("T")

if TYPE_CHECKING:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals.dtypes import DType
Expand Down Expand Up @@ -614,7 +615,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915
)
)
)
else:
else: # pragma: no cover
msg = (
"Converting to List dtype is not supported for implementation "
f"{implementation} and version {version}."
Expand Down Expand Up @@ -770,3 +771,38 @@ def select_columns_by_name(
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns, available_columns
) from e


def pivot_table(
df: PandasLikeDataFrame,
values: list[str],
index: list[str],
columns: list[str],
aggregate_function: str | None,
) -> Any:
dtypes = import_dtypes_module(df._version)
if df._implementation is Implementation.CUDF:
if any(
x == dtypes.Categorical
for x in df.select(*[*values, *index, *columns]).schema.values()
):
msg = "`pivot` with Categoricals is not implemented for cuDF backend"
raise NotImplementedError(msg)
# cuDF doesn't support `observed` argument
result = df._native_frame.pivot_table(
values=values,
index=index,
columns=columns,
aggfunc=aggregate_function,
margins=False,
)
else:
result = df._native_frame.pivot_table(
values=values,
index=index,
columns=columns,
aggfunc=aggregate_function,
margins=False,
observed=True,
)
return result
4 changes: 4 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,10 @@ def pivot(
│ 2 ┆ 4 ┆ 1 ┆ 0 ┆ 4 │
└─────┴───────┴───────┴───────┴───────┘
"""
if values is None and index is None:
msg = "At least one of `values` and `index` must be passed"
raise ValueError(msg)

return self._from_compliant_dataframe(
self._compliant_frame.pivot(
on=on,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]

[project.optional-dependencies]
cudf = ["cudf>=23.08.00"]
cudf = ["cudf>=24.12.0"]
modin = ["modin"]
pandas = ["pandas>=0.25.3"]
polars = ["polars>=0.20.3"]
Expand Down Expand Up @@ -188,7 +188,7 @@ omit = [
exclude_also = [
"if sys.version_info() <",
"if (:?self._)?implementation is Implementation.MODIN",
"if (:?self._)?implementation is Implementation.CUDF",
"if .*implementation is Implementation.CUDF",
'request.applymarker\(pytest.mark.xfail',
'\w+._backend_version < ',
'backend_version <',
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/over_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_over_cummin(request: pytest.FixtureRequest, constructor: Constructor) -


def test_over_cumprod(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor):
if any(x in str(constructor) for x in ("pyarrow_table", "dask_lazy_p2", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1):
Expand Down
45 changes: 38 additions & 7 deletions tests/frame/pivot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {
Expand Down Expand Up @@ -115,14 +116,14 @@
)
@pytest.mark.parametrize(("on", "index"), [("col", "ix"), (["col"], ["ix"])])
def test_pivot(
constructor_eager: Any,
constructor_eager: ConstructorEager,
agg_func: str,
expected: dict[str, list[Any]],
on: str | list[str],
index: str | list[str],
request: pytest.FixtureRequest,
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin", "cudf")):
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
Expand All @@ -149,7 +150,7 @@ def test_pivot(
],
)
def test_pivot_no_agg(
request: Any, constructor_eager: Any, data_: Any, context: Any
request: Any, constructor_eager: ConstructorEager, data_: Any, context: Any
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")):
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -177,9 +178,12 @@ def test_pivot_no_agg(
],
)
def test_pivot_sort_columns(
request: Any, constructor_eager: Any, sort_columns: Any, expected: list[str]
request: Any,
constructor_eager: ConstructorEager,
sort_columns: Any,
expected: list[str],
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin", "cudf")):
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
Expand Down Expand Up @@ -227,9 +231,9 @@ def test_pivot_sort_columns(
],
)
def test_pivot_names_out(
request: Any, constructor_eager: Any, kwargs: Any, expected: list[str]
request: Any, constructor_eager: ConstructorEager, kwargs: Any, expected: list[str]
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin", "cudf")):
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
Expand All @@ -243,3 +247,30 @@ def test_pivot_names_out(
df.pivot(aggregate_function="min", index="ix", **kwargs).collect_schema().names()
)
assert result == expected


def test_pivot_no_index_no_values(constructor_eager: ConstructorEager) -> None:
df = nw.from_native(constructor_eager(data_no_dups), eager_only=True)
with pytest.raises(ValueError, match="At least one of `values` and `index` must"):
df.pivot(on="col")


def test_pivot_no_index(
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
):
# not implemented
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager(data_no_dups), eager_only=True)
result = df.pivot(on="col", values="foo").sort("ix", "bar")
expected = {
"ix": [1, 1, 2, 2],
"bar": ["x", "y", "w", "z"],
"a": [1.0, float("nan"), float("nan"), 3.0],
"b": [float("nan"), 2.0, 4.0, float("nan")],
}
assert_equal_data(result, expected)
5 changes: 0 additions & 5 deletions tests/frame/unpivot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ def test_unpivot_var_value_names(
constructor: Constructor,
variable_name: str | None,
value_name: str | None,
request: pytest.FixtureRequest,
) -> None:
if variable_name == "" and "cudf" in str(constructor):
# https://github.com/rapidsai/cudf/issues/16972
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.unpivot(
on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name
Expand Down
24 changes: 12 additions & 12 deletions tests/group_by_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def test_invalid_group_by() -> None:
)


def test_group_by_iter(constructor_eager: ConstructorEager) -> None:
def test_group_by_iter(
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
) -> None:
if "cudf" in str(constructor_eager):
# https://github.com/rapidsai/cudf/issues/17650
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor_eager(data), eager_only=True)
expected_keys = [(1,), (3,)]
keys = []
Expand Down Expand Up @@ -117,8 +122,6 @@ def test_group_by_depth_1_agg(
expected: dict[str, list[int | float]],
request: pytest.FixtureRequest,
) -> None:
if "cudf" in str(constructor) and attr == "n_unique":
request.applymarker(pytest.mark.xfail)
if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1):
# Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations"
request.applymarker(pytest.mark.xfail)
Expand All @@ -140,13 +143,7 @@ def test_group_by_median(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_group_by_n_unique_w_missing(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "cudf" in str(constructor):
# Issue in cuDF https://github.com/rapidsai/cudf/issues/16861
request.applymarker(pytest.mark.xfail)

def test_group_by_n_unique_w_missing(constructor: Constructor) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
Expand Down Expand Up @@ -294,6 +291,9 @@ def test_key_with_nulls_iter(
if PANDAS_VERSION < (1, 3) and "pandas_constructor" in str(constructor_eager):
# bug in old pandas
request.applymarker(pytest.mark.xfail)
if "cudf" in str(constructor_eager):
# https://github.com/rapidsai/cudf/issues/17650
request.applymarker(pytest.mark.xfail)
data = {"b": ["4", "5", None, "7"], "a": [1, 2, 3, 4], "c": ["4", "3", None, None]}
result = dict(
nw.from_native(constructor_eager(data), eager_only=True)
Expand Down Expand Up @@ -369,10 +369,10 @@ def test_group_by_shift_raises(
def test_double_same_aggregation(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor) or "modin" in str(constructor):
if any(x in str(constructor) for x in ("dask", "modin", "cudf")):
# bugged in dask https://github.com/dask/dask/issues/11612
# and modin lol https://github.com/modin-project/modin/issues/7414
# At least cudf gets it right
# and cudf https://github.com/rapidsai/cudf/issues/17649
request.applymarker(pytest.mark.xfail)
if "pandas" in str(constructor) and PANDAS_VERSION < (1,):
request.applymarker(pytest.mark.xfail)
Expand Down

0 comments on commit e376dfe

Please sign in to comment.