diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 97e1dffc65b..c1dd16eac8d 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -202,27 +202,58 @@ class Index(DXIndex, CudfFrameBase): ## -try: - from dask_expr._backends import create_array_collection - - @get_collection_type.register_lazy("cupy") - def _register_cupy(): - import cupy - - @get_collection_type.register(cupy.ndarray) - def get_collection_type_cupy_array(_): - return create_array_collection - - @get_collection_type.register_lazy("cupyx") - def _register_cupyx(): - # Needed for cuml - from cupyx.scipy.sparse import spmatrix - - @get_collection_type.register(spmatrix) - def get_collection_type_csr_matrix(_): - return create_array_collection - -except ImportError: - # Older version of dask-expr. - # Implicit conversion to array wont work. - pass +def _create_array_collection_with_meta(expr): + # NOTE: This is the GPU compatible version of + # `new_dd_object` for DataFrame -> Array conversion. + # This can be removed if dask#11017 is resolved + # (See: https://github.com/dask/dask/issues/11017) + import numpy as np + + import dask.array as da + from dask.blockwise import Blockwise + from dask.highlevelgraph import HighLevelGraph + + result = expr.optimize() + dsk = result.__dask_graph__() + name = result._name + meta = result._meta + divisions = result.divisions + chunks = ((np.nan,) * (len(divisions) - 1),) + tuple( + (d,) for d in meta.shape[1:] + ) + if len(chunks) > 1: + if isinstance(dsk, HighLevelGraph): + layer = dsk.layers[name] + else: + # dask-expr provides a dict only + layer = dsk + if isinstance(layer, Blockwise): + layer.new_axes["j"] = chunks[1][0] + layer.output_indices = layer.output_indices + ("j",) + else: + suffix = (0,) * (len(chunks) - 1) + for i in range(len(chunks[0])): + layer[(name, i) + suffix] = layer.pop((name, i)) + + return da.Array(dsk, name=name, chunks=chunks, meta=meta) + + +@get_collection_type.register_lazy("cupy") +def _register_cupy(): + import cupy + + get_collection_type.register( + cupy.ndarray, + lambda _: _create_array_collection_with_meta, + ) + + +@get_collection_type.register_lazy("cupyx") +def _register_cupyx(): + # Needed for cuml + from cupyx.scipy.sparse import spmatrix + + get_collection_type.register( + spmatrix, + lambda _: _create_array_collection_with_meta, + ) diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 7aa0f6320f2..9f54aba3e13 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -16,6 +16,7 @@ import dask_cudf from dask_cudf.tests.utils import ( + QUERY_PLANNING_ON, require_dask_expr, skip_dask_expr, xfail_dask_expr, @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy(): def func(x): return x.values - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 - result = ds.map_partitions(func, meta=s.values).compute() - expect = func(s) + result = ds.map_partitions(func, meta=s.values) - dask.array.assert_eq(result, expect) + if QUERY_PLANNING_ON: + # Check Array and round-tripped DataFrame + dask.array.assert_eq(result, func(s)) + dd.assert_eq(result.to_dask_dataframe(), s, check_index=False) + else: + # Legacy version still carries numpy metadata + # See: https://github.com/dask/dask/issues/11017 + dask.array.assert_eq(result.compute(), func(s)) def test_implicit_array_conversion_cupy_sparse(): @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse(): def func(x): return cupyx.scipy.sparse.csr_matrix(x.values) - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 result = ds.map_partitions(func, meta=s.values).compute() expect = func(s)