Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ArrowDataFrame.explode #1644

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import scale_bytes
Expand Down Expand Up @@ -752,3 +753,85 @@ def unpivot(
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
import pyarrow as pa
import pyarrow.compute as pc

from narwhals.exceptions import InvalidOperationError

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)

schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)

raise InvalidOperationError(msg)

native_frame = self._native_frame
counts = pc.list_value_length(native_frame[to_explode[0]])

if not all(
pc.all(pc.equal(pc.list_value_length(native_frame[col_name]), counts)).as_py()
for col_name in to_explode[1:]
):
from narwhals.exceptions import ShapeError

msg = "exploded columns must have matching element counts"
raise ShapeError(msg)

original_columns = self.columns
other_columns = [c for c in original_columns if c not in to_explode]
fast_path = pc.all(pc.greater_equal(counts, 1)).as_py()

if fast_path:
indices = pc.list_parent_indices(native_frame[to_explode[0]])
flatten_func = pc.list_flatten

else:
filled_counts = pc.max_element_wise(counts, 1, skip_nulls=True)
indices = pa.array(
[
i
for i, count in enumerate(filled_counts.to_pylist())
for _ in range(count)
]
)

parent_indices = pc.list_parent_indices(native_frame[to_explode[0]])
is_valid_index = pc.is_in(indices, value_set=parent_indices)
exploded_size = len(is_valid_index)

def flatten_func(array: pa.ChunkedArray) -> pa.ChunkedArray:
dtype = array.type.value_type
return pc.replace_with_mask(
pa.repeat(pa.scalar(None, type=dtype), exploded_size),
is_valid_index,
pc.list_flatten(array).combine_chunks(),
)

arrays = [
native_frame[col_name].take(indices)
if col_name in other_columns
else flatten_func(native_frame[col_name])
for col_name in original_columns
]

return self._from_native_frame(
pa.Table.from_arrays(
arrays=arrays,
names=original_columns,
)
)
17 changes: 4 additions & 13 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ def test_explode_single_col(
column: str,
expected_values: list[int | None],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand Down Expand Up @@ -87,10 +84,7 @@ def test_explode_multiple_cols(
more_columns: Sequence[str],
expected: dict[str, list[str | int | None]],
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand All @@ -108,10 +102,7 @@ def test_explode_multiple_cols(
def test_explode_shape_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand All @@ -133,7 +124,7 @@ def test_explode_shape_error(
def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if "dask" in str(constructor) or "pyarrow_table" in str(constructor):
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
Expand Down
Loading