Skip to content

Commit

Permalink
Merge pull request #120 from lgray/tree_reduction_for_staged_fills
Browse files Browse the repository at this point in the history
feat: use a tree reduce for staged fills instead of pairwise adds
  • Loading branch information
martindurant authored Jan 30, 2024
2 parents 6e3fa1a + 7740f00 commit 5289f10
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 65 deletions.
82 changes: 72 additions & 10 deletions src/dask_histogram/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import boost_histogram as bh
import boost_histogram.axis as axis
import boost_histogram.storage as storage
import dask
import dask.array as da
from dask.bag.core import empty_safe_aggregate, partition_all
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
from dask.context import globalmethod
from dask.delayed import Delayed, delayed
Expand All @@ -34,6 +36,55 @@
__all__ = ("Histogram", "histogram", "histogram2d", "histogramdd")


def _build_staged_tree_reduce(
stages: list[AggHistogram], split_every: int | bool
) -> HighLevelGraph:
if not split_every:
split_every = len(stages)

reducer = sum

token = tokenize(stages, reducer, split_every)

k = len(stages)
b = ""
fmt = f"staged-fill-aggregate-{token}"
depth = 0

dsk = {}

if k > 1:
while k > split_every:
c = fmt + str(depth)
for i, inds in enumerate(partition_all(split_every, range(k))):
dsk[(c, i)] = (
empty_safe_aggregate,
reducer,
[
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
for j in inds
],
False,
)

k = i + 1
b = c
depth += 1

dsk[(fmt, 0)] = (
empty_safe_aggregate,
reducer,
[
(stages[j].name if depth == 0 else b, 0 if depth == 0 else j)
for j in range(k)
],
True,
)
return fmt, HighLevelGraph.from_collections(fmt, dsk, dependencies=stages)

return stages[0].name, stages[0].dask


class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
"""Histogram object capable of lazy computation.
Expand All @@ -46,6 +97,9 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram):
type is :py:class:`boost_histogram.storage.Double`.
metadata : Any
Data that is passed along if a new histogram is created.
split_every : int | bool | None, default None
Width of aggregation layers for staged fills.
If False, all staged fills are added in one layer (memory intensive!).
See Also
--------
Expand Down Expand Up @@ -81,23 +135,27 @@ def __init__(
*axes: bh.axis.Axis,
storage: bh.storage.Storage = bh.storage.Double(),
metadata: Any = None,
split_every: int | None = None,
) -> None:
"""Construct a Histogram object."""
super().__init__(*axes, storage=storage, metadata=metadata)
self._staged: AggHistogram | None = None
self._staged: list[AggHistogram] | None = None
self._dask_name: str | None = (
f"empty-histogram-{tokenize(*axes, storage, metadata)}"
)
self._dask: HighLevelGraph | None = HighLevelGraph(
{self._dask_name: {(self._dask_name, 0): (lambda: self,)}},
{},
)
self._split_every = split_every
if self._split_every is None:
self._split_every = dask.config.get("histogram.aggregation.split_every", 8)

@property
def _histref(self):
return (
tuple(self.axes),
self.storage_type(),
self.storage_type,
self.metadata,
)

Expand All @@ -107,8 +165,11 @@ def __iadd__(self, other):
elif not self.staged_fills() and other.staged_fills():
self._staged = other._staged
if self.staged_fills():
self._dask = self._staged.__dask_graph__()
self._dask_name = self._staged.name
new_name, new_graph = _build_staged_tree_reduce(
self._staged, self._split_every
)
self._dask = new_graph
self._dask_name = new_name
return self

def __add__(self, other):
Expand Down Expand Up @@ -259,11 +320,12 @@ def fill( # type: ignore

new_fill = factory(*args, histref=self._histref, weights=weight, sample=sample)
if self._staged is None:
self._staged = new_fill
self._staged = [new_fill]
else:
self._staged += new_fill
self._dask = self._staged.__dask_graph__()
self._dask_name = self._staged.name
self._staged += [new_fill]
new_name, new_graph = _build_staged_tree_reduce(self._staged, self._split_every)
self._dask = new_graph
self._dask_name = new_name

return self

Expand Down Expand Up @@ -321,7 +383,7 @@ def to_delayed(self) -> Delayed:
"""
if self._staged is not None:
return self._staged.to_delayed()
return sum(self._staged[1:], start=self._staged[0]).to_delayed()
return delayed(bh.Histogram(self))

