Skip to content

Commit

Permalink
fix: properly use dask-awkward optimizations in all scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Feb 1, 2024
1 parent 5289f10 commit 49516f1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/dask_histogram/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tlz import first

from dask_histogram.bins import normalize_bins_range
from dask_histogram.core import AggHistogram, factory, optimize
from dask_histogram.core import AggHistogram, _get_optimization_function, factory

if TYPE_CHECKING:
from dask_histogram.typing import (
Expand Down Expand Up @@ -201,7 +201,7 @@ def __dask_postpersist__(self) -> Any:
return self._rebuild, ()

__dask_optimize__ = globalmethod(
optimize, key="histogram_optimize", falsey=dont_optimize
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
)

__dask_scheduler__ = staticmethod(tget)
Expand Down
29 changes: 15 additions & 14 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,27 +401,28 @@ def optimize(
keys: Hashable | list[Hashable] | set[Hashable],
**kwargs: Any,
) -> Mapping:
if not isinstance(keys, (list, set)):
keys = [keys]
keys = list(flatten(keys))
keys = tuple(flatten(keys))

if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
dsk = HighLevelGraph.from_collections(str(id(dsk)), dsk, dependencies=())

dsk = optimize_blockwise(dsk, keys=keys)
dsk = fuse_roots(dsk, keys=keys) # type: ignore
dsk = dsk.cull(set(keys)) # type: ignore
return dsk


def _get_optimization_function():
# Here we try to run optimizations from dask-awkward (if we detect
# that dask-awkward has been imported). There is no cost to
# running this optimization even in cases where it's unncessary
# because if no AwkwardInputLayers from daks-awkward are not
# because if no AwkwardInputLayers from dask-awkward are
# detected then the original graph is returned unchanged.
if dask.config.get("awkward", default=False):
from dask_awkward.lib.optimize import optimize
from dask_awkward.lib.optimize import all_optimizations

dsk = optimize(dsk, keys=keys) # type: ignore[arg-type]

dsk = optimize_blockwise(dsk, keys=keys)
dsk = fuse_roots(dsk, keys=keys) # type: ignore
dsk = dsk.cull(set(keys)) # type: ignore
return dsk
return all_optimizations
return optimize


class AggHistogram(DaskMethodsMixin):
Expand Down Expand Up @@ -479,7 +480,7 @@ def __dask_postpersist__(self) -> Any:
return self._rebuild, ()

__dask_optimize__ = globalmethod(
optimize, key="histogram_optimize", falsey=dont_optimize
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
)

__dask_scheduler__ = staticmethod(tget)
Expand Down Expand Up @@ -706,7 +707,7 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
return type(self)(dsk, name, self.npartitions, self.histref)

__dask_optimize__ = globalmethod(
optimize, key="histogram_optimize", falsey=dont_optimize
_get_optimization_function(), key="histogram_optimize", falsey=dont_optimize
)

__dask_scheduler__ = staticmethod(tget)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights):
dhb.axis.Regular(9, -3.2, 3.2),
storage=storage,
)

# check that we are using the correct optimizer
assert h.__dask_optimize__ == dak.lib.optimize.all_optimizations

for i in range(25):
h.fill(f"testcat{i+1}", i + 1, x, y, z, weight=weights)
h = h.compute()
Expand Down

0 comments on commit 49516f1

Please sign in to comment.