From 75c5c83f1375213c94527eba1d0488145d7fdce7 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Wed, 25 Sep 2024 09:12:32 -0500 Subject: [PATCH] Add dask-cudf workaround for missing `rename_axis` support in cudf (#16899) See https://github.com/rapidsai/cudf/issues/16895 Closes https://github.com/rapidsai/cudf/issues/16892 Dask-expr uses `rename_axis`, which is not supported by cudf yet. This is a temporary workaround until #16895 is resolved. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Mads R. B. Kristensen (https://github.com/madsbk) - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/16899 --- python/dask_cudf/dask_cudf/expr/_collection.py | 12 ++++++++++++ python/dask_cudf/dask_cudf/expr/_expr.py | 16 +++++++++++++++- python/dask_cudf/dask_cudf/tests/test_core.py | 12 ++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index c1dd16eac8d..907abaa2bfc 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -15,6 +15,7 @@ from dask import config from dask.dataframe.core import is_dataframe_like +from dask.typing import no_default import cudf @@ -90,6 +91,17 @@ def var( ) ) + def rename_axis( + self, mapper=no_default, index=no_default, columns=no_default, axis=0 + ): + from dask_cudf.expr._expr import RenameAxisCudf + + return new_collection( + RenameAxisCudf( + self, mapper=mapper, index=index, columns=columns, axis=axis + ) + ) + class DataFrame(DXDataFrame, CudfFrameBase): @classmethod diff --git a/python/dask_cudf/dask_cudf/expr/_expr.py b/python/dask_cudf/dask_cudf/expr/_expr.py index 8a2c50d3fe7..b284ab3774d 100644 --- a/python/dask_cudf/dask_cudf/expr/_expr.py +++ b/python/dask_cudf/dask_cudf/expr/_expr.py @@ -4,11 +4,12 @@ import dask_expr._shuffle as _shuffle_module from dask_expr import new_collection from dask_expr._cumulative import CumulativeBlockwise -from dask_expr._expr import Elemwise, Expr, VarColumns +from dask_expr._expr import Elemwise, Expr, RenameAxis, VarColumns from dask_expr._reductions import Reduction, Var from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty from dask.dataframe.dispatch import is_categorical_dtype +from dask.typing import no_default import cudf @@ -17,6 +18,19 @@ ## +class RenameAxisCudf(RenameAxis): + # TODO: Remove this after rename_axis is supported in cudf + # (See: https://github.com/rapidsai/cudf/issues/16895) + @staticmethod + def operation(df, index=no_default, **kwargs): + if index != no_default: + df.index.name = index + return df + raise NotImplementedError( + "Only `index` is supported for the cudf backend" + ) + + class ToCudfBackend(Elemwise): # TODO: Inherit from ToBackend when rapids-dask-dependency # is pinned to dask>=2024.8.1 diff --git a/python/dask_cudf/dask_cudf/tests/test_core.py b/python/dask_cudf/dask_cudf/tests/test_core.py index 9f54aba3e13..5f0fae86691 100644 --- a/python/dask_cudf/dask_cudf/tests/test_core.py +++ b/python/dask_cudf/dask_cudf/tests/test_core.py @@ -1027,3 +1027,15 @@ def test_cov_corr(op, numeric_only): # (See: https://github.com/rapidsai/cudf/issues/12626) expect = getattr(df.to_pandas(), op)(numeric_only=numeric_only) dd.assert_eq(res, expect) + + +def test_rename_axis_after_join(): + df1 = cudf.DataFrame(index=["a", "b", "c"], data=dict(a=[1, 2, 3])) + df1.index.name = "test" + ddf1 = dd.from_pandas(df1, 2) + + df2 = cudf.DataFrame(index=["a", "b", "d"], data=dict(b=[1, 2, 3])) + ddf2 = dd.from_pandas(df2, 2) + result = ddf1.join(ddf2, how="outer") + expected = df1.join(df2, how="outer") + dd.assert_eq(result, expected, check_index=False)