Skip to content

Commit

Permalink
feat: add DataFrame and LazyFrame explode method (#1542)
Browse files Browse the repository at this point in the history
* feat: DataFrame and LazyFrame explode

* arrow refactor

* raise for invalid type and docstrings

* Update narwhals/dataframe.py

* old versions

* almost all native

* doctest

* better error message, fail for arrow with nulls

* doctest-modules

* completely remove pyarrow implementation
  • Loading branch information
FBruzzesi authored Dec 22, 2024
1 parent e376dfe commit e112a99
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- drop
- drop_nulls
- estimated_size
- explode
- filter
- gather_every
- get_column
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/lazyframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- columns
- drop
- drop_nulls
- explode
- filter
- gather_every
- group_by
Expand Down
52 changes: 52 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,55 @@ def unpivot(
value_name=value_name if value_name is not None else "value",
)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
from narwhals.exceptions import InvalidOperationError

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)

if len(to_explode) == 1:
return self._from_native_frame(self._native_frame.explode(to_explode[0]))
else:
native_frame = self._native_frame
anchor_series = native_frame[to_explode[0]].list.len()

if not all(
(native_frame[col_name].list.len() == anchor_series).all()
for col_name in to_explode[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]

exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode(
to_explode[0]
)
exploded_series = [
native_frame[col_name].explode().to_frame() for col_name in to_explode[1:]
]

plx = self.__native_namespace__()

return self._from_native_frame(
plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns]
)
125 changes: 123 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def __eq__(self, other: object) -> NoReturn:
)
raise NotImplementedError(msg)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
return self._from_compliant_dataframe(
self._compliant_frame.explode(
columns,
*more_columns,
)
)


class DataFrame(BaseFrame[DataFrameT]):
"""Narwhals DataFrame, backed by a native eager dataframe.
Expand Down Expand Up @@ -592,8 +600,6 @@ def to_pandas(self) -> pd.DataFrame:
0 1 6.0 a
1 2 7.0 b
2 3 8.0 c
"""
return self._compliant_frame.to_pandas()

Expand Down Expand Up @@ -3129,6 +3135,68 @@ def unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.
Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.
Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.
Returns:
New DataFrame
Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoDataFrameT
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }
We define a library agnostic function:
>>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )
We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:
>>> agnostic_explode(pd.DataFrame(data))
a lst1 lst2
0 x 1 3
0 x 2 <NA>
1 y <NA> <NA>
2 z <NA> 42
3 w <NA> <NA>
>>> agnostic_explode(pl.DataFrame(data))
shape: (5, 3)
┌─────┬──────┬──────┐
│ a ┆ lst1 ┆ lst2 │
│ --- ┆ --- ┆ --- │
│ str ┆ i32 ┆ i32 │
╞═════╪══════╪══════╡
│ x ┆ 1 ┆ 3 │
│ x ┆ 2 ┆ null │
│ y ┆ null ┆ null │
│ z ┆ null ┆ 42 │
│ w ┆ null ┆ null │
└─────┴──────┴──────┘
"""
return super().explode(columns, *more_columns)


class LazyFrame(BaseFrame[FrameT]):
"""Narwhals LazyFrame, backed by a native lazyframe.
Expand Down Expand Up @@ -4914,3 +4982,56 @@ def unpivot(
return super().unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.
Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.
Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.
Returns:
New LazyFrame
Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoFrameT
>>> import polars as pl
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }
We define a library agnostic function:
>>> def agnostic_explode(df_native: IntoFrameT) -> IntoFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )
We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:
>>> agnostic_explode(pl.LazyFrame(data)).collect()
shape: (5, 3)
┌─────┬──────┬──────┐
│ a ┆ lst1 ┆ lst2 │
│ --- ┆ --- ┆ --- │
│ str ┆ i32 ┆ i32 │
╞═════╪══════╪══════╡
│ x ┆ 1 ┆ 3 │
│ x ┆ 2 ┆ null │
│ y ┆ null ┆ null │
│ z ┆ null ┆ 42 │
│ w ┆ null ┆ null │
└─────┴──────┴──────┘
"""
return super().explode(columns, *more_columns)
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def from_missing_and_available_column_names(
return ColumnNotFoundError(message)


class ShapeError(Exception):
"""Exception raised when trying to perform operations on data structures with incompatible shapes."""


class InvalidOperationError(Exception):
"""Exception raised during invalid operations."""

Expand Down
146 changes: 146 additions & 0 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from __future__ import annotations

from typing import Sequence

import pytest
from polars.exceptions import InvalidOperationError as PlInvalidOperationError
from polars.exceptions import ShapeError as PlShapeError

import narwhals.stable.v1 as nw
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import ShapeError
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

# For context, polars allows to explode multiple columns only if the columns
# have matching element counts, therefore, l1 and l2 but not l1 and l3 together.
data = {
"a": ["x", "y", "z", "w"],
"l1": [[1, 2], None, [None], []],
"l2": [[3, None], None, [42], []],
"l3": [[1, 2], [3], [None], [1]],
"l4": [[1, 2], [3], [123], [456]],
}


@pytest.mark.parametrize(
("column", "expected_values"),
[
("l2", [3, None, None, 42, None]),
("l3", [1, 2, 3, None, 1]), # fast path for arrow
],
)
def test_explode_single_col(
request: pytest.FixtureRequest,
constructor: Constructor,
column: str,
expected_values: list[int | None],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

result = (
nw.from_native(constructor(data))
.with_columns(nw.col(column).cast(nw.List(nw.Int32())))
.explode(column)
.select("a", column)
)
expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values}
assert_equal_data(result, expected)


@pytest.mark.parametrize(
("columns", "more_columns", "expected"),
[
(
"l1",
["l2"],
{
"a": ["x", "x", "y", "z", "w"],
"l1": [1, 2, None, None, None],
"l2": [3, None, None, 42, None],
},
),
(
"l3",
["l4"],
{
"a": ["x", "x", "y", "z", "w"],
"l3": [1, 2, 3, None, 1],
"l4": [1, 2, 3, 123, 456],
},
),
],
)
def test_explode_multiple_cols(
request: pytest.FixtureRequest,
constructor: Constructor,
columns: str | Sequence[str],
more_columns: Sequence[str],
expected: dict[str, list[str | int | None]],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

result = (
nw.from_native(constructor(data))
.with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32())))
.explode(columns, *more_columns)
.select("a", columns, *more_columns)
)
assert_equal_data(result, expected)


def test_explode_shape_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(ShapeError, PlShapeError),
match="exploded columns must have matching element counts",
):
_ = (
nw.from_native(constructor(data))
.lazy()
.with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32())))
.explode("l1", "l3")
.collect()
)


def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "dask" in str(constructor) or "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(InvalidOperationError, PlInvalidOperationError),
match="`explode` operation not supported for dtype",
):
_ = nw.from_native(constructor(data)).lazy().explode("a").collect()

0 comments on commit e112a99

Please sign in to comment.