Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DataFrame and LazyFrame explode method #1542

Merged
merged 17 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- drop
- drop_nulls
- estimated_size
- explode
- filter
- gather_every
- get_column
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference/lazyframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- columns
- drop
- drop_nulls
- explode
- filter
- gather_every
- group_by
Expand Down
78 changes: 78 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import scale_bytes
Expand Down Expand Up @@ -752,3 +753,80 @@ def unpivot(
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyarrow has two paths:

  • if nulls or empty lists are not present, then it is enough to:
    1. make sure element counts are same
    2. explode each array individually
  • if nulls or empty lists are present, then these are ignore by pc.list_parent_indices and pc.list_flatten, which is a problem. This implementation falls back to a python list both to flatten the array(s) and to create the corresponding indices .

After flattening, a new table is created by take-ing the indices of the non-flattened arrays and the flattened arrays.

import pyarrow as pa
import pyarrow.compute as pc

from narwhals.exceptions import InvalidOperationError

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)

schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = f"`explode` operation not supported for dtype `{dtype}`"
raise InvalidOperationError(msg)

native_frame = self._native_frame
counts = pc.list_value_length(native_frame[to_explode[0]])

if not all(
pc.all(pc.equal(pc.list_value_length(native_frame[col_name]), counts)).as_py()
for col_name in to_explode[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]
fast_path = pc.all(pc.greater_equal(counts, 1)).as_py()

if fast_path:
indices = pc.list_parent_indices(native_frame[to_explode[0]])
flatten_func = pc.list_flatten

else:
indices = pa.array(
[
i
for i, count in enumerate(counts.to_pylist())
for _ in range(max(count or 1, 1))
]
)
parent_indices = pc.list_parent_indices(native_frame[to_explode[0]])
is_valid_index = pc.is_in(indices, value_set=parent_indices)
exploded_size = len(is_valid_index)

def flatten_func(array: pa.ChunkedArray) -> pa.ChunkedArray:
dtype = array.type.value_type

return pc.replace_with_mask(
pa.array([None] * exploded_size, type=dtype),
is_valid_index,
pc.list_flatten(array).combine_chunks(),
)

arrays = [
native_frame[col_name].take(indices)
if col_name in other_columns
else flatten_func(native_frame[col_name])
for col_name in original_columns
]

return self._from_native_frame(
pa.Table.from_arrays(
arrays=arrays,
names=original_columns,
)
)
49 changes: 49 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,3 +949,52 @@ def unpivot(
value_name=value_name if value_name is not None else "value",
)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a single column is to be exploded, then we use the pandas native method. If multiple columns, the strategy is to explode the one column with the rest of the dataframe, and the other series individually and finally concatenating them back, plus sorting by original column names order

from narwhals.exceptions import InvalidOperationError

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = f"`explode` operation not supported for dtype `{dtype}`"
raise InvalidOperationError(msg)
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

if len(to_explode) == 1:
return self._from_native_frame(self._native_frame.explode(to_explode[0]))
else:
native_frame = self._native_frame
anchor_series = native_frame[to_explode[0]].list.len()

if not all(
(native_frame[col_name].list.len() == anchor_series).all()
for col_name in to_explode[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]

exploded_frame = native_frame[[*other_columns, to_explode[0]]].explode(
to_explode[0]
)
exploded_series = [
native_frame[col_name].explode().to_frame() for col_name in to_explode[1:]
]

plx = self.__native_namespace__()

return self._from_native_frame(
plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns]
)
134 changes: 132 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ def __eq__(self, other: object) -> NoReturn:
)
raise NotImplementedError(msg)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
return self._from_compliant_dataframe(
self._compliant_frame.explode(
columns,
*more_columns,
)
)


class DataFrame(BaseFrame[DataFrameT]):
"""Narwhals DataFrame, backed by a native dataframe.
Expand Down Expand Up @@ -576,8 +584,6 @@ def to_pandas(self) -> pd.DataFrame:
0 1 6.0 a
1 2 7.0 b
2 3 8.0 c


"""
return self._compliant_frame.to_pandas()

Expand Down Expand Up @@ -3002,6 +3008,77 @@ def unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.

Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.

Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.

Returns:
New DataFrame

Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoDataFrameT
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }

We define a library agnostic function:

>>> def agnostic_explode(df_native: IntoDataFrameT) -> IntoDataFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )

We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:

>>> agnostic_explode(pd.DataFrame(data))
a lst1 lst2
0 x 1 3
0 x 2 <NA>
1 y <NA> <NA>
2 z <NA> 42
3 w <NA> <NA>
>>> agnostic_explode(pl.DataFrame(data))
shape: (5, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ lst1 ┆ lst2 β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i32 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ══════β•ͺ══════║
β”‚ x ┆ 1 ┆ 3 β”‚
β”‚ x ┆ 2 ┆ null β”‚
β”‚ y ┆ null ┆ null β”‚
β”‚ z ┆ null ┆ 42 β”‚
β”‚ w ┆ null ┆ null β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
>>> agnostic_explode(pa.table(data))
pyarrow.Table
a: string
lst1: int32
lst2: int32
----
a: [["x","x","y","z","w"]]
lst1: [[1,2,null,null,null]]
lst2: [[3,null,null,42,null]]
"""
return super().explode(columns, *more_columns)


class LazyFrame(BaseFrame[FrameT]):
"""Narwhals DataFrame, backed by a native dataframe.
Expand Down Expand Up @@ -4720,3 +4797,56 @@ def unpivot(
return super().unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
"""Explode the dataframe to long format by exploding the given columns.

Notes:
It is possible to explode multiple columns only if these columns must have
matching element counts.

Arguments:
columns: Column names. The underlying columns being exploded must be of the `List` data type.
*more_columns: Additional names of columns to explode, specified as positional arguments.

Returns:
New LazyFrame

Examples:
>>> import narwhals as nw
>>> from narwhals.typing import IntoFrameT
>>> import polars as pl
>>> data = {
... "a": ["x", "y", "z", "w"],
... "lst1": [[1, 2], None, [None], []],
... "lst2": [[3, None], None, [42], []],
... }

We define a library agnostic function:

>>> def agnostic_explode(df_native: IntoFrameT) -> IntoFrameT:
... return (
... nw.from_native(df_native)
... .with_columns(nw.col("lst1", "lst2").cast(nw.List(nw.Int32())))
... .explode("lst1", "lst2")
... .to_native()
... )

We can then pass any supported library such as pandas, Polars (eager),
or PyArrow to `agnostic_explode`:

>>> agnostic_explode(pl.LazyFrame(data)).collect()
shape: (5, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚ a ┆ lst1 ┆ lst2 β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i32 ┆ i32 β”‚
β•žβ•β•β•β•β•β•ͺ══════β•ͺ══════║
β”‚ x ┆ 1 ┆ 3 β”‚
β”‚ x ┆ 2 ┆ null β”‚
β”‚ y ┆ null ┆ null β”‚
β”‚ z ┆ null ┆ 42 β”‚
β”‚ w ┆ null ┆ null β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜
"""
return super().explode(columns, *more_columns)
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def from_missing_and_available_column_names(
return ColumnNotFoundError(message)


class ShapeError(Exception):
"""Exception raised when trying to perform operations on data structures with incompatible shapes."""


class InvalidOperationError(Exception):
"""Exception raised during invalid operations."""

Expand Down
Loading
Loading