diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index d5dfc9e1ff5..25f11863b0d 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -27,6 +27,7 @@ jobs: - wheel-tests-cudf - wheel-build-cudf-polars - wheel-tests-cudf-polars + - cudf-polars-polars-tests - wheel-build-dask-cudf - wheel-tests-dask-cudf - devcontainer @@ -154,6 +155,17 @@ jobs: # This always runs, but only fails if this PR touches code in # pylibcudf or cudf_polars script: "ci/test_wheel_cudf_polars.sh" + cudf-polars-polars-tests: + needs: wheel-build-cudf-polars + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.08 + with: + # This selects "ARCH=amd64 + the latest supported Python + CUDA". + matrix_filter: map(select(.ARCH == "amd64")) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))])) + build_type: pull-request + # This always runs, but only fails if this PR touches code in + # pylibcudf or cudf_polars + script: "ci/test_cudf_polars_polars_tests.sh" wheel-build-dask-cudf: needs: wheel-build-cudf secrets: inherit diff --git a/ci/run_cudf_polars_polars_tests.sh b/ci/run_cudf_polars_polars_tests.sh new file mode 100755 index 00000000000..52a827af94c --- /dev/null +++ b/ci/run_cudf_polars_polars_tests.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -euo pipefail + +# Support invoking run_cudf_polars_pytests.sh outside the script directory +# Assumption, polars has been cloned in the root of the repo. +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../polars/ + +DESELECTED_TESTS=( + "tests/unit/test_polars_import.py::test_polars_import" # relies on a polars built in place + "tests/unit/streaming/test_streaming_sort.py::test_streaming_sort[True]" # relies on polars built in debug mode + "tests/unit/test_cpu_check.py::test_check_cpu_flags_skipped_no_flags" # Mock library error + "tests/docs/test_user_guide.py" # No dot binary in CI image +) + +DESELECTED_TESTS=$(printf -- " --deselect %s" "${DESELECTED_TESTS[@]}") +python -m pytest \ + --import-mode=importlib \ + --cache-clear \ + -m "" \ + -p cudf_polars.testing.plugin \ + -v \ + --tb=short \ + ${DESELECTED_TESTS} \ + "$@" \ + py-polars/tests diff --git a/ci/test_cudf_polars_polars_tests.sh b/ci/test_cudf_polars_polars_tests.sh new file mode 100755 index 00000000000..ed8943ecb57 --- /dev/null +++ b/ci/test_cudf_polars_polars_tests.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -eou pipefail + +# We will only fail these tests if the PR touches code in pylibcudf +# or cudf_polars itself. +# Note, the three dots mean we are doing diff between the merge-base +# of upstream and HEAD. So this is asking, "does _this branch_ touch +# files in cudf_polars/pylibcudf", rather than "are there changes +# between upstream and this branch which touch cudf_polars/pylibcudf" +# TODO: is the target branch exposed anywhere in an environment variable? +if [ -n "$(git diff --name-only origin/branch-24.08...HEAD -- python/cudf_polars/ python/cudf/cudf/_lib/pylibcudf/)" ]; +then + HAS_CHANGES=1 + rapids-logger "PR has changes in cudf-polars/pylibcudf, test fails treated as failure" +else + HAS_CHANGES=0 + rapids-logger "PR does not have changes in cudf-polars/pylibcudf, test fails NOT treated as failure" +fi + +rapids-logger "Download wheels" + +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" +RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 ./dist + +# Download the cudf built in the previous step +RAPIDS_PY_WHEEL_NAME="cudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-cudf-dep + +rapids-logger "Install cudf" +python -m pip install ./local-cudf-dep/cudf*.whl + +rapids-logger "Install cudf_polars" +python -m pip install $(echo ./dist/cudf_polars*.whl) + +TAG=$(python -c 'import polars; print(f"py-{polars.__version__}")') +rapids-logger "Clone polars to ${TAG}" +git clone https://github.com/pola-rs/polars.git --branch ${TAG} --depth 1 + +# Install requirements for running polars tests +rapids-logger "Install polars test requirements" +python -m pip install -r polars/py-polars/requirements-dev.txt -r polars/py-polars/requirements-ci.txt + +function set_exitcode() +{ + EXITCODE=$? +} +EXITCODE=0 +trap set_exitcode ERR +set +e + +rapids-logger "Run polars tests" +./ci/run_cudf_polars_polars_tests.sh + +trap ERR +set -e + +if [ ${EXITCODE} != 0 ]; then + rapids-logger "Running polars test suite FAILED: exitcode ${EXITCODE}" +else + rapids-logger "Running polars test suite PASSED" +fi + +if [ ${HAS_CHANGES} == 1 ]; then + exit ${EXITCODE} +else + exit 0 +fi diff --git a/ci/test_wheel_cudf_polars.sh b/ci/test_wheel_cudf_polars.sh index 900acd5d473..2f0dda4c96e 100755 --- a/ci/test_wheel_cudf_polars.sh +++ b/ci/test_wheel_cudf_polars.sh @@ -13,15 +13,21 @@ set -eou pipefail if [ -n "$(git diff --name-only origin/branch-24.08...HEAD -- python/cudf_polars/ python/cudf/cudf/_lib/pylibcudf/)" ]; then HAS_CHANGES=1 + rapids-logger "PR has changes in cudf-polars/pylibcudf, test fails treated as failure" else HAS_CHANGES=0 + rapids-logger "PR does not have changes in cudf-polars/pylibcudf, test fails NOT treated as failure" fi +rapids-logger "Download wheels" + RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 ./dist # Download the cudf built in the previous step RAPIDS_PY_WHEEL_NAME="cudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-cudf-dep + +rapids-logger "Install cudf" python -m pip install ./local-cudf-dep/cudf*.whl rapids-logger "Install cudf_polars" diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index b30ef875a7b..9a66d2527db 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -16,6 +16,8 @@ from cudf._lib.pylibcudf.libcudf.scalar.scalar cimport scalar from cudf._lib.pylibcudf.libcudf.types cimport size_type from cudf._lib.scalar cimport DeviceScalar +import cudf._lib.pylibcudf as plc + @acquire_spill_lock() def add_months(Column col, Column months): @@ -37,43 +39,9 @@ def add_months(Column col, Column months): @acquire_spill_lock() def extract_datetime_component(Column col, object field): - - cdef unique_ptr[column] c_result - cdef column_view col_view = col.view() - - with nogil: - if field == "year": - c_result = move(libcudf_datetime.extract_year(col_view)) - elif field == "month": - c_result = move(libcudf_datetime.extract_month(col_view)) - elif field == "day": - c_result = move(libcudf_datetime.extract_day(col_view)) - elif field == "weekday": - c_result = move(libcudf_datetime.extract_weekday(col_view)) - elif field == "hour": - c_result = move(libcudf_datetime.extract_hour(col_view)) - elif field == "minute": - c_result = move(libcudf_datetime.extract_minute(col_view)) - elif field == "second": - c_result = move(libcudf_datetime.extract_second(col_view)) - elif field == "millisecond": - c_result = move( - libcudf_datetime.extract_millisecond_fraction(col_view) - ) - elif field == "microsecond": - c_result = move( - libcudf_datetime.extract_microsecond_fraction(col_view) - ) - elif field == "nanosecond": - c_result = move( - libcudf_datetime.extract_nanosecond_fraction(col_view) - ) - elif field == "day_of_year": - c_result = move(libcudf_datetime.day_of_year(col_view)) - else: - raise ValueError(f"Invalid datetime field: '{field}'") - - result = Column.from_unique_ptr(move(c_result)) + result = Column.from_pylibcudf( + plc.datetime.extract_datetime_component(col.to_pylibcudf(mode="read"), field) + ) if field == "weekday": # Pandas counts Monday-Sunday as 0-6 diff --git a/python/cudf/cudf/_lib/pylibcudf/datetime.pyx b/python/cudf/cudf/_lib/pylibcudf/datetime.pyx index 82351327de6..87efcd495b9 100644 --- a/python/cudf/cudf/_lib/pylibcudf/datetime.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/datetime.pyx @@ -4,6 +4,16 @@ from libcpp.utility cimport move from cudf._lib.pylibcudf.libcudf.column.column cimport column from cudf._lib.pylibcudf.libcudf.datetime cimport ( + day_of_year as cpp_day_of_year, + extract_day as cpp_extract_day, + extract_hour as cpp_extract_hour, + extract_microsecond_fraction as cpp_extract_microsecond_fraction, + extract_millisecond_fraction as cpp_extract_millisecond_fraction, + extract_minute as cpp_extract_minute, + extract_month as cpp_extract_month, + extract_nanosecond_fraction as cpp_extract_nanosecond_fraction, + extract_second as cpp_extract_second, + extract_weekday as cpp_extract_weekday, extract_year as cpp_extract_year, ) @@ -31,3 +41,42 @@ cpdef Column extract_year( with nogil: result = move(cpp_extract_year(values.view())) return Column.from_libcudf(move(result)) + + +def extract_datetime_component(Column col, str field): + + cdef unique_ptr[column] c_result + + with nogil: + if field == "year": + c_result = move(cpp_extract_year(col.view())) + elif field == "month": + c_result = move(cpp_extract_month(col.view())) + elif field == "day": + c_result = move(cpp_extract_day(col.view())) + elif field == "weekday": + c_result = move(cpp_extract_weekday(col.view())) + elif field == "hour": + c_result = move(cpp_extract_hour(col.view())) + elif field == "minute": + c_result = move(cpp_extract_minute(col.view())) + elif field == "second": + c_result = move(cpp_extract_second(col.view())) + elif field == "millisecond": + c_result = move( + cpp_extract_millisecond_fraction(col.view()) + ) + elif field == "microsecond": + c_result = move( + cpp_extract_microsecond_fraction(col.view()) + ) + elif field == "nanosecond": + c_result = move( + cpp_extract_nanosecond_fraction(col.view()) + ) + elif field == "day_of_year": + c_result = move(cpp_day_of_year(col.view())) + else: + raise ValueError(f"Invalid datetime field: '{field}'") + + return Column.from_libcudf(move(c_result)) diff --git a/python/cudf/cudf/pylibcudf_tests/test_datetime.py b/python/cudf/cudf/pylibcudf_tests/test_datetime.py index 75af0fa6ca1..777c234c192 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_datetime.py +++ b/python/cudf/cudf/pylibcudf_tests/test_datetime.py @@ -1,8 +1,10 @@ # Copyright (c) 2024, NVIDIA CORPORATION. import datetime +import functools import pyarrow as pa +import pyarrow.compute as pc import pytest from utils import assert_column_eq @@ -10,7 +12,7 @@ @pytest.fixture -def column(has_nulls): +def date_column(has_nulls): values = [ datetime.date(1999, 1, 1), datetime.date(2024, 10, 12), @@ -22,9 +24,41 @@ def column(has_nulls): return plc.interop.from_arrow(pa.array(values, type=pa.date32())) -def test_extract_year(column): - got = plc.datetime.extract_year(column) +@pytest.fixture(scope="module", params=["s", "ms", "us", "ns"]) +def datetime_column(has_nulls, request): + values = [ + datetime.datetime(1999, 1, 1), + datetime.datetime(2024, 10, 12), + datetime.datetime(1970, 1, 1), + datetime.datetime(2260, 1, 1), + datetime.datetime(2024, 2, 29, 3, 14, 15), + datetime.datetime(2024, 2, 29, 3, 14, 15, 999), + ] + if has_nulls: + values[2] = None + return plc.interop.from_arrow( + pa.array(values, type=pa.timestamp(request.param)) + ) + + +@pytest.mark.parametrize( + "component, pc_fun", + [ + ("year", pc.year), + ("month", pc.month), + ("day", pc.day), + ("weekday", functools.partial(pc.day_of_week, count_from_zero=False)), + ("hour", pc.hour), + ("minute", pc.minute), + ("second", pc.second), + ("millisecond", pc.millisecond), + ("microsecond", pc.microsecond), + ("nanosecond", pc.nanosecond), + ], +) +def test_extraction(datetime_column, component, pc_fun): + got = plc.datetime.extract_datetime_component(datetime_column, component) # libcudf produces an int16, arrow produces an int64 - expect = pa.compute.year(plc.interop.to_arrow(column)).cast(pa.int16()) + expect = pc_fun(plc.interop.to_arrow(datetime_column)).cast(pa.int16()) assert_column_eq(expect, got) diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index effcff20d3d..d6f44621406 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -961,6 +961,18 @@ def do_evaluate( class TemporalFunction(Expr): __slots__ = ("name", "options", "children") + _COMPONENT_MAP: ClassVar[dict[pl_expr.TemporalFunction, str]] = { + pl_expr.TemporalFunction.Year: "year", + pl_expr.TemporalFunction.Month: "month", + pl_expr.TemporalFunction.Day: "day", + pl_expr.TemporalFunction.WeekDay: "weekday", + pl_expr.TemporalFunction.Hour: "hour", + pl_expr.TemporalFunction.Minute: "minute", + pl_expr.TemporalFunction.Second: "second", + pl_expr.TemporalFunction.Millisecond: "millisecond", + pl_expr.TemporalFunction.Microsecond: "microsecond", + pl_expr.TemporalFunction.Nanosecond: "nanosecond", + } _non_child = ("dtype", "name", "options") children: tuple[Expr, ...] @@ -975,7 +987,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name != pl_expr.TemporalFunction.Year: + if self.name not in self._COMPONENT_MAP: raise NotImplementedError(f"Temporal function {self.name}") def do_evaluate( @@ -990,12 +1002,59 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == pl_expr.TemporalFunction.Year: - (column,) = columns - return Column(plc.datetime.extract_year(column.obj)) - raise NotImplementedError( - f"TemporalFunction {self.name}" - ) # pragma: no cover; init trips first + (column,) = columns + if self.name == pl_expr.TemporalFunction.Microsecond: + millis = plc.datetime.extract_datetime_component(column.obj, "millisecond") + micros = plc.datetime.extract_datetime_component(column.obj, "microsecond") + millis_as_micros = plc.binaryop.binary_operation( + millis, + plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.DataType(plc.TypeId.INT32), + ) + total_micros = plc.binaryop.binary_operation( + micros, + millis_as_micros, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + return Column(total_micros) + elif self.name == pl_expr.TemporalFunction.Nanosecond: + millis = plc.datetime.extract_datetime_component(column.obj, "millisecond") + micros = plc.datetime.extract_datetime_component(column.obj, "microsecond") + nanos = plc.datetime.extract_datetime_component(column.obj, "nanosecond") + millis_as_nanos = plc.binaryop.binary_operation( + millis, + plc.interop.from_arrow(pa.scalar(1_000_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.types.DataType(plc.types.TypeId.INT32), + ) + micros_as_nanos = plc.binaryop.binary_operation( + micros, + plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.types.DataType(plc.types.TypeId.INT32), + ) + total_nanos = plc.binaryop.binary_operation( + nanos, + millis_as_nanos, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + total_nanos = plc.binaryop.binary_operation( + total_nanos, + micros_as_nanos, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + return Column(total_nanos) + + return Column( + plc.datetime.extract_datetime_component( + column.obj, + self._COMPONENT_MAP[self.name], + ) + ) class UnaryFunction(Expr): diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index a5046107e2a..e27c7827e9a 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -273,7 +273,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: eol = chr(parse_options["eol_char"]) if self.reader_options["schema"] is not None: # Reader schema provides names - column_names = list(self.reader_options["schema"]["inner"].keys()) + column_names = list(self.reader_options["schema"]["fields"].keys()) else: # file provides column names column_names = None diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index df5c0294141..2886f1c684f 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -94,6 +94,13 @@ def _( cloud_options = None else: reader_options, cloud_options = map(json.loads, options) + if ( + typ == "csv" + and visitor.version()[0] == 1 + and reader_options["schema"] is not None + ): + # Polars 1.7 renames the inner slot from "inner" to "fields". + reader_options["schema"] = {"fields": reader_options["schema"]["inner"]} file_options = node.file_options with_columns = file_options.with_columns n_rows = file_options.n_rows @@ -310,8 +317,8 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: # IR is versioned with major.minor, minor is bumped for backwards # compatible changes (e.g. adding new nodes), major is bumped for # incompatible changes (e.g. renaming nodes). - # Polars 1.4 changes definition of PythonScan. - if (version := visitor.version()) >= (2, 0): + # Polars 1.7 changes definition of the CSV reader options schema name. + if (version := visitor.version()) >= (3, 0): raise NotImplementedError( f"No support for polars IR {version=}" ) # pragma: no cover; no such version for now. @@ -419,12 +426,29 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex *(translate_expr(visitor, n=n) for n in node.input), ) elif isinstance(name, pl_expr.TemporalFunction): - return expr.TemporalFunction( + # functions for which evaluation of the expression may not return + # the same dtype as polars, either due to libcudf returning a different + # dtype, or due to our internal processing affecting what libcudf returns + needs_cast = { + pl_expr.TemporalFunction.Year, + pl_expr.TemporalFunction.Month, + pl_expr.TemporalFunction.Day, + pl_expr.TemporalFunction.WeekDay, + pl_expr.TemporalFunction.Hour, + pl_expr.TemporalFunction.Minute, + pl_expr.TemporalFunction.Second, + pl_expr.TemporalFunction.Millisecond, + } + result_expr = expr.TemporalFunction( dtype, name, options, *(translate_expr(visitor, n=n) for n in node.input), ) + if name in needs_cast: + return expr.Cast(dtype, result_expr) + return result_expr + elif isinstance(name, str): children = (translate_expr(visitor, n=n) for n in node.input) if name == "log": diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py new file mode 100644 index 00000000000..7be40f6f762 --- /dev/null +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Plugin for running polars test suite setting GPU engine as default.""" + +from __future__ import annotations + +from functools import partialmethod +from typing import TYPE_CHECKING + +import pytest + +import polars + +if TYPE_CHECKING: + from collections.abc import Mapping + + +def pytest_addoption(parser: pytest.Parser): + """Add plugin-specific options.""" + group = parser.getgroup( + "cudf-polars", "Plugin to set GPU as default engine for polars tests" + ) + group.addoption( + "--cudf-polars-no-fallback", + action="store_true", + help="Turn off fallback to CPU when running tests (default use fallback)", + ) + + +def pytest_configure(config: pytest.Config): + """Enable use of this module as a pytest plugin to enable GPU collection.""" + no_fallback = config.getoption("--cudf-polars-no-fallback") + collect = polars.LazyFrame.collect + engine = polars.GPUEngine(raise_on_fail=no_fallback) + polars.LazyFrame.collect = partialmethod(collect, engine=engine) + config.addinivalue_line( + "filterwarnings", + "ignore:.*GPU engine does not support streaming or background collection", + ) + config.addinivalue_line( + "filterwarnings", + "ignore:.*Query execution with GPU not supported", + ) + + +EXPECTED_FAILURES: Mapping[str, str] = { + "tests/unit/io/test_csv.py::test_compressed_csv": "Need to determine if file is compressed", + "tests/unit/io/test_csv.py::test_read_csv_only_loads_selected_columns": "Memory usage won't be correct due to GPU", + "tests/unit/io/test_lazy_count_star.py::test_count_compressed_csv_18057": "Need to determine if file is compressed", + "tests/unit/io/test_lazy_csv.py::test_scan_csv_slice_offset_zero": "Integer overflow in sliced read", + "tests/unit/io/test_lazy_parquet.py::test_parquet_is_in_statistics": "Debug output on stderr doesn't match", + "tests/unit/io/test_lazy_parquet.py::test_parquet_statistics": "Debug output on stderr doesn't match", + "tests/unit/io/test_lazy_parquet.py::test_parquet_different_schema[False]": "Needs cudf#16394", + "tests/unit/io/test_lazy_parquet.py::test_parquet_schema_mismatch_panic_17067[False]": "Needs cudf#16394", + "tests/unit/io/test_lazy_parquet.py::test_parquet_slice_pushdown_non_zero_offset[True]": "Unknown error: invalid parquet?", + "tests/unit/io/test_lazy_parquet.py::test_parquet_slice_pushdown_non_zero_offset[False]": "Unknown error: invalid parquet?", + "tests/unit/io/test_parquet.py::test_read_parquet_only_loads_selected_columns_15098": "Memory usage won't be correct due to GPU", + "tests/unit/io/test_scan.py::test_scan[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_projected_out[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_filter_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_projected_out[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_filter_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_parquet-write_parquet]": "Need to add include_file_path to IR", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_csv-write_csv]": "Need to add include_file_path to IR", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_ndjson-write_ndjson]": "Need to add include_file_path to IR", + "tests/unit/lazyframe/test_engine_selection.py::test_engine_import_error_raises[gpu]": "Expect this to pass because cudf-polars is installed", + "tests/unit/lazyframe/test_engine_selection.py::test_engine_import_error_raises[engine1]": "Expect this to pass because cudf-polars is installed", + "tests/unit/lazyframe/test_lazyframe.py::test_round[dtype1-123.55-1-123.6]": "Rounding midpoints is handled incorrectly", + "tests/unit/lazyframe/test_lazyframe.py::test_cast_frame": "Casting that raises not supported on GPU", + "tests/unit/lazyframe/test_lazyframe.py::test_lazy_cache_hit": "Debug output on stderr doesn't match", + "tests/unit/operations/aggregation/test_aggregations.py::test_duration_function_literal": "Broadcasting inside groupby-agg not supported", + "tests/unit/operations/aggregation/test_aggregations.py::test_sum_empty_and_null_set": "libcudf sums column of all nulls to null, not zero", + "tests/unit/operations/aggregation/test_aggregations.py::test_binary_op_agg_context_no_simplify_expr_12423": "groupby-agg of just literals should not produce collect_list", + "tests/unit/operations/aggregation/test_aggregations.py::test_nan_inf_aggregation": "treatment of nans and nulls together is different in libcudf and polars in groupby-agg context", + "tests/unit/operations/test_abs.py::test_abs_duration": "Need to raise for unsupported uops on timelike values", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input7-expected7-Float32-Float32]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input10-expected10-Date-output_dtype10]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input11-expected11-input_dtype11-output_dtype11]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input12-expected12-input_dtype12-output_dtype12]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input13-expected13-input_dtype13-output_dtype13]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input7-expected7-Float32-Float32]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input10-expected10-Date-output_dtype10]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input11-expected11-input_dtype11-output_dtype11]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input12-expected12-input_dtype12-output_dtype12]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input13-expected13-input_dtype13-output_dtype13]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input14-expected14-input_dtype14-output_dtype14]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input15-expected15-input_dtype15-output_dtype15]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input16-expected16-input_dtype16-output_dtype16]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_binary_agg_with_literal": "Incorrect broadcasting of literals in groupby-agg", + "tests/unit/operations/test_group_by.py::test_group_by_apply_first_input_is_literal": "Polars advertises incorrect schema names polars#18524", + "tests/unit/operations/test_group_by.py::test_aggregated_scalar_elementwise_15602": "Unsupported boolean function/dtype combination in groupby-agg", + "tests/unit/operations/test_group_by.py::test_schemas[data1-expr1-expected_select1-expected_gb1]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_by_monday_and_offset_5444": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[left-expected0]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[right-expected1]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[datapoint-expected2]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_rolling_dynamic_sortedness_check": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_validation": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_15225": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_join.py::test_cross_join_slice_pushdown": "Need to implement slice pushdown for cross joins", + "tests/unit/sql/test_cast.py::test_cast_errors[values0-values::uint8-conversion from `f64` to `u64` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_cast.py::test_cast_errors[values1-values::uint4-conversion from `i64` to `u32` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_cast.py::test_cast_errors[values2-values::int1-conversion from `i64` to `i8` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_miscellaneous.py::test_read_csv": "Incorrect handling of missing_is_null in read_csv", + "tests/unit/sql/test_wildcard_opts.py::test_select_wildcard_errors": "Raises correctly but with different exception", + "tests/unit/streaming/test_streaming_io.py::test_parquet_eq_statistics": "Debug output on stderr doesn't match", + "tests/unit/test_cse.py::test_cse_predicate_self_join": "Debug output on stderr doesn't match", + "tests/unit/test_empty.py::test_empty_9137": "Mismatching dtypes, needs cudf#15852", + # Maybe flaky, order-dependent? + "tests/unit/test_projections.py::test_schema_full_outer_join_projection_pd_13287": "Order-specific result check, query is correct but in different order", + "tests/unit/test_queries.py::test_group_by_agg_equals_zero_3535": "libcudf sums all nulls to null, not zero", +} + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: list[pytest.Item] +): + """Mark known failing tests.""" + if config.getoption("--cudf-polars-no-fallback"): + # Don't xfail tests if running without fallback + return + for item in items: + if item.nodeid in EXPECTED_FAILURES: + item.add_marker(pytest.mark.xfail(reason=EXPECTED_FAILURES[item.nodeid])) diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index d4a3bb50dae..06c0e217403 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -58,6 +58,9 @@ exclude_also = [ "class .*\\bProtocol\\):", "assert_never\\(" ] +# The cudf_polars test suite doesn't exercise the plugin, so we omit +# it from coverage checks. +omit = ["cudf_polars/testing/plugin.py"] [tool.ruff] line-length = 88 diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py index 218101bf87c..c6ea29ddd38 100644 --- a/python/cudf_polars/tests/expressions/test_datetime_basic.py +++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py @@ -9,7 +9,11 @@ import polars as pl -from cudf_polars.testing.asserts import assert_gpu_result_equal +from cudf_polars.dsl.expr import TemporalFunction +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) @pytest.mark.parametrize( @@ -37,26 +41,97 @@ def test_datetime_dataframe_scan(dtype): assert_gpu_result_equal(query) +datetime_extract_fields = [ + "year", + "month", + "day", + "weekday", + "hour", + "minute", + "second", + "millisecond", + "microsecond", + "nanosecond", +] + + +@pytest.fixture( + ids=datetime_extract_fields, + params=[methodcaller(f) for f in datetime_extract_fields], +) +def field(request): + return request.param + + +def test_datetime_extract(field): + ldf = pl.LazyFrame( + { + "datetimes": pl.datetime_range( + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 12, 30), + "3mo14h15s11ms33us999ns", + eager=True, + ) + } + ) + + q = ldf.select(field(pl.col("datetimes").dt)) + + assert_gpu_result_equal(q) + + +def test_datetime_extra_unsupported(monkeypatch): + ldf = pl.LazyFrame( + { + "datetimes": pl.datetime_range( + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 12, 30), + "3mo14h15s11ms33us999ns", + eager=True, + ) + } + ) + + def unsupported_name_setter(self, value): + pass + + def unsupported_name_getter(self): + return "unsupported" + + monkeypatch.setattr( + TemporalFunction, + "name", + property(unsupported_name_getter, unsupported_name_setter), + ) + + q = ldf.select(pl.col("datetimes").dt.nanosecond()) + + assert_ir_translation_raises(q, NotImplementedError) + + @pytest.mark.parametrize( "field", [ methodcaller("year"), - pytest.param( - methodcaller("day"), - marks=pytest.mark.xfail(reason="day extraction not implemented"), - ), + methodcaller("month"), + methodcaller("day"), + methodcaller("weekday"), ], ) -def test_datetime_extract(field): +def test_date_extract(field): + ldf = pl.LazyFrame( + { + "dates": [ + datetime.date(2024, 1, 1), + datetime.date(2024, 10, 11), + ] + } + ) + ldf = pl.LazyFrame( {"dates": [datetime.date(2024, 1, 1), datetime.date(2024, 10, 11)]} ) - q = ldf.select(field(pl.col("dates").dt)) - with pytest.raises(AssertionError): - # polars produces int32, libcudf produces int16 for the year extraction - # libcudf can lose data here. - # https://github.com/rapidsai/cudf/issues/16196 - assert_gpu_result_equal(q) + q = ldf.select(field(pl.col("dates").dt)) - assert_gpu_result_equal(q, check_dtypes=False) + assert_gpu_result_equal(q) diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index 6f996e0e0ec..68ee118c701 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -5,6 +5,7 @@ import itertools import pytest +from packaging import version import polars as pl @@ -168,7 +169,13 @@ def test_groupby_nan_minmax_raises(op): "expr", [ pl.lit(1).alias("value"), - pl.lit([[4, 5, 6]]).alias("value"), + pytest.param( + pl.lit([[4, 5, 6]]).alias("value"), + marks=pytest.mark.xfail( + condition=version.parse(pl.__version__) >= version.parse("1.7.1"), + reason="Broken in polars 1.7.1", + ), + ), pl.col("float") * (1 - pl.col("int")), [pl.lit(2).alias("value"), pl.col("float") * 2], ],