Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][POC] Add ResourceBarrier expression to change resources within an expression graph #1116

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ def __array__(self, dtype=None, **kwargs):

def persist(self, fuse=True, **kwargs):
out = self.optimize(fuse=fuse)
return DaskMethodsMixin.persist(out, **kwargs)
return DaskMethodsMixin.persist(
out,
task_resources=out.expr.collect_task_resources(),
**kwargs,
)

def compute(self, fuse=True, **kwargs):
"""Compute this DataFrame.
Expand Down Expand Up @@ -474,7 +478,11 @@ def compute(self, fuse=True, **kwargs):
if not isinstance(out, Scalar):
out = out.repartition(npartitions=1)
out = out.optimize(fuse=fuse)
return DaskMethodsMixin.compute(out, **kwargs)
return DaskMethodsMixin.compute(
out,
task_resources=out.expr.collect_task_resources(),
**kwargs,
)

def analyze(self, filename: str | None = None, format: str | None = None) -> None:
"""Outputs statistics about every node in the expression.
Expand Down Expand Up @@ -2494,6 +2502,30 @@ def to_delayed(self, optimize_graph=True):
"""
return self.to_legacy_dataframe().to_delayed(optimize_graph=optimize_graph)

def resource_barrier(self, resources):
"""Define a resource-constraint barrier

Parameters
----------
resources : dict
Resource constraint (e.g. ``{GPU: 1}``).

Notes
-----
1. This resources constraint will be applied to all tasks
generated by operations after this point (or until the
`resource_barrier` API is used again).
2. This resource constraint will superceed any other
resource constraints defined with global annotations.
3. Creating a resource barrier will not block optimizations
like column projection or predicate pushdown. We assume
both projection and filtering are resource agnostic.
4. Resource constraints only apply to distributed execution.
5. The scheduler will only try to satisfy resource constraints
when relevant worker resources exist.
"""
return new_collection(expr.ElemwiseResourceBarrier(self.expr, resources))

def to_backend(self, backend: str | None = None, **kwargs):
"""Move to a new DataFrame backend

