Skip to content

Commit

Permalink
feat: Add initial (and very minimal) support for Dask.DataFrame (#635)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 27, 2024
1 parent 7c7da8b commit c380fc0
Show file tree
Hide file tree
Showing 10 changed files with 371 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: Run doctests
run: pytest narwhals --doctest-modules

pandas-nightly:
pandas-nightly-and-dask:
strategy:
matrix:
python-version: ["3.12"]
Expand Down Expand Up @@ -102,6 +102,9 @@ jobs:
run: python -m pip uninstall numpy -y
- name: install numpy nightly
run: python -m pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy
- name: install dask
run: |
pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask git+https://github.com/dask/dask-expr
- name: show-deps
run: pip freeze
- name: Run pytest
Expand Down
Empty file added narwhals/_dask/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals.dependencies import get_dask_dataframe

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._dask.expr import DaskExpr
from narwhals._dask.namespace import DaskNamespace


class DaskLazyFrame:
def __init__(
self, native_dataframe: Any, *, backend_version: tuple[int, ...]
) -> None:
self._native_dataframe = native_dataframe
self._backend_version = backend_version

def __native_namespace__(self) -> Any: # pragma: no cover
return get_dask_dataframe()

def __narwhals_namespace__(self) -> DaskNamespace:
from narwhals._dask.namespace import DaskNamespace

return DaskNamespace(backend_version=self._backend_version)

def __narwhals_lazyframe__(self) -> Self:
return self

def _from_native_dataframe(self, df: Any) -> Self:
return self.__class__(df, backend_version=self._backend_version)

def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = self._native_dataframe
new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
df = df.assign(**new_series)
return self._from_native_dataframe(df)
144 changes: 144 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

from narwhals.dependencies import get_dask_expr

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._dask.dataframe import DaskLazyFrame

from narwhals._dask.utils import maybe_evaluate


class DaskExpr:
def __init__(
self,
# callable from DaskLazyFrame to list of (native) Dask Series
call: Callable[[DaskLazyFrame], Any],
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
backend_version: tuple[int, ...],
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._backend_version = backend_version

@classmethod
def from_column_names(
cls: type[Self],
*column_names: str,
backend_version: tuple[int, ...],
) -> Self:
def func(df: DaskLazyFrame) -> list[Any]:
return [
df._native_dataframe.loc[:, column_name] for column_name in column_names
]

return cls(
func,
depth=0,
function_name="col",
root_names=list(column_names),
output_names=list(column_names),
backend_version=backend_version,
)

def _from_call(
self,
# callable from DaskLazyFrame to list of (native) Dask Series
call: Any,
expr_name: str,
*args: Any,
**kwargs: Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[Any]:
results = []
inputs = self._call(df)
for _input in inputs:
_args = [maybe_evaluate(df, x) for x in args]
_kwargs = {
key: maybe_evaluate(df, value) for key, value in kwargs.items()
}
result = call(_input, *_args, **_kwargs)
if isinstance(result, get_dask_expr()._collection.Series):
result = result.rename(_input.name)
results.append(result)
return results

# Try tracking root and output names by combining them from all
# expressions appearing in args and kwargs. If any anonymous
# expression appears (e.g. nw.all()), then give up on tracking root names
# and just set it to None.
root_names = copy(self._root_names)
output_names = self._output_names
for arg in list(args) + list(kwargs.values()):
if root_names is not None and isinstance(arg, self.__class__):
if arg._root_names is not None:
root_names.extend(arg._root_names)
else: # pragma: no cover
# TODO(unassigned): increase coverage
root_names = None
output_names = None
break
elif root_names is None: # pragma: no cover
# TODO(unassigned): increase coverage
output_names = None
break

if not (
(output_names is None and root_names is None)
or (output_names is not None and root_names is not None)
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
root_names=root_names,
output_names=output_names,
backend_version=self._backend_version,
)

def alias(self, name: str) -> Self:
def func(df: DaskLazyFrame) -> list[Any]:
results = []
inputs = self._call(df)
for _input in inputs:
result = _input.rename(name)
results.append(result)
return results

return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
root_names=self._root_names,
output_names=[name],
backend_version=self._backend_version,
)

def __add__(self, other: Any) -> Self:
return self._from_call(
lambda _input, other: _input.__add__(other),
"__add__",
other,
)

def mean(self) -> Self:
return self._from_call(
lambda _input: _input.mean(),
"mean",
)
35 changes: 35 additions & 0 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from narwhals import dtypes
from narwhals._dask.expr import DaskExpr


class DaskNamespace:
Int64 = dtypes.Int64
Int32 = dtypes.Int32
Int16 = dtypes.Int16
Int8 = dtypes.Int8
UInt64 = dtypes.UInt64
UInt32 = dtypes.UInt32
UInt16 = dtypes.UInt16
UInt8 = dtypes.UInt8
Float64 = dtypes.Float64
Float32 = dtypes.Float32
Boolean = dtypes.Boolean
Object = dtypes.Object
Unknown = dtypes.Unknown
Categorical = dtypes.Categorical
Enum = dtypes.Enum
String = dtypes.String
Datetime = dtypes.Datetime
Duration = dtypes.Duration
Date = dtypes.Date

def __init__(self, *, backend_version: tuple[int, ...]) -> None:
self._backend_version = backend_version

def col(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names,
backend_version=self._backend_version,
)
54 changes: 54 additions & 0 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals.dependencies import get_dask_expr

if TYPE_CHECKING:
from narwhals._dask.dataframe import DaskLazyFrame


def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
from narwhals._dask.expr import DaskExpr

if isinstance(obj, DaskExpr):
results = obj._call(df)
if len(results) != 1: # pragma: no cover
msg = "Multi-output expressions not supported in this context"
raise NotImplementedError(msg)
result = results[0]
if not get_dask_expr()._expr.are_co_aligned(
df._native_dataframe._expr, result._expr
): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
# should always be the case: expression outputs (by definition) all come from the
# same input dataframe, and Dask Series does not have any operations which
# change the index. Nonetheless, we perform this safety check anyway.

# However, we still need to carefully vet which methods we support for Dask, to
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
# https://github.com/dask/dask-expr/issues/1112.
msg = "Implicit index alignment is not support for Dask DataFrame in Narwhals"
raise NotImplementedError(msg)
return result
return obj


def parse_exprs_and_named_exprs(
df: DaskLazyFrame, *exprs: Any, **named_exprs: Any
) -> dict[str, Any]:
results = {}
for expr in exprs:
_results = expr._call(df)
for _result in _results:
results[_result.name] = _result
for name, value in named_exprs.items():
_results = value._call(df)
if len(_results) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise AssertionError(msg)
results[name] = _results[0]
return results
21 changes: 21 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import sys
import warnings
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -65,6 +66,26 @@ def get_numpy() -> Any:
return sys.modules.get("numpy", None)


def get_dask_dataframe() -> Any:
"""Get dask.dataframe module (if already imported - else return None)."""
if "dask" in sys.modules:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="\nDask dataframe query planning.*",
category=FutureWarning,
)
import dask.dataframe as dd

return dd
return None # pragma: no cover


def get_dask_expr() -> Any:
"""Get dask_expr module (if already imported - else return None)."""
return sys.modules.get("dask_expr", None)


def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]:
"""Check whether `df` is a pandas DataFrame without importing pandas."""
return bool((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame))
Expand Down
23 changes: 23 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import overload

