Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DataFrame and LazyFrame explode method #1542

Merged
merged 17 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- drop
- drop_nulls
- estimated_size
- explode
- filter
- gather_every
- get_column
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/lazyframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- columns
- drop
- drop_nulls
- explode
- filter
- gather_every
- group_by
Expand Down
52 changes: 52 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,3 +949,55 @@ def unpivot(
value_name=value_name if value_name is not None else "value",
)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a single column is to be exploded, then we use the pandas native method. If multiple columns, the strategy is to explode the one column with the rest of the dataframe, and the other series individually and finally concatenating them back, plus sorting by original column names order

from narwhals.exceptions import InvalidOperationError

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)

if len(to_explode) == 1:
return self._from_native_frame(self._native_frame.explode(to_explode[0]))
else:
native_frame = self._native_frame
anchor_series = native_frame[to_explode[0]].list.len()

if not all(
(native_frame[col_name].list.len() == anchor_series).all()
for col_name in to_explode[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]

exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode(
to_explode[0]
)
exploded_series = [
native_frame[col_name].explode().to_frame() for col_name in to_explode[1:]
]

plx = self.__native_namespace__()

return self._from_native_frame(
plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns]
)
125 changes: 123 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def __eq__(self, other: object) -> NoReturn:
)
raise NotImplementedError(msg)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
return self._from_compliant_dataframe(
self._compliant_frame.explode(
columns,
*more_columns,
)
)


class DataFrame(BaseFrame[DataFrameT]):
"""Narwhals DataFrame, backed by a native eager dataframe.
Expand Down Expand Up @@ -592,8 +600,6 @@ def to_pandas(self) -> pd.DataFrame:
0 1 6.0 a
1 2 7.0 b
2 3 8.0 c


"""
return self._compliant_frame.to_pandas()

Expand Down Expand Up @@ -3125,6 +3131,68 @@ def unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.

Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.

Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.

Returns:
New DataFrame

Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoDataFrameT
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }

We define a library agnostic function:

>>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )

We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:

>>> agnostic_explode(pd.DataFrame(data))
a lst1 lst2
0 x 1 3
0 x 2 <NA>
1 y <NA> <NA>
2 z <NA> 42
3 w <NA> <NA>
>>> agnostic_explode(pl.DataFrame(data))
shape: (5, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ lst1 ┆ lst2 β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i32 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ══════β•ͺ══════║
β”‚ x ┆ 1 ┆ 3 β”‚
β”‚ x ┆ 2 ┆ null β”‚
β”‚ y ┆ null ┆ null β”‚
β”‚ z ┆ null ┆ 42 β”‚
β”‚ w ┆ null ┆ null β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
"""
return super().explode(columns, *more_columns)


class LazyFrame(BaseFrame[FrameT]):
"""Narwhals LazyFrame, backed by a native lazyframe.
Expand Down Expand Up @@ -4910,3 +4978,56 @@ def unpivot(
return super().unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.

Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.

Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.

Returns:
New LazyFrame

Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoFrameT
>>> import polars as pl
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }

We define a library agnostic function:

>>> def agnostic_explode(df_native: IntoFrameT) -> IntoFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )

We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:

>>> agnostic_explode(pl.LazyFrame(data)).collect()
shape: (5, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ lst1 ┆ lst2 β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i32 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ══════β•ͺ══════║
β”‚ x ┆ 1 ┆ 3 β”‚
β”‚ x ┆ 2 ┆ null β”‚
β”‚ y ┆ null ┆ null β”‚
β”‚ z ┆ null ┆ 42 β”‚
β”‚ w ┆ null ┆ null β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
"""
return super().explode(columns, *more_columns)
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def from_missing_and_available_column_names(
return ColumnNotFoundError(message)


class ShapeError(Exception):
"""Exception raised when trying to perform operations on data structures with incompatible shapes."""


class InvalidOperationError(Exception):
"""Exception raised during invalid operations."""

Expand Down
146 changes: 146 additions & 0 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from typing import Sequence

import pytest
from polars.exceptions import InvalidOperationError as PlInvalidOperationError
from polars.exceptions import ShapeError as PlShapeError

import narwhals.stable.v1 as nw
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import ShapeError
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

# For context, polars allows to explode multiple columns only if the columns
# have matching element counts, therefore, l1 and l2 but not l1 and l3 together.
data = {
"a": ["x", "y", "z", "w"],
"l1": [[1, 2], None, [None], []],
"l2": [[3, None], None, [42], []],
"l3": [[1, 2], [3], [None], [1]],
"l4": [[1, 2], [3], [123], [456]],
}


@pytest.mark.parametrize(
("column", "expected_values"),
[
("l2", [3, None, None, 42, None]),
("l3", [1, 2, 3, None, 1]), # fast path for arrow
],
)
def test_explode_single_col(
request: pytest.FixtureRequest,
constructor: Constructor,
column: str,
expected_values: list[int | None],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

result = (
nw.from_native(constructor(data))
.with_columns(nw.col(column).cast(nw.List(nw.Int32())))
.explode(column)
.select("a", column)
)
expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("columns", "more_columns", "expected"),
[
(
"l1",
["l2"],
{
"a": ["x", "x", "y", "z", "w"],
"l1": [1, 2, None, None, None],
"l2": [3, None, None, 42, None],
},
),
(
"l3",
["l4"],
{
"a": ["x", "x", "y", "z", "w"],
"l3": [1, 2, 3, None, 1],
"l4": [1, 2, 3, 123, 456],
},
),
],
)
def test_explode_multiple_cols(
request: pytest.FixtureRequest,
constructor: Constructor,
columns: str | Sequence[str],
more_columns: Sequence[str],
expected: dict[str, list[str | int | None]],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

result = (
nw.from_native(constructor(data))
.with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32())))
.explode(columns, *more_columns)
.select("a", columns, *more_columns)
)
assert_equal_data(result, expected)


def test_explode_shape_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(ShapeError, PlShapeError),
match="exploded columns must have matching element counts",
):
_ = (
nw.from_native(constructor(data))
.lazy()
.with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32())))
.explode("l1", "l3")
.collect()
)


def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "dask" in str(constructor) or "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(InvalidOperationError, PlInvalidOperationError),
match="`explode` operation not supported for dtype",
):
_ = nw.from_native(constructor(data)).lazy().explode("a").collect()
Loading