-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: DataFrame and LazyFrame explode * arrow refactor * raise for invalid type and docstrings * Update narwhals/dataframe.py * old versions * almost all native * doctest * better error message, fail for arrow with nulls * doctest-modules * completely remove pyarrow implementation
- Loading branch information
Showing
6 changed files
with
327 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
- drop | ||
- drop_nulls | ||
- estimated_size | ||
- explode | ||
- filter | ||
- gather_every | ||
- get_column | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
- columns | ||
- drop | ||
- drop_nulls | ||
- explode | ||
- filter | ||
- gather_every | ||
- group_by | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Sequence | ||
|
||
import pytest | ||
from polars.exceptions import InvalidOperationError as PlInvalidOperationError | ||
from polars.exceptions import ShapeError as PlShapeError | ||
|
||
import narwhals.stable.v1 as nw | ||
from narwhals.exceptions import InvalidOperationError | ||
from narwhals.exceptions import ShapeError | ||
from tests.utils import PANDAS_VERSION | ||
from tests.utils import POLARS_VERSION | ||
from tests.utils import Constructor | ||
from tests.utils import assert_equal_data | ||
|
||
# For context, polars allows to explode multiple columns only if the columns | ||
# have matching element counts, therefore, l1 and l2 but not l1 and l3 together. | ||
data = { | ||
"a": ["x", "y", "z", "w"], | ||
"l1": [[1, 2], None, [None], []], | ||
"l2": [[3, None], None, [42], []], | ||
"l3": [[1, 2], [3], [None], [1]], | ||
"l4": [[1, 2], [3], [123], [456]], | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("column", "expected_values"), | ||
[ | ||
("l2", [3, None, None, 42, None]), | ||
("l3", [1, 2, 3, None, 1]), # fast path for arrow | ||
], | ||
) | ||
def test_explode_single_col( | ||
request: pytest.FixtureRequest, | ||
constructor: Constructor, | ||
column: str, | ||
expected_values: list[int | None], | ||
) -> None: | ||
if any( | ||
backend in str(constructor) | ||
for backend in ("dask", "modin", "cudf", "pyarrow_table") | ||
): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
result = ( | ||
nw.from_native(constructor(data)) | ||
.with_columns(nw.col(column).cast(nw.List(nw.Int32()))) | ||
.explode(column) | ||
.select("a", column) | ||
) | ||
expected = {"a": ["x", "x", "y", "z", "w"], column: expected_values} | ||
assert_equal_data(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("columns", "more_columns", "expected"), | ||
[ | ||
( | ||
"l1", | ||
["l2"], | ||
{ | ||
"a": ["x", "x", "y", "z", "w"], | ||
"l1": [1, 2, None, None, None], | ||
"l2": [3, None, None, 42, None], | ||
}, | ||
), | ||
( | ||
"l3", | ||
["l4"], | ||
{ | ||
"a": ["x", "x", "y", "z", "w"], | ||
"l3": [1, 2, 3, None, 1], | ||
"l4": [1, 2, 3, 123, 456], | ||
}, | ||
), | ||
], | ||
) | ||
def test_explode_multiple_cols( | ||
request: pytest.FixtureRequest, | ||
constructor: Constructor, | ||
columns: str | Sequence[str], | ||
more_columns: Sequence[str], | ||
expected: dict[str, list[str | int | None]], | ||
) -> None: | ||
if any( | ||
backend in str(constructor) | ||
for backend in ("dask", "modin", "cudf", "pyarrow_table") | ||
): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
result = ( | ||
nw.from_native(constructor(data)) | ||
.with_columns(nw.col(columns, *more_columns).cast(nw.List(nw.Int32()))) | ||
.explode(columns, *more_columns) | ||
.select("a", columns, *more_columns) | ||
) | ||
assert_equal_data(result, expected) | ||
|
||
|
||
def test_explode_shape_error( | ||
request: pytest.FixtureRequest, constructor: Constructor | ||
) -> None: | ||
if any( | ||
backend in str(constructor) | ||
for backend in ("dask", "modin", "cudf", "pyarrow_table") | ||
): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
with pytest.raises( | ||
(ShapeError, PlShapeError), | ||
match="exploded columns must have matching element counts", | ||
): | ||
_ = ( | ||
nw.from_native(constructor(data)) | ||
.lazy() | ||
.with_columns(nw.col("l1", "l2", "l3").cast(nw.List(nw.Int32()))) | ||
.explode("l1", "l3") | ||
.collect() | ||
) | ||
|
||
|
||
def test_explode_invalid_operation_error( | ||
request: pytest.FixtureRequest, constructor: Constructor | ||
) -> None: | ||
if "dask" in str(constructor) or "pyarrow_table" in str(constructor): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): | ||
request.applymarker(pytest.mark.xfail) | ||
|
||
with pytest.raises( | ||
(InvalidOperationError, PlInvalidOperationError), | ||
match="`explode` operation not supported for dtype", | ||
): | ||
_ = nw.from_native(constructor(data)).lazy().explode("a").collect() |