From 22cefc94d05727607563d6519eb17e1eb95c5478 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Tue, 24 Sep 2024 21:40:11 -0500 Subject: [PATCH] Fix metadata after implicit array conversion from Dask cuDF (#16842) Temporary workaround for https://github.com/dask/dask/issues/11017 in Dask cuDF (when query-planning is enabled). I will try to move this fix upstream soon. However, the next dask release will probably not be used by 24.10, and it's still unclear whether the same fix works for all CPU cases. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/cudf/pull/16842 --- .../dask_cudf/dask_cudf/expr/_collection.py | 79 +++++++++++++------ python/dask_cudf/dask_cudf/tests/test_core.py | 17 ++-- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 97e1dffc65b..c1dd16eac8d 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -202,27 +202,58 @@ class Index(DXIndex, CudfFrameBase): ## -try: - from dask_expr._backends import create_array_collection - - @get_collection_type.register_lazy("cupy") - def _register_cupy(): - import cupy - - @get_collection_type.register(cupy.ndarray) - def get_collection_type_cupy_array(_): - return create_array_collection - - @get_collection_type.register_lazy("cupyx") - def _register_cupyx(): - # Needed for cuml - from cupyx.scipy.sparse import spmatrix - - @get_collection_type.register(spmatrix) - def get_collection_type_csr_matrix(_): - return create_array_collection - -except ImportError: - # Older version of dask-expr. - # Implicit conversion to array wont work. - pass +def _create_array_collection_with_meta(expr): + # NOTE: This is the GPU compatible version of + # `new_dd_object` for DataFrame -> Array conversion. + # This can be removed if dask#11017 is resolved + # (See: https://github.com/dask/dask/issues/11017) + import numpy as np + + import dask.array as da + from dask.blockwise import Blockwise + from dask.highlevelgraph import HighLevelGraph + + result = expr.optimize() + dsk = result.__dask_graph__() + name = result._name + meta = result._meta + divisions = result.divisions + chunks = ((np.nan,) * (len(divisions) - 1),) + tuple( + (d,) for d in meta.shape[1:] + ) + if len(chunks) > 1: + if isinstance(dsk, HighLevelGraph): + layer = dsk.layers[name] + else: + # dask-expr provides a dict only + layer = dsk + if isinstance(layer, Blockwise): + layer.new_axes["j"] = chunks[1][0] + layer.output_indices = layer.output_indices + ("j",) + else: + suffix = (0,) * (len(chunks) - 1) + for i in range(len(chunks[0])): + layer[(name, i) + suffix] = layer.pop((name, i)) + + return da.Array(dsk, name=name, chunks=chunks, meta=meta) + + +@get_collection_type.register_lazy("cupy") +def _register_cupy(): + import cupy + + get_collection_type.register( + cupy.ndarray, + lambda _: _create_array_collection_with_meta, + ) + + +@get_collection_type.register_lazy("cupyx") +def _register_cupyx(): + # Needed for cuml + from cupyx.scipy.sparse import spmatrix + + get_collection_type.register( + spmatrix, + lambda _: _create_array_collection_with_meta, + ) diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 7aa0f6320f2..9f54aba3e13 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -16,6 +16,7 @@ import dask_cudf from dask_cudf.tests.utils import ( + QUERY_PLANNING_ON, require_dask_expr, skip_dask_expr, xfail_dask_expr, @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy(): def func(x): return x.values - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 - result = ds.map_partitions(func, meta=s.values).compute() - expect = func(s) + result = ds.map_partitions(func, meta=s.values) - dask.array.assert_eq(result, expect) + if QUERY_PLANNING_ON: + # Check Array and round-tripped DataFrame + dask.array.assert_eq(result, func(s)) + dd.assert_eq(result.to_dask_dataframe(), s, check_index=False) + else: + # Legacy version still carries numpy metadata + # See: https://github.com/dask/dask/issues/11017 + dask.array.assert_eq(result.compute(), func(s)) def test_implicit_array_conversion_cupy_sparse(): @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse(): def func(x): return cupyx.scipy.sparse.csr_matrix(x.values) - # Need to compute the dask collection for now. - # See: https://github.com/dask/dask/issues/11017 result = ds.map_partitions(func, meta=s.values).compute() expect = func(s)