Skip to content

Commit

Permalink
multifill and typetracing optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Mar 1, 2024
1 parent 8caccd5 commit 78672dd
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 123 deletions.
153 changes: 82 additions & 71 deletions src/dask_histogram/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import boost_histogram.storage as storage
import dask
import dask.array as da
from dask.bag.core import empty_safe_aggregate, partition_all
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
from dask.context import globalmethod
from dask.delayed import Delayed, delayed
Expand All @@ -20,7 +19,14 @@
from tlz import first

from dask_histogram.bins import normalize_bins_range
from dask_histogram.core import AggHistogram, _get_optimization_function, factory
from dask_histogram.core import (
AggHistogram,
_get_optimization_function,
_partitioned_histogram_multifill,
_reduction,
factory,
is_dask_awkward_like,
)

if TYPE_CHECKING:
from dask_histogram.typing import (
Expand All @@ -36,55 +42,6 @@
__all__ = ("Histogram", "histogram", "histogram2d", "histogramdd")


def _build_staged_tree_reduce(
stages: list[AggHistogram], split_every: int | bool
) -> HighLevelGraph:
if not split_every:
split_every = len(stages)

reducer = sum

token = tokenize(stages, reducer, split_every)

k = len(stages)
b = ""
fmt = f"staged-fill-aggregate-{token}"
depth = 0

dsk = {}

if k > 1:
while k > split_every:
c = fmt + str(depth)
for i, inds in enumerate(partition_all(split_every, range(k))):
dsk[(c, i)] = (
empty_safe_aggregate,
reducer,
[
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
for j in inds
],
False,
)

k = i + 1
b = c
depth += 1

dsk[(fmt, 0)] = (
empty_safe_aggregate,
reducer,
[
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
for j in range(k)
],
True,
)
return fmt, HighLevelGraph.from_collections(fmt, dsk, dependencies=stages)

return stages[0].name, stages[0].dask


class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
"""Histogram object capable of lazy computation.
Expand All @@ -97,9 +54,6 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
type is :py:class:`boost_histogram.storage.Double`.
metadata : Any
Data that is passed along if a new histogram is created.
split_every : int | bool | None, default None
Width of aggregation layers for staged fills.
If False, all staged fills are added in one layer (memory intensive!).
See Also
--------
Expand Down Expand Up @@ -139,7 +93,7 @@ def __init__(
) -> None:
"""Construct a Histogram object."""
super().__init__(*axes, storage=storage, metadata=metadata)
self._staged: list[AggHistogram] | None = None
self._staged: AggHistogram | None = None
self._dask_name: str | None = (
f"empty-histogram-{tokenize(*axes, storage, metadata)}"
)
Expand All @@ -148,14 +102,12 @@ def __init__(
{},
)
self._split_every = split_every
if self._split_every is None:
self._split_every = dask.config.get("histogram.aggregation.split_every", 8)

@property
def _histref(self):
return (
tuple(self.axes),
self.storage_type,
self.storage_type(),
self.metadata,
)

Expand All @@ -164,12 +116,6 @@ def __iadd__(self, other):
self._staged += other._staged
elif not self.staged_fills() and other.staged_fills():
self._staged = other._staged
if self.staged_fills():
new_name, new_graph = _build_staged_tree_reduce(
self._staged, self._split_every
)
self._dask = new_graph
self._dask_name = new_name
return self

def __add__(self, other):
Expand Down Expand Up @@ -234,6 +180,8 @@ def _in_memory_type(self) -> type[bh.Histogram]:

@property
def dask_name(self) -> str:
if self._dask_name == "__not_yet_calculated__" and self._dask is None:
self._build_taskgraph()
if self._dask_name is None:
raise RuntimeError(
"The dask name should never be None when it's requested."
Expand All @@ -242,12 +190,73 @@ def dask_name(self) -> str:

@property
def dask(self) -> HighLevelGraph:
if self._dask_name == "__not_yet_calculated__" and self._dask is None:
self._build_taskgraph()
if self._dask is None:
raise RuntimeError(
"The dask graph should never be None when it's requested."
)
return self._dask

def _build_taskgraph(self):
data_list = []
weights = []
samples = []

dask_data = tuple(
datum
for datum in (
self._staged[0]["args"] + tuple(self._staged[0]["kwargs"].values())
)
if is_dask_collection(datum)
)

if is_dask_awkward_like(dask_data[0]):

for afill in self._staged:
data_list.append(afill["args"])
weights.append(afill["kwargs"]["weight"])
samples.append(afill["kwargs"]["sample"])

if all(weight is None for weight in weights):
weights = None

if not all(sample is None for sample in samples):
samples = None

split_every = self._split_every
if split_every is None:
split_every = dask.config.get("histogram.aggregation.split_every", 8)

fills = _partitioned_histogram_multifill(
data_list, self._histref, weights, samples
)

output_hist = _reduction(fills, split_every)
else:

first_fill = self._staged.pop()

output_hist = factory(
*first_fill["args"],
histref=self._histref,
weights=first_fill["kwargs"]["weight"],
sample=first_fill["kwargs"]["sample"],
)

for afill in self._staged:
output_hist += factory(
*afill["args"],
histref=self._histref,
weights=afill["kwargs"]["weight"],
sample=afill["kwargs"]["sample"],
)

self._staged = None
self._staged_result = output_hist
self._dask = output_hist.dask
self._dask_name = output_hist.name

def fill( # type: ignore
self,
*args: DaskCollection,
Expand Down Expand Up @@ -318,14 +327,14 @@ def fill( # type: ignore
else:
raise ValueError(f"Cannot interpret input data: {args}")

new_fill = factory(*args, histref=self._histref, weights=weight, sample=sample)
# new_fill = partitioned_factory(*args, histref=self._histref, weights=weight, sample=sample)
new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}}
if self._staged is None:
self._staged = [new_fill]
else:
self._staged += [new_fill]
new_name, new_graph = _build_staged_tree_reduce(self._staged, self._split_every)
self._dask = new_graph
self._dask_name = new_name
self._staged.append(new_fill)
self._dask = None # self._staged.__dask_graph__()
self._dask_name = "__not_yet_calculated__"

return self

Expand Down Expand Up @@ -383,7 +392,8 @@ def to_delayed(self) -> Delayed:
"""
if self._staged is not None:
return sum(self._staged[1:], start=self._staged[0]).to_delayed()
self._build_taskgraph()
return self._staged_result.to_delayed()
return delayed(bh.Histogram(self))

def __repr__(self) -> str:
Expand Down Expand Up @@ -449,7 +459,8 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any:
"""
if self._staged is not None:
return sum(self._staged).to_dask_array(flow=flow, dd=dd)
self._build_taskgraph()
return self._staged_result.to_dask_array(flow=flow, dd=dd)
else:
counts, edges = self.to_numpy(flow=flow, dd=True, view=False)
counts = da.from_array(counts)
Expand Down
Loading

0 comments on commit 78672dd

Please sign in to comment.