from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_dask_dataframe
from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
Expand Down Expand Up @@ -280,6 +282,7 @@ def from_native( # noqa: PLR0915
"""
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._interchange.dataframe import InterchangeFrame
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -467,6 +470,26 @@ def from_native( # noqa: PLR0915
level="full",
)

# Dask
elif (dd := get_dask_dataframe()) is not None and isinstance(
native_object, dd.DataFrame
):
if series_only: # pragma: no cover
# TODO(unassigned): increase coverage
msg = "Cannot only use `series_only` with dask DataFrame"
raise TypeError(msg)
if eager_only or eager_or_interchange_only: # pragma: no cover
# TODO(unassigned): increase coverage
msg = "Cannot only use `eager_only` or `eager_or_interchange_only` with dask DataFrame"
raise TypeError(msg)
if get_dask_expr() is None: # pragma: no cover
msg = "Please install dask-expr"
raise ImportError(msg)
return LazyFrame(
DaskLazyFrame(native_object, backend_version=parse_version(pl.__version__)),
level="full",
)

# Interchange protocol
elif hasattr(native_object, "__dataframe__"):
if eager_only or series_only:
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ pytest
pytest-cov
hypothesis
scikit-learn

dask[dataframe]; python_version >= '3.9'
Loading

0 comments on commit c380fc0

Please sign in to comment.