Skip to content

Commit

Permalink
Add basic tests of dataframe scan (#16003)
Browse files Browse the repository at this point in the history
Also assert that unsupported file scan operations raise.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - https://github.com/brandon-b-miller

URL: #16003
  • Loading branch information
wence- authored Jun 19, 2024
1 parent c83e5b3 commit f536e30
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 3 deletions.
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def __post_init__(self) -> None:
if self.file_options.n_rows is not None:
raise NotImplementedError("row limit in scan")
if self.typ not in ("csv", "parquet"):
raise NotImplementedError(f"Unhandled scan type: {self.typ}")
raise NotImplementedError(
f"Unhandled scan type: {self.typ}"
) # pragma: no cover; polars raises on the rust side for now

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand Down
34 changes: 33 additions & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from polars.testing.asserts import assert_frame_equal

from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

if TYPE_CHECKING:
from collections.abc import Mapping
Expand All @@ -19,7 +20,7 @@

from cudf_polars.typing import OptimizationArgs

__all__: list[str] = ["assert_gpu_result_equal"]
__all__: list[str] = ["assert_gpu_result_equal", "assert_ir_translation_raises"]


def assert_gpu_result_equal(
Expand Down Expand Up @@ -84,3 +85,34 @@ def assert_gpu_result_equal(
atol=atol,
categorical_as_str=categorical_as_str,
)


def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) -> None:
"""
Assert that translation of a query raises an exception.
Parameters
----------
q
Query to translate.
exceptions
Exceptions that one expects might be raised.
Returns
-------
None
If translation successfully raised the specified exceptions.
Raises
------
AssertionError
If the specified exceptions were not raised.
"""
try:
_ = translate_ir(q._ldf.visit())
except exceptions:
return
except Exception as e:
raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e
else:
raise AssertionError(f"Translation DID NOT RAISE {exceptions}")
18 changes: 18 additions & 0 deletions python/cudf_polars/docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,24 @@ def test_whatever():
assert_gpu_result_equal(query)
```

## Test coverage and asserting failure modes

Where translation of a query should fail due to the feature being
unsupported we should test this. To assert that _translation_ raises
an exception (usually `NotImplementedError`), use the utility function
`assert_ir_translation_raises`:

```python
from cudf_polars.testing.asserts import assert_ir_translation_raises


def test_whatever():
unsupported_query = ...
assert_ir_translation_raises(unsupported_query, NotImplementedError)
```

This test will fail if translation does not raise.

# Debugging

If the callback execution fails during the polars `collect` call, we
Expand Down
43 changes: 43 additions & 0 deletions python/cudf_polars/tests/test_dataframescan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"subset",
[
None,
["a", "c"],
["b", "c", "d"],
["b", "d"],
["b", "c"],
["c", "e"],
["d", "e"],
pl.selectors.string(),
pl.selectors.integer(),
],
)
@pytest.mark.parametrize("predicate_pushdown", [False, True])
def test_scan_drop_nulls(subset, predicate_pushdown):
df = pl.LazyFrame(
{
"a": [1, 2, 3, 4],
"b": [None, 4, 5, None],
"c": [6, 7, None, None],
"d": [8, None, 9, 10],
"e": [None, None, "A", None],
}
)
# Drop nulls are pushed into filters
q = df.drop_nulls(subset)

assert_gpu_result_equal(
q, collect_kwargs={"predicate_pushdown": predicate_pushdown}
)
13 changes: 12 additions & 1 deletion python/cudf_polars/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal
from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)


@pytest.fixture(
Expand Down Expand Up @@ -86,3 +89,11 @@ def test_scan(df, columns, mask):
if columns is not None:
q = df.select(*columns)
assert_gpu_result_equal(q)


def test_scan_unsupported_raises(tmp_path):
df = pl.DataFrame({"a": [1, 2, 3]})

df.write_ndjson(tmp_path / "df.json")
q = pl.scan_ndjson(tmp_path / "df.json")
assert_ir_translation_raises(q, NotImplementedError)
6 changes: 6 additions & 0 deletions python/cudf_polars/tests/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

__all__: list[str] = []
35 changes: 35 additions & 0 deletions python/cudf_polars/tests/testing/test_asserts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)


def test_translation_assert_raises():
df = pl.LazyFrame({"a": [1, 2, 3]})

# This should succeed
assert_gpu_result_equal(df)

with pytest.raises(AssertionError):
# This should fail, because we can translate this query.
assert_ir_translation_raises(df, NotImplementedError)

class E(Exception):
pass

unsupported = df.group_by("a").agg(pl.col("a").cum_max().alias("b"))
# Unsupported query should raise NotImplementedError
assert_ir_translation_raises(unsupported, NotImplementedError)

with pytest.raises(AssertionError):
# This should fail, because we can't translate this query, but it doesn't raise E.
assert_ir_translation_raises(unsupported, E)

0 comments on commit f536e30

Please sign in to comment.