-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add initial (and very minimal) support for Dask.DataFrame (#635)
- Loading branch information
1 parent
7c7da8b
commit c380fc0
Showing
10 changed files
with
371 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
|
||
from narwhals._dask.utils import parse_exprs_and_named_exprs | ||
from narwhals.dependencies import get_dask_dataframe | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import Self | ||
|
||
from narwhals._dask.expr import DaskExpr | ||
from narwhals._dask.namespace import DaskNamespace | ||
|
||
|
||
class DaskLazyFrame: | ||
def __init__( | ||
self, native_dataframe: Any, *, backend_version: tuple[int, ...] | ||
) -> None: | ||
self._native_dataframe = native_dataframe | ||
self._backend_version = backend_version | ||
|
||
def __native_namespace__(self) -> Any: # pragma: no cover | ||
return get_dask_dataframe() | ||
|
||
def __narwhals_namespace__(self) -> DaskNamespace: | ||
from narwhals._dask.namespace import DaskNamespace | ||
|
||
return DaskNamespace(backend_version=self._backend_version) | ||
|
||
def __narwhals_lazyframe__(self) -> Self: | ||
return self | ||
|
||
def _from_native_dataframe(self, df: Any) -> Self: | ||
return self.__class__(df, backend_version=self._backend_version) | ||
|
||
def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: | ||
df = self._native_dataframe | ||
new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) | ||
df = df.assign(**new_series) | ||
return self._from_native_dataframe(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from __future__ import annotations | ||
|
||
from copy import copy | ||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
from typing import Callable | ||
|
||
from narwhals.dependencies import get_dask_expr | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import Self | ||
|
||
from narwhals._dask.dataframe import DaskLazyFrame | ||
|
||
from narwhals._dask.utils import maybe_evaluate | ||
|
||
|
||
class DaskExpr: | ||
def __init__( | ||
self, | ||
# callable from DaskLazyFrame to list of (native) Dask Series | ||
call: Callable[[DaskLazyFrame], Any], | ||
*, | ||
depth: int, | ||
function_name: str, | ||
root_names: list[str] | None, | ||
output_names: list[str] | None, | ||
backend_version: tuple[int, ...], | ||
) -> None: | ||
self._call = call | ||
self._depth = depth | ||
self._function_name = function_name | ||
self._root_names = root_names | ||
self._output_names = output_names | ||
self._backend_version = backend_version | ||
|
||
@classmethod | ||
def from_column_names( | ||
cls: type[Self], | ||
*column_names: str, | ||
backend_version: tuple[int, ...], | ||
) -> Self: | ||
def func(df: DaskLazyFrame) -> list[Any]: | ||
return [ | ||
df._native_dataframe.loc[:, column_name] for column_name in column_names | ||
] | ||
|
||
return cls( | ||
func, | ||
depth=0, | ||
function_name="col", | ||
root_names=list(column_names), | ||
output_names=list(column_names), | ||
backend_version=backend_version, | ||
) | ||
|
||
def _from_call( | ||
self, | ||
# callable from DaskLazyFrame to list of (native) Dask Series | ||
call: Any, | ||
expr_name: str, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Self: | ||
def func(df: DaskLazyFrame) -> list[Any]: | ||
results = [] | ||
inputs = self._call(df) | ||
for _input in inputs: | ||
_args = [maybe_evaluate(df, x) for x in args] | ||
_kwargs = { | ||
key: maybe_evaluate(df, value) for key, value in kwargs.items() | ||
} | ||
result = call(_input, *_args, **_kwargs) | ||
if isinstance(result, get_dask_expr()._collection.Series): | ||
result = result.rename(_input.name) | ||
results.append(result) | ||
return results | ||
|
||
# Try tracking root and output names by combining them from all | ||
# expressions appearing in args and kwargs. If any anonymous | ||
# expression appears (e.g. nw.all()), then give up on tracking root names | ||
# and just set it to None. | ||
root_names = copy(self._root_names) | ||
output_names = self._output_names | ||
for arg in list(args) + list(kwargs.values()): | ||
if root_names is not None and isinstance(arg, self.__class__): | ||
if arg._root_names is not None: | ||
root_names.extend(arg._root_names) | ||
else: # pragma: no cover | ||
# TODO(unassigned): increase coverage | ||
root_names = None | ||
output_names = None | ||
break | ||
elif root_names is None: # pragma: no cover | ||
# TODO(unassigned): increase coverage | ||
output_names = None | ||
break | ||
|
||
if not ( | ||
(output_names is None and root_names is None) | ||
or (output_names is not None and root_names is not None) | ||
): # pragma: no cover | ||
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" | ||
raise AssertionError(msg) | ||
|
||
return self.__class__( | ||
func, | ||
depth=self._depth + 1, | ||
function_name=f"{self._function_name}->{expr_name}", | ||
root_names=root_names, | ||
output_names=output_names, | ||
backend_version=self._backend_version, | ||
) | ||
|
||
def alias(self, name: str) -> Self: | ||
def func(df: DaskLazyFrame) -> list[Any]: | ||
results = [] | ||
inputs = self._call(df) | ||
for _input in inputs: | ||
result = _input.rename(name) | ||
results.append(result) | ||
return results | ||
|
||
return self.__class__( | ||
func, | ||
depth=self._depth, | ||
function_name=self._function_name, | ||
root_names=self._root_names, | ||
output_names=[name], | ||
backend_version=self._backend_version, | ||
) | ||
|
||
def __add__(self, other: Any) -> Self: | ||
return self._from_call( | ||
lambda _input, other: _input.__add__(other), | ||
"__add__", | ||
other, | ||
) | ||
|
||
def mean(self) -> Self: | ||
return self._from_call( | ||
lambda _input: _input.mean(), | ||
"mean", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import annotations | ||
|
||
from narwhals import dtypes | ||
from narwhals._dask.expr import DaskExpr | ||
|
||
|
||
class DaskNamespace: | ||
Int64 = dtypes.Int64 | ||
Int32 = dtypes.Int32 | ||
Int16 = dtypes.Int16 | ||
Int8 = dtypes.Int8 | ||
UInt64 = dtypes.UInt64 | ||
UInt32 = dtypes.UInt32 | ||
UInt16 = dtypes.UInt16 | ||
UInt8 = dtypes.UInt8 | ||
Float64 = dtypes.Float64 | ||
Float32 = dtypes.Float32 | ||
Boolean = dtypes.Boolean | ||
Object = dtypes.Object | ||
Unknown = dtypes.Unknown | ||
Categorical = dtypes.Categorical | ||
Enum = dtypes.Enum | ||
String = dtypes.String | ||
Datetime = dtypes.Datetime | ||
Duration = dtypes.Duration | ||
Date = dtypes.Date | ||
|
||
def __init__(self, *, backend_version: tuple[int, ...]) -> None: | ||
self._backend_version = backend_version | ||
|
||
def col(self, *column_names: str) -> DaskExpr: | ||
return DaskExpr.from_column_names( | ||
*column_names, | ||
backend_version=self._backend_version, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
|
||
from narwhals.dependencies import get_dask_expr | ||
|
||
if TYPE_CHECKING: | ||
from narwhals._dask.dataframe import DaskLazyFrame | ||
|
||
|
||
def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: | ||
from narwhals._dask.expr import DaskExpr | ||
|
||
if isinstance(obj, DaskExpr): | ||
results = obj._call(df) | ||
if len(results) != 1: # pragma: no cover | ||
msg = "Multi-output expressions not supported in this context" | ||
raise NotImplementedError(msg) | ||
result = results[0] | ||
if not get_dask_expr()._expr.are_co_aligned( | ||
df._native_dataframe._expr, result._expr | ||
): # pragma: no cover | ||
# are_co_aligned is a method which cheaply checks if two Dask expressions | ||
# have the same index, and therefore don't require index alignment. | ||
# If someone only operates on a Dask DataFrame via expressions, then this | ||
# should always be the case: expression outputs (by definition) all come from the | ||
# same input dataframe, and Dask Series does not have any operations which | ||
# change the index. Nonetheless, we perform this safety check anyway. | ||
|
||
# However, we still need to carefully vet which methods we support for Dask, to | ||
# avoid issues where `are_co_aligned` doesn't do what we want it to do: | ||
# https://github.com/dask/dask-expr/issues/1112. | ||
msg = "Implicit index alignment is not support for Dask DataFrame in Narwhals" | ||
raise NotImplementedError(msg) | ||
return result | ||
return obj | ||
|
||
|
||
def parse_exprs_and_named_exprs( | ||
df: DaskLazyFrame, *exprs: Any, **named_exprs: Any | ||
) -> dict[str, Any]: | ||
results = {} | ||
for expr in exprs: | ||
_results = expr._call(df) | ||
for _result in _results: | ||
results[_result.name] = _result | ||
for name, value in named_exprs.items(): | ||
_results = value._call(df) | ||
if len(_results) != 1: # pragma: no cover | ||
msg = "Named expressions must return a single column" | ||
raise AssertionError(msg) | ||
results[name] = _results[0] | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,4 @@ pytest | |
pytest-cov | ||
hypothesis | ||
scikit-learn | ||
|
||
dask[dataframe]; python_version >= '3.9' |
Oops, something went wrong.