Skip to content

Commit

Permalink
Merge pull request #16903 from rapidsai/branch-24.10
Browse files Browse the repository at this point in the history
Forward-merge branch-24.10 into branch-24.12
  • Loading branch information
GPUtester authored Sep 25, 2024
2 parents a4b4151 + 22cefc9 commit e9a40d8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
79 changes: 55 additions & 24 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,58 @@ class Index(DXIndex, CudfFrameBase):
##


try:
from dask_expr._backends import create_array_collection

@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

@get_collection_type.register(cupy.ndarray)
def get_collection_type_cupy_array(_):
return create_array_collection

@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

@get_collection_type.register(spmatrix)
def get_collection_type_csr_matrix(_):
return create_array_collection

except ImportError:
# Older version of dask-expr.
# Implicit conversion to array wont work.
pass
def _create_array_collection_with_meta(expr):
# NOTE: This is the GPU compatible version of
# `new_dd_object` for DataFrame -> Array conversion.
# This can be removed if dask#11017 is resolved
# (See: https://github.com/dask/dask/issues/11017)
import numpy as np

import dask.array as da
from dask.blockwise import Blockwise
from dask.highlevelgraph import HighLevelGraph

result = expr.optimize()
dsk = result.__dask_graph__()
name = result._name
meta = result._meta
divisions = result.divisions
chunks = ((np.nan,) * (len(divisions) - 1),) + tuple(
(d,) for d in meta.shape[1:]
)
if len(chunks) > 1:
if isinstance(dsk, HighLevelGraph):
layer = dsk.layers[name]
else:
# dask-expr provides a dict only
layer = dsk
if isinstance(layer, Blockwise):
layer.new_axes["j"] = chunks[1][0]
layer.output_indices = layer.output_indices + ("j",)
else:
suffix = (0,) * (len(chunks) - 1)
for i in range(len(chunks[0])):
layer[(name, i) + suffix] = layer.pop((name, i))

return da.Array(dsk, name=name, chunks=chunks, meta=meta)


@get_collection_type.register_lazy("cupy")
def _register_cupy():
import cupy

get_collection_type.register(
cupy.ndarray,
lambda _: _create_array_collection_with_meta,
)


@get_collection_type.register_lazy("cupyx")
def _register_cupyx():
# Needed for cuml
from cupyx.scipy.sparse import spmatrix

get_collection_type.register(
spmatrix,
lambda _: _create_array_collection_with_meta,
)
17 changes: 10 additions & 7 deletions python/dask_cudf/dask_cudf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dask_cudf
from dask_cudf.tests.utils import (
QUERY_PLANNING_ON,
require_dask_expr,
skip_dask_expr,
xfail_dask_expr,
Expand Down Expand Up @@ -950,12 +951,16 @@ def test_implicit_array_conversion_cupy():
def func(x):
return x.values

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)
result = ds.map_partitions(func, meta=s.values)

dask.array.assert_eq(result, expect)
if QUERY_PLANNING_ON:
# Check Array and round-tripped DataFrame
dask.array.assert_eq(result, func(s))
dd.assert_eq(result.to_dask_dataframe(), s, check_index=False)
else:
# Legacy version still carries numpy metadata
# See: https://github.com/dask/dask/issues/11017
dask.array.assert_eq(result.compute(), func(s))


def test_implicit_array_conversion_cupy_sparse():
Expand All @@ -967,8 +972,6 @@ def test_implicit_array_conversion_cupy_sparse():
def func(x):
return cupyx.scipy.sparse.csr_matrix(x.values)

# Need to compute the dask collection for now.
# See: https://github.com/dask/dask/issues/11017
result = ds.map_partitions(func, meta=s.values).compute()
expect = func(s)

Expand Down

0 comments on commit e9a40d8

Please sign in to comment.