def __repr__(self) -> str:
Expand Down Expand Up @@ -387,7 +449,7 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any:
"""
if self._staged is not None:
return self._staged.to_dask_array(flow=flow, dd=dd)
return sum(self._staged).to_dask_array(flow=flow, dd=dd)
else:
counts, edges = self.to_numpy(flow=flow, dd=True, view=False)
counts = da.from_array(counts)
Expand Down
53 changes: 30 additions & 23 deletions src/dask_histogram/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _blocked_sa(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
if data.ndim == 1:
return thehist.fill(data)
Expand All @@ -83,7 +83,7 @@ def _blocked_sa_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
if data.ndim == 1:
return thehist.fill(data, sample=sample)
Expand All @@ -103,7 +103,7 @@ def _blocked_sa_w(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
if data.ndim == 1:
return thehist.fill(data, weight=weights)
Expand All @@ -124,7 +124,7 @@ def _blocked_sa_w_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
if data.ndim == 1:
return thehist.fill(data, weight=weights, sample=sample)
Expand All @@ -142,7 +142,7 @@ def _blocked_ma(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*data)

Expand All @@ -157,7 +157,7 @@ def _blocked_ma_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*data, sample=sample)

Expand All @@ -172,7 +172,7 @@ def _blocked_ma_w(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*data, weight=weights)

Expand All @@ -188,7 +188,7 @@ def _blocked_ma_w_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*data, weight=weights, sample=sample)

Expand All @@ -201,7 +201,7 @@ def _blocked_df(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*(data[c] for c in data.columns), weight=None)

Expand All @@ -215,7 +215,7 @@ def _blocked_df_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*(data[c] for c in data.columns), sample=sample)

Expand All @@ -230,7 +230,7 @@ def _blocked_df_w(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*(data[c] for c in data.columns), weight=weights)

Expand All @@ -246,7 +246,7 @@ def _blocked_df_w_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*(data[c] for c in data.columns), weight=weights, sample=sample)

Expand Down Expand Up @@ -279,7 +279,7 @@ def _blocked_dak(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(thedata, weight=theweights, sample=thesample)

Expand All @@ -302,7 +302,7 @@ def _blocked_dak_ma(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*tuple(thedata))

Expand Down Expand Up @@ -330,7 +330,7 @@ def _blocked_dak_ma_w(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*tuple(thedata), weight=theweights)

Expand Down Expand Up @@ -358,7 +358,7 @@ def _blocked_dak_ma_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*tuple(thedata), sample=thesample)

Expand Down Expand Up @@ -391,7 +391,7 @@ def _blocked_dak_ma_w_s(
thehist = (
clone(histref)
if not isinstance(histref, tuple)
else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2])
else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2])
)
return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample)

Expand Down Expand Up @@ -515,11 +515,15 @@ def histref(self):
@property
def _storage_type(self) -> type[bh.storage.Storage]:
"""Storage type of the histogram."""
if isinstance(self.histref, tuple):
return self.histref[1]
return self.histref.storage_type

@property
def ndim(self) -> int:
"""Total number of dimensions."""
if isinstance(self.histref, tuple):
return len(self.histref[0])
return self.histref.ndim

@property
Expand Down Expand Up @@ -746,6 +750,11 @@ def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]:
return [Delayed(k, graph, layer=layer) for k in keys]


def _hist_safe_sum(items):
safe_items = [item for item in items if not isinstance(item, tuple)]
return sum(safe_items)


def _reduction(
ph: PartitionedHistogram,
split_every: int | None = None,
Expand All @@ -762,15 +771,11 @@ def _reduction(
name_comb = f"{label}-combine-{token}"
name_agg = f"{label}-agg-{token}"

def hist_safe_sum(items):
safe_items = [item for item in items if not isinstance(item, tuple)]
return sum(safe_items)

mdftr = MockableDataFrameTreeReduction(
name=name_agg,
name_input=ph.name,
npartitions_input=ph.npartitions,
concat_func=hist_safe_sum,
concat_func=_hist_safe_sum,
tree_node_func=lambda x: x,
finalize_func=lambda x: x,
split_every=split_every,
Expand Down Expand Up @@ -1000,7 +1005,9 @@ def to_dask_array(agghist: AggHistogram, flow: bool = False, dd: bool = False) -
thehist = agghist.histref
if isinstance(thehist, tuple):
thehist = bh.Histogram(
*agghist.histref[0], storage=agghist.histref[1], metadata=agghist.histref[2]
*agghist.histref[0],
storage=agghist.histref[1](),
metadata=agghist.histref[2],
)
zeros = (0,) * thehist.ndim
dsk = {(name, *zeros): (lambda x, f: x.to_numpy(flow=f)[0], agghist.key, flow)}
Expand Down
Loading

0 comments on commit 5289f10

Please sign in to comment.