Skip to content

Commit

Permalink
Introduce ToBackend expression (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora authored Aug 13, 2024
1 parent 2b8f765 commit 37a5116
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
26 changes: 16 additions & 10 deletions dask_expr/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pandas as pd
from dask.backends import CreationDispatch
from dask.dataframe.backends import DataFrameBackendEntrypoint
from dask.dataframe.dispatch import to_pandas_dispatch

from dask_expr._dispatch import get_collection_type
from dask_expr._expr import ToBackend

try:
import sparse
Expand All @@ -32,25 +34,29 @@
)


class ToPandasBackend(ToBackend):
@staticmethod
def operation(df, options):
return to_pandas_dispatch(df, **options)

def _simplify_down(self):
if isinstance(self.frame._meta, (pd.DataFrame, pd.Series, pd.Index)):
# We already have pandas data
return self.frame


class PandasBackendEntrypoint(DataFrameBackendEntrypoint):
"""Pandas-Backend Entrypoint Class for Dask-Expressions
Note that all DataFrame-creation functions are defined
and registered 'in-place'.
"""

@classmethod
def to_backend_dispatch(cls):
from dask.dataframe.dispatch import to_pandas_dispatch

return to_pandas_dispatch

@classmethod
def to_backend(cls, data, **kwargs):
if isinstance(data._meta, (pd.DataFrame, pd.Series, pd.Index)):
# Already a pandas-backed collection
return data
return data.map_partitions(cls.to_backend_dispatch(), **kwargs)
from dask_expr._collection import new_collection

return new_collection(ToPandasBackend(data, kwargs))


dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint())
Expand Down
7 changes: 7 additions & 0 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,13 @@ def operation(df):
return df.copy(deep=True)


class ToBackend(Elemwise):
_parameters = ["frame", "options"]
_projection_passthrough = True
_filter_passthrough = True
_preserves_partitioning_information = True


class RenameSeries(Elemwise):
_parameters = ["frame", "index", "sorted_index"]
_defaults = {"sorted_index": False}
Expand Down
9 changes: 9 additions & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,3 +2665,12 @@ def test_empty_from_pandas_projection():
df["foo"] = from_pandas(foo, npartitions=1)
pdf["foo"] = foo
assert_eq(df["foo"], pdf["foo"])


def test_to_backend_simplify():
with dask.config.set({"dataframe.backend": "pandas"}):
df = from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}, npartitions=2)
df2 = df.to_backend("pandas")[["y"]]
assert str(df2.expr) != str(df[["y"]].expr)
df3 = df2.simplify()
assert str(df3.expr) == str(df[["y"]].expr)

0 comments on commit 37a5116

Please sign in to comment.