Skip to content

Commit

Permalink
Fix metadata after implicit array conversion from Dask cuDF (#16842)
Browse files Browse the repository at this point in the history
Temporary workaround for dask/dask#11017 in Dask cuDF (when query-planning is enabled).
I will try to move this fix upstream soon. However, the next dask release will probably not be used by 24.10, and it's still unclear whether the same fix works for all CPU cases.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #16842
  • Loading branch information
rjzamora authored Sep 25, 2024
1 parent 73fa557 commit 22cefc9
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
79 changes: 55 additions & 24 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,58 @@ class Index(DXIndex, CudfFrameBase):
##


try:
from dask_expr._backends import create_array_collection

@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

@get_collection_type.register(cupy.ndarray)
def get_collection_type_cupy_array(_):
return create_array_collection

@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

@get_collection_type.register(spmatrix)
def get_collection_type_csr_matrix(_):
return create_array_collection

except ImportError:
# Older version of dask-expr.
# Implicit conversion to array wont work.
pass
def _create_array_collection_with_meta(expr):
# NOTE: This is the GPU compatible version of
# `new_dd_object` for DataFrame -> Array conversion.
# This can be removed if dask#11017 is resolved
# (See: https://github.com/dask/dask/issues/11017)
import numpy as np

import dask.array as da
from dask.blockwise import Blockwise
from dask.highlevelgraph import HighLevelGraph

result = expr.optimize()
dsk = result.__dask_graph__()
name = result._name
meta = result._meta
divisions = result.divisions
chunks = ((np.nan,) * (len(divisions) - 1),) + tuple(
(d,) for d in meta.shape[1:]
)
if len(chunks) > 1:
if isinstance(dsk, HighLevelGraph):
layer = dsk.layers[name]
else:
# dask-expr provides a dict only
layer = dsk
if isinstance(layer, Blockwise):
layer.new_axes["j"] = chunks[1][0]
layer.output_indices = layer.output_indices + ("j",)
else:
suffix = (0,) * (len(chunks) - 1)
for i in range(len(chunks[0])):
layer[(name, i) + suffix] = layer.pop((name, i))

return da.Array(dsk, name=name, chunks=chunks, meta=meta)


@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

get_collection_type.register(
cupy.ndarray,
lambda _: _create_array_collection_with_meta,
)


@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

get_collection_type.register(
spmatrix,
lambda _: _create_array_collection_with_meta,
)
17 changes: 10 additions & 7 deletions python/dask_cudf/dask_cudf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dask_cudf
from dask_cudf.tests.utils import (
QUERY_PLANNING_ON,
require_dask_expr,
skip_dask_expr,
xfail_dask_expr,
Expand Down Expand Up @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy():
def func(x):
return x.values

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)
result = ds.map_partitions(func, meta=s.values)

dask.array.assert_eq(result, expect)
if QUERY_PLANNING_ON:
# Check Array and round-tripped DataFrame
dask.array.assert_eq(result, func(s))
dd.assert_eq(result.to_dask_dataframe(), s, check_index=False)
else:
# Legacy version still carries numpy metadata
# See: https://github.com/dask/dask/issues/11017
dask.array.assert_eq(result.compute(), func(s))


def test_implicit_array_conversion_cupy_sparse():
Expand All @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse():
def func(x):
return cupyx.scipy.sparse.csr_matrix(x.values)

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)

Expand Down

0 comments on commit 22cefc9

Please sign in to comment.