Skip to content

Commit

Permalink
Enable dask_cudf json and s3 tests with query-planning on (#15408)
Browse files Browse the repository at this point in the history
Addresses parts of #15027 (json and s3 testing).

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #15408
  • Loading branch information
rjzamora authored Apr 1, 2024
1 parent e5f9e2d commit 09f8c8a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 13 deletions.
15 changes: 14 additions & 1 deletion python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Iterator
from functools import partial

import cupy as cp
import numpy as np
Expand Down Expand Up @@ -484,7 +485,6 @@ def sizeof_cudf_series_index(obj):
def _simple_cudf_encode(_):
# Basic pickle-based encoding for a partd k-v store
import pickle
from functools import partial

import partd

Expand Down Expand Up @@ -686,6 +686,19 @@ def from_dict(
constructor=constructor,
)

@staticmethod
def read_json(*args, engine="auto", **kwargs):
return _default_backend(
dd.read_json,
*args,
engine=(
partial(cudf.read_json, engine=engine)
if isinstance(engine, str)
else engine
),
**kwargs,
)


# Import/register cudf-specific classes for dask-expr
try:
Expand Down
4 changes: 2 additions & 2 deletions python/dask_cudf/dask_cudf/io/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support
pytestmark = skip_dask_expr()
# No dask-expr support for dask_expr<=1.0.5
pytestmark = skip_dask_expr(lt_version="1.0.5+a")


def test_read_json_backend_dispatch(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def test_check_file_size(tmpdir):
dask_cudf.io.read_parquet(fn, check_file_size=1).compute()


@xfail_dask_expr("HivePartitioning cannot be hashed")
@xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="1.0")
def test_null_partition(tmpdir):
import pyarrow as pa
from pyarrow.dataset import HivePartitioning
Expand Down
6 changes: 1 addition & 5 deletions python/dask_cudf/dask_cudf/io/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import pytest

import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support
pytestmark = skip_dask_expr()

moto = pytest.importorskip("moto", minversion="3.1.6")
boto3 = pytest.importorskip("boto3")
Expand Down Expand Up @@ -111,7 +107,7 @@ def test_read_csv(s3_base, s3so):
s3_base=s3_base, bucket="daskcsv", files={"a.csv": b"a,b\n1,2\n3,4\n"}
):
df = dask_cudf.read_csv(
"s3://daskcsv/*.csv", chunksize="50 B", storage_options=s3so
"s3://daskcsv/*.csv", blocksize="50 B", storage_options=s3so
)
assert df.a.sum().compute() == 4

Expand Down
24 changes: 20 additions & 4 deletions python/dask_cudf/dask_cudf/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import numpy as np
import pandas as pd
import pytest
from packaging.version import Version

import dask.dataframe as dd

import cudf

from dask_cudf.expr import QUERY_PLANNING_ON

if QUERY_PLANNING_ON:
import dask_expr

DASK_EXPR_VERSION = Version(dask_expr.__version__)
else:
DASK_EXPR_VERSION = None


def _make_random_frame(nelem, npartitions=2, include_na=False):
df = pd.DataFrame(
Expand All @@ -27,9 +35,17 @@ def _make_random_frame(nelem, npartitions=2, include_na=False):
_default_reason = "Not compatible with dask-expr"


def skip_dask_expr(reason=_default_reason):
return pytest.mark.skipif(QUERY_PLANNING_ON, reason=reason)
def skip_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
skip = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
else:
skip = QUERY_PLANNING_ON
return pytest.mark.skipif(skip, reason=reason)


def xfail_dask_expr(reason=_default_reason):
return pytest.mark.xfail(QUERY_PLANNING_ON, reason=reason)
def xfail_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
xfail = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
else:
xfail = QUERY_PLANNING_ON
return pytest.mark.xfail(xfail, reason=reason)

0 comments on commit 09f8c8a

Please sign in to comment.