Expand Down Expand Up @@ -5192,6 +5224,7 @@ def read_parquet(
filesystem="fsspec",
engine=None,
arrow_to_pandas=None,
resources=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -5371,6 +5404,10 @@ def read_parquet(
arrow_to_pandas: dict, default None
Dictionary of options to use when converting from ``pyarrow.Table`` to
a pandas ``DataFrame`` object. Only used by the "arrow" engine.
resources: dict, default None
Resource constraint to apply to the generated IO tasks and all
future operations. The `resource_barrier` API can be used to modify
future resource constraints after the collection is created.
**kwargs: dict (of dicts)
Options to pass through to ``engine.read_partitions`` as stand-alone
key-word arguments. Note that these options will be ignored by the
Expand All @@ -5386,6 +5423,7 @@ def read_parquet(
to_parquet
pyarrow.parquet.ParquetDataset
"""
from dask_expr.io.io import IOResourceBarrier
from dask_expr.io.parquet import (
ReadParquetFSSpec,
ReadParquetPyarrowFS,
Expand All @@ -5405,6 +5443,9 @@ def read_parquet(
if op == "in" and not isinstance(val, (set, list, tuple)):
raise TypeError("Value of 'in' filter must be a list, set or tuple.")

if resources is not None:
resources = IOResourceBarrier(resources)

if (
isinstance(filesystem, pa_fs.FileSystem)
or isinstance(filesystem, str)
Expand Down Expand Up @@ -5454,6 +5495,7 @@ def read_parquet(
pyarrow_strings_enabled=pyarrow_strings_enabled(),
kwargs=kwargs,
_series=isinstance(columns, str),
resource_requirement=resources,
)
)

Expand All @@ -5476,6 +5518,7 @@ def read_parquet(
engine=_set_parquet_engine(engine),
kwargs=kwargs,
_series=isinstance(columns, str),
resource_requirement=resources,
)
)

Expand Down
24 changes: 24 additions & 0 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,30 @@ def walk(self) -> Generator[Expr]:

yield node

def collect_task_resources(self) -> dict:
resources_annotation = {}
stack = [self]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)

resources = node._resources
if resources is not None:
resources_annotation.update(
{
k: (resources(k) if callable(resources) else resources)
for k in node._layer().keys()
}
)

for dep in node.dependencies():
stack.append(dep)

return resources_annotation

def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
"""Search the expression graph for a specific operation type

Expand Down
39 changes: 37 additions & 2 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
random_state_data,
)
from pandas.errors import PerformanceWarning
from tlz import merge_sorted, partition, unique
from tlz import merge, merge_sorted, partition, unique

from dask_expr import _core as core
from dask_expr._util import (
Expand Down Expand Up @@ -87,6 +87,11 @@ def ndim(self):
except AttributeError:
return 0

@functools.cached_property
def _resources(self):
dep_resources = merge(dep._resources or {} for dep in self.dependencies())
return dep_resources or None

def __dask_keys__(self):
return [(self._name, i) for i in range(self.npartitions)]

Expand Down Expand Up @@ -1303,6 +1308,36 @@ def operation(df):
return df.copy(deep=True)


class ResourceBarrier(Expr):
@property
def _resources(self):
raise NotImplementedError()

def __str__(self):
return f"{type(self).__name__}({self._resources})"


class ElemwiseResourceBarrier(Elemwise, ResourceBarrier):
_parameters = ["frame", "resource_spec"]
_projection_passthrough = True
_filter_passthrough = True
_preserves_partitioning_information = True

@property
def _resources(self):
return self.resource_spec

def _task(self, index: int):
return (self.frame._name, index)

@property
def _meta(self):
return self.frame._meta

def _divisions(self):
return self.frame.divisions


class RenameSeries(Elemwise):
_parameters = ["frame", "index", "sorted_index"]
_defaults = {"sorted_index": False}
Expand Down Expand Up @@ -3128,7 +3163,7 @@ def are_co_aligned(*exprs):

def is_valid_blockwise_op(expr):
return isinstance(expr, Blockwise) and not isinstance(
expr, (FromPandas, FromArray, FromDelayed)
expr, (FromPandas, FromArray, FromDelayed, ResourceBarrier)
)


Expand Down
19 changes: 18 additions & 1 deletion dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Literal,
PartitionsFiltered,
Projection,
ResourceBarrier,
determine_column_projection,
no_default,
)
Expand All @@ -31,6 +32,17 @@ def __str__(self):
return f"{type(self).__name__}({self._name[-7:]})"


class IOResourceBarrier(ResourceBarrier):
_parameters = ["resource_spec"]

@property
def _resources(self):
return self.resource_spec

def _layer(self):
return {}


class FromGraph(IO):
"""A DataFrame created from an opaque Dask task graph

Expand Down Expand Up @@ -149,7 +161,8 @@ def _tune_up(self, parent):


class FusedParquetIO(FusedIO):
_parameters = ["_expr"]
_parameters = ["_expr", "resource_requirement"]
_defaults = {"resource_requirement": None}

@functools.cached_property
def _name(self):
Expand All @@ -159,6 +172,10 @@ def _name(self):
+ _tokenize_deterministic(*self.operands)
)

def dependencies(self):
dep = self.resource_requirement
return [] if dep is None else [dep]

@staticmethod
def _load_multiple_files(
frag_filters,
Expand Down
6 changes: 5 additions & 1 deletion dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ class ReadParquetPyarrowFS(ReadParquet):
"arrow_to_pandas",
"pyarrow_strings_enabled",
"kwargs",
"resource_requirement",
"_partitions",
"_series",
"_dataset_info_cache",
Expand All @@ -844,6 +845,7 @@ class ReadParquetPyarrowFS(ReadParquet):
"arrow_to_pandas": None,
"pyarrow_strings_enabled": True,
"kwargs": None,
"resource_requirement": None,
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
Expand Down Expand Up @@ -1098,7 +1100,7 @@ def _tune_up(self, parent):
return
if isinstance(parent, FusedParquetIO):
return
return parent.substitute(self, FusedParquetIO(self))
return parent.substitute(self, FusedParquetIO(self, self.resource_requirement))

@cached_property
def fragments(self):
Expand Down Expand Up @@ -1253,6 +1255,7 @@ class ReadParquetFSSpec(ReadParquet):
"filesystem",
"engine",
"kwargs",
"resource_requirement",
"_partitions",
"_series",
"_dataset_info_cache",
Expand All @@ -1273,6 +1276,7 @@ class ReadParquetFSSpec(ReadParquet):
"filesystem": "fsspec",
"engine": "pyarrow",
"kwargs": None,
"resource_requirement": None,
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
Expand Down
Loading