From c380fc0bc6d684f35d06ed3847fa937ddf2df8bb Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sat, 27 Jul 2024 14:33:07 +0100 Subject: [PATCH] feat: Add initial (and very minimal) support for Dask.DataFrame (#635) --- .github/workflows/extremes.yml | 5 +- narwhals/_dask/__init__.py | 0 narwhals/_dask/dataframe.py | 41 ++++++++++ narwhals/_dask/expr.py | 144 +++++++++++++++++++++++++++++++++ narwhals/_dask/namespace.py | 35 ++++++++ narwhals/_dask/utils.py | 54 +++++++++++++ narwhals/dependencies.py | 21 +++++ narwhals/translate.py | 23 ++++++ requirements-dev.txt | 2 +- tests/dask_test.py | 48 +++++++++++ 10 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 narwhals/_dask/__init__.py create mode 100644 narwhals/_dask/dataframe.py create mode 100644 narwhals/_dask/expr.py create mode 100644 narwhals/_dask/namespace.py create mode 100644 narwhals/_dask/utils.py create mode 100644 tests/dask_test.py diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 1140d8ee9..e689de046 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -66,7 +66,7 @@ jobs: - name: Run doctests run: pytest narwhals --doctest-modules - pandas-nightly: + pandas-nightly-and-dask: strategy: matrix: python-version: ["3.12"] @@ -102,6 +102,9 @@ jobs: run: python -m pip uninstall numpy -y - name: install numpy nightly run: python -m pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy + - name: install dask + run: | + pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask git+https://github.com/dask/dask-expr - name: show-deps run: pip freeze - name: Run pytest diff --git a/narwhals/_dask/__init__.py b/narwhals/_dask/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py new file mode 100644 index 000000000..a51e4c7d9 --- /dev/null +++ b/narwhals/_dask/dataframe.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from narwhals._dask.utils import parse_exprs_and_named_exprs +from narwhals.dependencies import get_dask_dataframe + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._dask.expr import DaskExpr + from narwhals._dask.namespace import DaskNamespace + + +class DaskLazyFrame: + def __init__( + self, native_dataframe: Any, *, backend_version: tuple[int, ...] + ) -> None: + self._native_dataframe = native_dataframe + self._backend_version = backend_version + + def __native_namespace__(self) -> Any: # pragma: no cover + return get_dask_dataframe() + + def __narwhals_namespace__(self) -> DaskNamespace: + from narwhals._dask.namespace import DaskNamespace + + return DaskNamespace(backend_version=self._backend_version) + + def __narwhals_lazyframe__(self) -> Self: + return self + + def _from_native_dataframe(self, df: Any) -> Self: + return self.__class__(df, backend_version=self._backend_version) + + def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: + df = self._native_dataframe + new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + df = df.assign(**new_series) + return self._from_native_dataframe(df) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py new file mode 100644 index 000000000..7172bf246 --- /dev/null +++ b/narwhals/_dask/expr.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +from narwhals.dependencies import get_dask_expr + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._dask.dataframe import DaskLazyFrame + +from narwhals._dask.utils import maybe_evaluate + + +class DaskExpr: + def __init__( + self, + # callable from DaskLazyFrame to list of (native) Dask Series + call: Callable[[DaskLazyFrame], Any], + *, + depth: int, + function_name: str, + root_names: list[str] | None, + output_names: list[str] | None, + backend_version: tuple[int, ...], + ) -> None: + self._call = call + self._depth = depth + self._function_name = function_name + self._root_names = root_names + self._output_names = output_names + self._backend_version = backend_version + + @classmethod + def from_column_names( + cls: type[Self], + *column_names: str, + backend_version: tuple[int, ...], + ) -> Self: + def func(df: DaskLazyFrame) -> list[Any]: + return [ + df._native_dataframe.loc[:, column_name] for column_name in column_names + ] + + return cls( + func, + depth=0, + function_name="col", + root_names=list(column_names), + output_names=list(column_names), + backend_version=backend_version, + ) + + def _from_call( + self, + # callable from DaskLazyFrame to list of (native) Dask Series + call: Any, + expr_name: str, + *args: Any, + **kwargs: Any, + ) -> Self: + def func(df: DaskLazyFrame) -> list[Any]: + results = [] + inputs = self._call(df) + for _input in inputs: + _args = [maybe_evaluate(df, x) for x in args] + _kwargs = { + key: maybe_evaluate(df, value) for key, value in kwargs.items() + } + result = call(_input, *_args, **_kwargs) + if isinstance(result, get_dask_expr()._collection.Series): + result = result.rename(_input.name) + results.append(result) + return results + + # Try tracking root and output names by combining them from all + # expressions appearing in args and kwargs. If any anonymous + # expression appears (e.g. nw.all()), then give up on tracking root names + # and just set it to None. + root_names = copy(self._root_names) + output_names = self._output_names + for arg in list(args) + list(kwargs.values()): + if root_names is not None and isinstance(arg, self.__class__): + if arg._root_names is not None: + root_names.extend(arg._root_names) + else: # pragma: no cover + # TODO(unassigned): increase coverage + root_names = None + output_names = None + break + elif root_names is None: # pragma: no cover + # TODO(unassigned): increase coverage + output_names = None + break + + if not ( + (output_names is None and root_names is None) + or (output_names is not None and root_names is not None) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + return self.__class__( + func, + depth=self._depth + 1, + function_name=f"{self._function_name}->{expr_name}", + root_names=root_names, + output_names=output_names, + backend_version=self._backend_version, + ) + + def alias(self, name: str) -> Self: + def func(df: DaskLazyFrame) -> list[Any]: + results = [] + inputs = self._call(df) + for _input in inputs: + result = _input.rename(name) + results.append(result) + return results + + return self.__class__( + func, + depth=self._depth, + function_name=self._function_name, + root_names=self._root_names, + output_names=[name], + backend_version=self._backend_version, + ) + + def __add__(self, other: Any) -> Self: + return self._from_call( + lambda _input, other: _input.__add__(other), + "__add__", + other, + ) + + def mean(self) -> Self: + return self._from_call( + lambda _input: _input.mean(), + "mean", + ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py new file mode 100644 index 000000000..f30adba1a --- /dev/null +++ b/narwhals/_dask/namespace.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from narwhals import dtypes +from narwhals._dask.expr import DaskExpr + + +class DaskNamespace: + Int64 = dtypes.Int64 + Int32 = dtypes.Int32 + Int16 = dtypes.Int16 + Int8 = dtypes.Int8 + UInt64 = dtypes.UInt64 + UInt32 = dtypes.UInt32 + UInt16 = dtypes.UInt16 + UInt8 = dtypes.UInt8 + Float64 = dtypes.Float64 + Float32 = dtypes.Float32 + Boolean = dtypes.Boolean + Object = dtypes.Object + Unknown = dtypes.Unknown + Categorical = dtypes.Categorical + Enum = dtypes.Enum + String = dtypes.String + Datetime = dtypes.Datetime + Duration = dtypes.Duration + Date = dtypes.Date + + def __init__(self, *, backend_version: tuple[int, ...]) -> None: + self._backend_version = backend_version + + def col(self, *column_names: str) -> DaskExpr: + return DaskExpr.from_column_names( + *column_names, + backend_version=self._backend_version, + ) diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py new file mode 100644 index 000000000..d6eeeb5ea --- /dev/null +++ b/narwhals/_dask/utils.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Any + +from narwhals.dependencies import get_dask_expr + +if TYPE_CHECKING: + from narwhals._dask.dataframe import DaskLazyFrame + + +def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: + from narwhals._dask.expr import DaskExpr + + if isinstance(obj, DaskExpr): + results = obj._call(df) + if len(results) != 1: # pragma: no cover + msg = "Multi-output expressions not supported in this context" + raise NotImplementedError(msg) + result = results[0] + if not get_dask_expr()._expr.are_co_aligned( + df._native_dataframe._expr, result._expr + ): # pragma: no cover + # are_co_aligned is a method which cheaply checks if two Dask expressions + # have the same index, and therefore don't require index alignment. + # If someone only operates on a Dask DataFrame via expressions, then this + # should always be the case: expression outputs (by definition) all come from the + # same input dataframe, and Dask Series does not have any operations which + # change the index. Nonetheless, we perform this safety check anyway. + + # However, we still need to carefully vet which methods we support for Dask, to + # avoid issues where `are_co_aligned` doesn't do what we want it to do: + # https://github.com/dask/dask-expr/issues/1112. + msg = "Implicit index alignment is not support for Dask DataFrame in Narwhals" + raise NotImplementedError(msg) + return result + return obj + + +def parse_exprs_and_named_exprs( + df: DaskLazyFrame, *exprs: Any, **named_exprs: Any +) -> dict[str, Any]: + results = {} + for expr in exprs: + _results = expr._call(df) + for _result in _results: + results[_result.name] = _result + for name, value in named_exprs.items(): + _results = value._call(df) + if len(_results) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise AssertionError(msg) + results[name] = _results[0] + return results diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index cbb5eca3c..63a103b56 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -4,6 +4,7 @@ from __future__ import annotations import sys +import warnings from typing import TYPE_CHECKING from typing import Any @@ -65,6 +66,26 @@ def get_numpy() -> Any: return sys.modules.get("numpy", None) +def get_dask_dataframe() -> Any: + """Get dask.dataframe module (if already imported - else return None).""" + if "dask" in sys.modules: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="\nDask dataframe query planning.*", + category=FutureWarning, + ) + import dask.dataframe as dd + + return dd + return None # pragma: no cover + + +def get_dask_expr() -> Any: + """Get dask_expr module (if already imported - else return None).""" + return sys.modules.get("dask_expr", None) + + def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]: """Check whether `df` is a pandas DataFrame without importing pandas.""" return bool((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame)) diff --git a/narwhals/translate.py b/narwhals/translate.py index cf19af7c3..2d01fd222 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -9,6 +9,8 @@ from typing import overload from narwhals.dependencies import get_cudf +from narwhals.dependencies import get_dask_dataframe +from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars @@ -280,6 +282,7 @@ def from_native( # noqa: PLR0915 """ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries + from narwhals._dask.dataframe import DaskLazyFrame from narwhals._interchange.dataframe import InterchangeFrame from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries @@ -467,6 +470,26 @@ def from_native( # noqa: PLR0915 level="full", ) + # Dask + elif (dd := get_dask_dataframe()) is not None and isinstance( + native_object, dd.DataFrame + ): + if series_only: # pragma: no cover + # TODO(unassigned): increase coverage + msg = "Cannot only use `series_only` with dask DataFrame" + raise TypeError(msg) + if eager_only or eager_or_interchange_only: # pragma: no cover + # TODO(unassigned): increase coverage + msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame" + raise TypeError(msg) + if get_dask_expr() is None: # pragma: no cover + msg = "Please install dask-expr" + raise ImportError(msg) + return LazyFrame( + DaskLazyFrame(native_object, backend_version=parse_version(pl.__version__)), + level="full", + ) + # Interchange protocol elif hasattr(native_object, "__dataframe__"): if eager_only or series_only: diff --git a/requirements-dev.txt b/requirements-dev.txt index 0586d00c6..98bf9d8c7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,4 +7,4 @@ pytest pytest-cov hypothesis scikit-learn - +dask[dataframe]; python_version >= '3.9' diff --git a/tests/dask_test.py b/tests/dask_test.py new file mode 100644 index 000000000..6563d34be --- /dev/null +++ b/tests/dask_test.py @@ -0,0 +1,48 @@ +""" +Dask support in Narwhals is still _very_ scant. + +Start with a simple test file whilst we develop the basics. +Once we're a bit further along (say, we can at least evaluate +TPC-H Q1 with Dask), then we can integrate dask tests into +the main test suite. +""" + +import sys + +import pandas as pd +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import compare_dicts + +pytest.importorskip("dask_expr") + + +if sys.version_info < (3, 9): + pytest.skip("Dask tests require Python 3.9+", allow_module_level=True) + + +def test_with_columns() -> None: + import dask.dataframe as dd + + dfdd = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + + df = nw.from_native(dfdd) + df = df.with_columns( + nw.col("a") + 1, + (nw.col("a") + nw.col("b").mean()).alias("c"), + d=nw.col("a"), + e=nw.col("a") + nw.col("b"), + ) + + result = nw.to_native(df).compute() + compare_dicts( + result, + { + "a": [2, 3, 4], + "b": [4, 5, 6], + "c": [6.0, 7.0, 8.0], + "d": [1, 2, 3], + "e": [5, 7, 9], + }, + )